feat(risk): 支持基于SSE的异步风险评估任务管理

- 新增RiskAssessmentTaskVO表示风险评估异步任务
- 添加RiskAssessmentAsyncService接口及实现类,实现任务的提交、状态查询和取消
- 实现基于Redis的任务状态存储和管理,支持分布式环境
- 在RiskController中新增异步任务提交及查询相关接口,支持任务列表、统计、状态和结果查询
- 通过SseChannelManager实时推送任务状态和进度给用户
- 保留原有同步风险评估接口,新增日志区分同步与异步调用
- 增加用户登录及SSE连接状态校验,提升异步任务交互的安全性与可靠性
This commit is contained in:
2026-03-30 15:16:59 +08:00
parent 4d20bf21cc
commit e7a21ba665
4 changed files with 628 additions and 2 deletions

View File

@@ -0,0 +1,61 @@
package cn.yinlihupo.service.risk;
import cn.yinlihupo.domain.vo.RiskAssessmentResult;
import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO;
import java.util.List;
/**
* 风险评估异步任务服务接口
*/
public interface RiskAssessmentAsyncService {
/**
* 提交异步风险评估任务
*
* @param projectId 项目ID
* @param userId 用户ID
* @return 任务ID
*/
String submitAssessmentTask(Long projectId, Long userId);
/**
* 获取指定用户的所有风险评估任务
*
* @param userId 用户ID
* @return 任务列表(按创建时间倒序)
*/
List<RiskAssessmentTaskVO> getTasksByUserId(Long userId);
/**
* 获取任务状态
*
* @param taskId 任务ID
* @return 任务状态VO
*/
RiskAssessmentTaskVO getTaskStatus(String taskId);
/**
* 获取任务结果
*
* @param taskId 任务ID
* @return 风险评估结果
*/
RiskAssessmentResult getTaskResult(String taskId);
/**
* 取消任务
*
* @param taskId 任务ID
* @return 是否取消成功
*/
boolean cancelTask(String taskId);
/**
* 获取指定用户的正在进行的任务数量
*
* @param userId 用户ID
* @return 正在进行的任务数量
*/
int getProcessingTaskCount(Long userId);
}

View File

@@ -0,0 +1,317 @@
package cn.yinlihupo.service.risk.impl;
import cn.hutool.core.util.IdUtil;
import cn.yinlihupo.common.enums.AsyncTaskStatus;
import cn.yinlihupo.common.sse.SseChannelManager;
import cn.yinlihupo.common.sse.SseMessage;
import cn.yinlihupo.common.util.RedisService;
import cn.yinlihupo.domain.vo.RiskAssessmentResult;
import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO;
import cn.yinlihupo.service.risk.RiskAssessmentAsyncService;
import cn.yinlihupo.service.risk.RiskService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
/**
* 风险评估异步任务服务实现类
* 使用 Redis 存储任务状态,支持分布式环境
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class RiskAssessmentAsyncServiceImpl implements RiskAssessmentAsyncService {
private final RiskService riskService;
private final RedisService redisService;
private final SseChannelManager sseChannelManager;
/**
* 任务存储 key 前缀
*/
private static final String TASK_KEY_PREFIX = "risk:assess:task:";
/**
* 用户任务列表 key 前缀
*/
private static final String USER_TASKS_KEY_PREFIX = "risk:assess:user:";
/**
* SSE 消息类型
*/
private static final String MESSAGE_TYPE = "risk-assess";
/**
* 任务默认过期时间24小时
*/
private static final Duration TASK_EXPIRE_DURATION = Duration.ofHours(24);
@Override
public String submitAssessmentTask(Long projectId, Long userId) {
// 生成任务ID
String taskId = IdUtil.fastSimpleUUID();
log.info("提交风险评估任务, taskId: {}, projectId: {}, userId: {}", taskId, projectId, userId);
// 创建任务记录
RiskAssessmentTaskVO taskVO = new RiskAssessmentTaskVO();
taskVO.setTaskId(taskId);
taskVO.setUserId(userId);
taskVO.setProjectId(projectId);
taskVO.setStatus(AsyncTaskStatus.PENDING.getCode());
taskVO.setStatusDesc(AsyncTaskStatus.PENDING.getDescription());
taskVO.setProgress(0);
taskVO.setProgressMessage("任务已提交,等待处理...");
taskVO.setCreateTime(LocalDateTime.now());
// 存储任务到 Redis
String taskKey = getTaskKey(taskId);
redisService.set(taskKey, taskVO, TASK_EXPIRE_DURATION);
// 将任务ID添加到用户的任务列表
if (userId != null) {
String userTasksKey = getUserTasksKey(userId);
redisService.sAdd(userTasksKey, taskId);
redisService.expire(userTasksKey, TASK_EXPIRE_DURATION);
}
// 异步执行任务
executeAssessmentTaskAsync(taskId, projectId);
return taskId;
}
/**
* 异步执行风险评估任务
*/
@Async("projectInitTaskExecutor")
public CompletableFuture<Void> executeAssessmentTaskAsync(String taskId, Long projectId) {
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
if (taskVO == null) {
log.error("任务不存在, taskId: {}", taskId);
return CompletableFuture.completedFuture(null);
}
try {
// 更新状态为处理中
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 10, "正在收集项目数据...");
// 调用风险评估服务
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 30, "项目数据收集完成AI正在分析...");
RiskAssessmentResult result = riskService.assessProjectRisk(projectId);
// 更新任务完成状态
taskVO = getTaskFromRedis(taskId);
if (taskVO != null) {
taskVO.setResult(result);
taskVO.setProjectName(result.getProjectName());
taskVO.setCompleteTime(LocalDateTime.now());
saveTaskToRedis(taskVO);
}
updateTaskProgress(taskId, AsyncTaskStatus.COMPLETED, 100, "风险评估完成");
log.info("风险评估任务完成, taskId: {}, projectId: {}", taskId, projectId);
} catch (Exception e) {
log.error("风险评估任务失败, taskId: {}, error: {}", taskId, e.getMessage(), e);
taskVO = getTaskFromRedis(taskId);
if (taskVO != null) {
taskVO.setErrorMessage(e.getMessage());
taskVO.setCompleteTime(LocalDateTime.now());
saveTaskToRedis(taskVO);
}
updateTaskProgress(taskId, AsyncTaskStatus.FAILED, 0, "任务执行失败: " + e.getMessage());
}
return CompletableFuture.completedFuture(null);
}
/**
* 更新任务进度
*/
private void updateTaskProgress(String taskId, AsyncTaskStatus status, int progress, String message) {
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
if (taskVO == null) {
return;
}
taskVO.setStatus(status.getCode());
taskVO.setStatusDesc(status.getDescription());
taskVO.setProgress(progress);
taskVO.setProgressMessage(message);
if (status == AsyncTaskStatus.PROCESSING && taskVO.getStartTime() == null) {
taskVO.setStartTime(LocalDateTime.now());
}
// 更新 Redis
saveTaskToRedis(taskVO);
// 通过 SSE 推送进度给用户
pushProgressToUser(taskVO, status);
log.debug("任务进度更新, taskId: {}, status: {}, progress: {}%, message: {}",
taskId, status.getCode(), progress, message);
}
/**
* 通过 SSE 推送进度给用户
*/
private void pushProgressToUser(RiskAssessmentTaskVO taskVO, AsyncTaskStatus status) {
if (taskVO.getUserId() == null) {
return;
}
String userId = String.valueOf(taskVO.getUserId());
// 判断任务是否结束
boolean isFinished = status == AsyncTaskStatus.COMPLETED ||
status == AsyncTaskStatus.FAILED ||
status == AsyncTaskStatus.CANCELLED;
// 推送进度消息
SseMessage message = SseMessage.of(MESSAGE_TYPE, "progress", userId, taskVO);
sseChannelManager.send(userId, message);
// 任务结束时推送完成消息
if (isFinished) {
SseMessage completeMessage = SseMessage.of(MESSAGE_TYPE, "complete", userId, taskVO);
sseChannelManager.send(userId, completeMessage);
}
}
@Override
public RiskAssessmentTaskVO getTaskStatus(String taskId) {
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
if (taskVO == null) {
return null;
}
return copyTaskVO(taskVO);
}
@Override
public RiskAssessmentResult getTaskResult(String taskId) {
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
if (taskVO == null || !AsyncTaskStatus.COMPLETED.getCode().equals(taskVO.getStatus())) {
return null;
}
return taskVO.getResult();
}
@Override
public boolean cancelTask(String taskId) {
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
if (taskVO == null) {
return false;
}
// 只能取消待处理或处理中的任务
if (AsyncTaskStatus.PENDING.getCode().equals(taskVO.getStatus()) ||
AsyncTaskStatus.PROCESSING.getCode().equals(taskVO.getStatus())) {
updateTaskProgress(taskId, AsyncTaskStatus.CANCELLED, 0, "任务已取消");
taskVO.setCompleteTime(LocalDateTime.now());
saveTaskToRedis(taskVO);
log.info("风险评估任务已取消, taskId: {}", taskId);
return true;
}
return false;
}
@Override
public List<RiskAssessmentTaskVO> getTasksByUserId(Long userId) {
if (userId == null) {
return new ArrayList<>();
}
String userTasksKey = getUserTasksKey(userId);
Set<String> taskIds = redisService.sMembers(userTasksKey);
if (taskIds == null || taskIds.isEmpty()) {
return new ArrayList<>();
}
return taskIds.stream()
.map(this::getTaskFromRedis)
.filter(Objects::nonNull)
.sorted(Comparator.comparing(RiskAssessmentTaskVO::getCreateTime).reversed())
.map(this::copyTaskVO)
.collect(Collectors.toList());
}
@Override
public int getProcessingTaskCount(Long userId) {
if (userId == null) {
return 0;
}
List<RiskAssessmentTaskVO> tasks = getTasksByUserId(userId);
return (int) tasks.stream()
.filter(task -> AsyncTaskStatus.PENDING.getCode().equals(task.getStatus())
|| AsyncTaskStatus.PROCESSING.getCode().equals(task.getStatus()))
.count();
}
// ==================== Redis 操作方法 ====================
/**
* 从 Redis 获取任务
*/
private RiskAssessmentTaskVO getTaskFromRedis(String taskId) {
String key = getTaskKey(taskId);
return redisService.get(key);
}
/**
* 保存任务到 Redis
*/
private void saveTaskToRedis(RiskAssessmentTaskVO taskVO) {
String key = getTaskKey(taskVO.getTaskId());
redisService.set(key, taskVO, TASK_EXPIRE_DURATION);
}
/**
* 获取任务存储 key
*/
private String getTaskKey(String taskId) {
return TASK_KEY_PREFIX + taskId;
}
/**
* 获取用户任务列表 key
*/
private String getUserTasksKey(Long userId) {
return USER_TASKS_KEY_PREFIX + userId + ":tasks";
}
// ==================== 工具方法 ====================
/**
* 复制任务VO
*/
private RiskAssessmentTaskVO copyTaskVO(RiskAssessmentTaskVO source) {
RiskAssessmentTaskVO copy = new RiskAssessmentTaskVO();
copy.setTaskId(source.getTaskId());
copy.setUserId(source.getUserId());
copy.setProjectId(source.getProjectId());
copy.setProjectName(source.getProjectName());
copy.setStatus(source.getStatus());
copy.setStatusDesc(source.getStatusDesc());
copy.setProgress(source.getProgress());
copy.setProgressMessage(source.getProgressMessage());
copy.setCreateTime(source.getCreateTime());
copy.setStartTime(source.getStartTime());
copy.setCompleteTime(source.getCompleteTime());
copy.setResult(source.getResult());
copy.setErrorMessage(source.getErrorMessage());
return copy;
}
}