diff --git a/src/main/java/cn/yinlihupo/controller/risk/RiskController.java b/src/main/java/cn/yinlihupo/controller/risk/RiskController.java index d9063eb..7685f81 100644 --- a/src/main/java/cn/yinlihupo/controller/risk/RiskController.java +++ b/src/main/java/cn/yinlihupo/controller/risk/RiskController.java @@ -1,20 +1,27 @@ package cn.yinlihupo.controller.risk; import cn.yinlihupo.common.core.BaseResponse; +import cn.yinlihupo.common.enums.AsyncTaskStatus; import cn.yinlihupo.common.page.TableDataInfo; +import cn.yinlihupo.common.sse.SseChannelManager; +import cn.yinlihupo.common.sse.SseMessage; import cn.yinlihupo.common.util.ResultUtils; import cn.yinlihupo.common.util.SecurityUtils; import cn.yinlihupo.domain.dto.CreateRiskRequest; import cn.yinlihupo.domain.dto.CreateWorkOrderRequest; import cn.yinlihupo.domain.vo.RiskAssessmentResult; +import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO; import cn.yinlihupo.domain.vo.RiskStatisticsVO; import cn.yinlihupo.domain.vo.RiskVO; +import cn.yinlihupo.service.risk.RiskAssessmentAsyncService; import cn.yinlihupo.service.risk.RiskService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.web.bind.annotation.*; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * 风险管理控制器 @@ -27,6 +34,13 @@ import java.util.List; public class RiskController { private final RiskService riskService; + private final RiskAssessmentAsyncService riskAssessmentAsyncService; + private final SseChannelManager sseChannelManager; + + /** + * SSE 消息类型常量 + */ + private static final String MESSAGE_TYPE = "risk-assess"; /** * 创建风险评估 @@ -201,7 +215,7 @@ public class RiskController { // ==================== AI 风险评估接口 ==================== /** - * AI风险评估 + * AI风险评估(同步接口,保留兼容) * 使用AI能力对项目整体的进度、人员、资金等会影响项目开展的所有因素进行风险评估 * 评估完成后自动将识别的风险入库 * @@ -210,7 +224,7 @@ public class RiskController { */ @PostMapping("/assess/{projectId}") public BaseResponse assessProjectRisk(@PathVariable Long projectId) { - log.info("AI风险评估, projectId: {}", projectId); + log.info("AI风险评估(同步), projectId: {}", projectId); try { RiskAssessmentResult result = riskService.assessProjectRisk(projectId); @@ -220,4 +234,161 @@ public class RiskController { return ResultUtils.error("风险评估失败: " + e.getMessage()); } } + + // ==================== SSE 异步风险评估接口 ==================== + + /** + * 通过 SSE 提交异步风险评估任务 + * + * @param projectId 项目ID + * @return 提交结果 + */ + @PostMapping("/sse/assess/{projectId}") + public BaseResponse> submitAssessmentTaskWithSse(@PathVariable Long projectId) { + Long userId = SecurityUtils.getCurrentUserId(); + log.info("用户通过SSE提交风险评估任务, userId: {}, projectId: {}", userId, projectId); + + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + // 检查用户是否在线 + String userIdStr = String.valueOf(userId); + if (!sseChannelManager.isOnline(userIdStr)) { + return ResultUtils.error("用户未建立SSE连接,请先调用 /api/v1/sse/connect/" + userId); + } + + try { + // 提交异步任务 + String taskId = riskAssessmentAsyncService.submitAssessmentTask(projectId, userId); + + // 推送任务提交成功事件 + Map submittedData = new HashMap<>(); + submittedData.put("taskId", taskId); + submittedData.put("projectId", projectId); + submittedData.put("message", "风险评估任务已提交"); + + SseMessage submittedMessage = SseMessage.of(MESSAGE_TYPE, "submitted", userIdStr, submittedData); + sseChannelManager.send(userIdStr, submittedMessage); + + log.info("风险评估任务提交成功, userId: {}, projectId: {}, taskId: {}", userId, projectId, taskId); + + Map result = new HashMap<>(); + result.put("taskId", taskId); + result.put("projectId", projectId); + result.put("message", "任务已提交,进度将通过SSE推送"); + return ResultUtils.success("提交成功", result); + + } catch (Exception e) { + log.error("提交风险评估任务失败, userId: {}, error: {}", userId, e.getMessage(), e); + + // 推送错误事件 + Map errorData = new HashMap<>(); + errorData.put("projectId", projectId); + errorData.put("error", e.getMessage()); + SseMessage errorMessage = SseMessage.of(MESSAGE_TYPE, "error", userIdStr, errorData); + sseChannelManager.send(userIdStr, errorMessage); + + return ResultUtils.error("提交失败: " + e.getMessage()); + } + } + + /** + * 查询我的风险评估任务列表 + * + * @return 任务列表 + */ + @GetMapping("/sse/my-tasks") + public BaseResponse> getMyTasks() { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + List tasks = riskAssessmentAsyncService.getTasksByUserId(userId); + return ResultUtils.success(tasks); + } + + /** + * 查询我的风险评估任务统计信息 + * + * @return 统计信息 + */ + @GetMapping("/sse/my-tasks/stats") + public BaseResponse> getMyTaskStats() { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + List tasks = riskAssessmentAsyncService.getTasksByUserId(userId); + int processingCount = riskAssessmentAsyncService.getProcessingTaskCount(userId); + + Map stats = new HashMap<>(); + stats.put("total", tasks.size()); + stats.put("processing", processingCount); + stats.put("completed", (int) tasks.stream() + .filter(t -> AsyncTaskStatus.COMPLETED.getCode().equals(t.getStatus())).count()); + stats.put("failed", (int) tasks.stream() + .filter(t -> AsyncTaskStatus.FAILED.getCode().equals(t.getStatus())).count()); + + return ResultUtils.success(stats); + } + + /** + * 查询单个风险评估任务状态 + * + * @param taskId 任务ID + * @return 任务状态 + */ + @GetMapping("/sse/task/{taskId}") + public BaseResponse getTaskStatus(@PathVariable String taskId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + RiskAssessmentTaskVO task = riskAssessmentAsyncService.getTaskStatus(taskId); + if (task == null) { + return ResultUtils.error("任务不存在"); + } + + // 校验任务归属 + if (!userId.equals(task.getUserId())) { + return ResultUtils.error("无权访问该任务"); + } + + return ResultUtils.success(task); + } + + /** + * 获取风险评估任务结果 + * + * @param taskId 任务ID + * @return 风险评估结果 + */ + @GetMapping("/sse/task/{taskId}/result") + public BaseResponse getTaskResult(@PathVariable String taskId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + RiskAssessmentTaskVO task = riskAssessmentAsyncService.getTaskStatus(taskId); + if (task == null) { + return ResultUtils.error("任务不存在"); + } + + // 校验任务归属 + if (!userId.equals(task.getUserId())) { + return ResultUtils.error("无权访问该任务"); + } + + if (!AsyncTaskStatus.COMPLETED.getCode().equals(task.getStatus())) { + return ResultUtils.error("任务尚未完成"); + } + + RiskAssessmentResult result = riskAssessmentAsyncService.getTaskResult(taskId); + return ResultUtils.success(result); + } } diff --git a/src/main/java/cn/yinlihupo/domain/vo/RiskAssessmentTaskVO.java b/src/main/java/cn/yinlihupo/domain/vo/RiskAssessmentTaskVO.java new file mode 100644 index 0000000..b961940 --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/vo/RiskAssessmentTaskVO.java @@ -0,0 +1,77 @@ +package cn.yinlihupo.domain.vo; + +import lombok.Data; + +import java.time.LocalDateTime; + +/** + * 风险评估异步任务VO + */ +@Data +public class RiskAssessmentTaskVO { + + /** + * 任务ID + */ + private String taskId; + + /** + * 用户ID(任务所属用户) + */ + private Long userId; + + /** + * 项目ID + */ + private Long projectId; + + /** + * 项目名称 + */ + private String projectName; + + /** + * 任务状态: pending-待处理, processing-处理中, completed-已完成, failed-失败 + */ + private String status; + + /** + * 状态描述 + */ + private String statusDesc; + + /** + * 当前进度百分比 (0-100) + */ + private Integer progress; + + /** + * 进度描述信息 + */ + private String progressMessage; + + /** + * 任务创建时间 + */ + private LocalDateTime createTime; + + /** + * 任务开始处理时间 + */ + private LocalDateTime startTime; + + /** + * 任务完成时间 + */ + private LocalDateTime completeTime; + + /** + * 处理结果(仅当status=completed时有值) + */ + private RiskAssessmentResult result; + + /** + * 错误信息(仅当status=failed时有值) + */ + private String errorMessage; +} diff --git a/src/main/java/cn/yinlihupo/service/risk/RiskAssessmentAsyncService.java b/src/main/java/cn/yinlihupo/service/risk/RiskAssessmentAsyncService.java new file mode 100644 index 0000000..18dff4a --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/risk/RiskAssessmentAsyncService.java @@ -0,0 +1,61 @@ +package cn.yinlihupo.service.risk; + +import cn.yinlihupo.domain.vo.RiskAssessmentResult; +import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO; + +import java.util.List; + +/** + * 风险评估异步任务服务接口 + */ +public interface RiskAssessmentAsyncService { + + /** + * 提交异步风险评估任务 + * + * @param projectId 项目ID + * @param userId 用户ID + * @return 任务ID + */ + String submitAssessmentTask(Long projectId, Long userId); + + /** + * 获取指定用户的所有风险评估任务 + * + * @param userId 用户ID + * @return 任务列表(按创建时间倒序) + */ + List getTasksByUserId(Long userId); + + /** + * 获取任务状态 + * + * @param taskId 任务ID + * @return 任务状态VO + */ + RiskAssessmentTaskVO getTaskStatus(String taskId); + + /** + * 获取任务结果 + * + * @param taskId 任务ID + * @return 风险评估结果 + */ + RiskAssessmentResult getTaskResult(String taskId); + + /** + * 取消任务 + * + * @param taskId 任务ID + * @return 是否取消成功 + */ + boolean cancelTask(String taskId); + + /** + * 获取指定用户的正在进行的任务数量 + * + * @param userId 用户ID + * @return 正在进行的任务数量 + */ + int getProcessingTaskCount(Long userId); +} diff --git a/src/main/java/cn/yinlihupo/service/risk/impl/RiskAssessmentAsyncServiceImpl.java b/src/main/java/cn/yinlihupo/service/risk/impl/RiskAssessmentAsyncServiceImpl.java new file mode 100644 index 0000000..21dc3a4 --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/risk/impl/RiskAssessmentAsyncServiceImpl.java @@ -0,0 +1,317 @@ +package cn.yinlihupo.service.risk.impl; + +import cn.hutool.core.util.IdUtil; +import cn.yinlihupo.common.enums.AsyncTaskStatus; +import cn.yinlihupo.common.sse.SseChannelManager; +import cn.yinlihupo.common.sse.SseMessage; +import cn.yinlihupo.common.util.RedisService; +import cn.yinlihupo.domain.vo.RiskAssessmentResult; +import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO; +import cn.yinlihupo.service.risk.RiskAssessmentAsyncService; +import cn.yinlihupo.service.risk.RiskService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; + +import java.time.Duration; +import java.time.LocalDateTime; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * 风险评估异步任务服务实现类 + * 使用 Redis 存储任务状态,支持分布式环境 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class RiskAssessmentAsyncServiceImpl implements RiskAssessmentAsyncService { + + private final RiskService riskService; + private final RedisService redisService; + private final SseChannelManager sseChannelManager; + + /** + * 任务存储 key 前缀 + */ + private static final String TASK_KEY_PREFIX = "risk:assess:task:"; + + /** + * 用户任务列表 key 前缀 + */ + private static final String USER_TASKS_KEY_PREFIX = "risk:assess:user:"; + + /** + * SSE 消息类型 + */ + private static final String MESSAGE_TYPE = "risk-assess"; + + /** + * 任务默认过期时间:24小时 + */ + private static final Duration TASK_EXPIRE_DURATION = Duration.ofHours(24); + + @Override + public String submitAssessmentTask(Long projectId, Long userId) { + // 生成任务ID + String taskId = IdUtil.fastSimpleUUID(); + + log.info("提交风险评估任务, taskId: {}, projectId: {}, userId: {}", taskId, projectId, userId); + + // 创建任务记录 + RiskAssessmentTaskVO taskVO = new RiskAssessmentTaskVO(); + taskVO.setTaskId(taskId); + taskVO.setUserId(userId); + taskVO.setProjectId(projectId); + taskVO.setStatus(AsyncTaskStatus.PENDING.getCode()); + taskVO.setStatusDesc(AsyncTaskStatus.PENDING.getDescription()); + taskVO.setProgress(0); + taskVO.setProgressMessage("任务已提交,等待处理..."); + taskVO.setCreateTime(LocalDateTime.now()); + + // 存储任务到 Redis + String taskKey = getTaskKey(taskId); + redisService.set(taskKey, taskVO, TASK_EXPIRE_DURATION); + + // 将任务ID添加到用户的任务列表 + if (userId != null) { + String userTasksKey = getUserTasksKey(userId); + redisService.sAdd(userTasksKey, taskId); + redisService.expire(userTasksKey, TASK_EXPIRE_DURATION); + } + + // 异步执行任务 + executeAssessmentTaskAsync(taskId, projectId); + + return taskId; + } + + /** + * 异步执行风险评估任务 + */ + @Async("projectInitTaskExecutor") + public CompletableFuture executeAssessmentTaskAsync(String taskId, Long projectId) { + RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId); + if (taskVO == null) { + log.error("任务不存在, taskId: {}", taskId); + return CompletableFuture.completedFuture(null); + } + + try { + // 更新状态为处理中 + updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 10, "正在收集项目数据..."); + + // 调用风险评估服务 + updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 30, "项目数据收集完成,AI正在分析..."); + RiskAssessmentResult result = riskService.assessProjectRisk(projectId); + + // 更新任务完成状态 + taskVO = getTaskFromRedis(taskId); + if (taskVO != null) { + taskVO.setResult(result); + taskVO.setProjectName(result.getProjectName()); + taskVO.setCompleteTime(LocalDateTime.now()); + saveTaskToRedis(taskVO); + } + + updateTaskProgress(taskId, AsyncTaskStatus.COMPLETED, 100, "风险评估完成"); + + log.info("风险评估任务完成, taskId: {}, projectId: {}", taskId, projectId); + + } catch (Exception e) { + log.error("风险评估任务失败, taskId: {}, error: {}", taskId, e.getMessage(), e); + taskVO = getTaskFromRedis(taskId); + if (taskVO != null) { + taskVO.setErrorMessage(e.getMessage()); + taskVO.setCompleteTime(LocalDateTime.now()); + saveTaskToRedis(taskVO); + } + updateTaskProgress(taskId, AsyncTaskStatus.FAILED, 0, "任务执行失败: " + e.getMessage()); + } + + return CompletableFuture.completedFuture(null); + } + + /** + * 更新任务进度 + */ + private void updateTaskProgress(String taskId, AsyncTaskStatus status, int progress, String message) { + RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId); + if (taskVO == null) { + return; + } + + taskVO.setStatus(status.getCode()); + taskVO.setStatusDesc(status.getDescription()); + taskVO.setProgress(progress); + taskVO.setProgressMessage(message); + + if (status == AsyncTaskStatus.PROCESSING && taskVO.getStartTime() == null) { + taskVO.setStartTime(LocalDateTime.now()); + } + + // 更新 Redis + saveTaskToRedis(taskVO); + + // 通过 SSE 推送进度给用户 + pushProgressToUser(taskVO, status); + + log.debug("任务进度更新, taskId: {}, status: {}, progress: {}%, message: {}", + taskId, status.getCode(), progress, message); + } + + /** + * 通过 SSE 推送进度给用户 + */ + private void pushProgressToUser(RiskAssessmentTaskVO taskVO, AsyncTaskStatus status) { + if (taskVO.getUserId() == null) { + return; + } + + String userId = String.valueOf(taskVO.getUserId()); + + // 判断任务是否结束 + boolean isFinished = status == AsyncTaskStatus.COMPLETED || + status == AsyncTaskStatus.FAILED || + status == AsyncTaskStatus.CANCELLED; + + // 推送进度消息 + SseMessage message = SseMessage.of(MESSAGE_TYPE, "progress", userId, taskVO); + sseChannelManager.send(userId, message); + + // 任务结束时推送完成消息 + if (isFinished) { + SseMessage completeMessage = SseMessage.of(MESSAGE_TYPE, "complete", userId, taskVO); + sseChannelManager.send(userId, completeMessage); + } + } + + @Override + public RiskAssessmentTaskVO getTaskStatus(String taskId) { + RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId); + if (taskVO == null) { + return null; + } + return copyTaskVO(taskVO); + } + + @Override + public RiskAssessmentResult getTaskResult(String taskId) { + RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId); + if (taskVO == null || !AsyncTaskStatus.COMPLETED.getCode().equals(taskVO.getStatus())) { + return null; + } + return taskVO.getResult(); + } + + @Override + public boolean cancelTask(String taskId) { + RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId); + if (taskVO == null) { + return false; + } + + // 只能取消待处理或处理中的任务 + if (AsyncTaskStatus.PENDING.getCode().equals(taskVO.getStatus()) || + AsyncTaskStatus.PROCESSING.getCode().equals(taskVO.getStatus())) { + updateTaskProgress(taskId, AsyncTaskStatus.CANCELLED, 0, "任务已取消"); + taskVO.setCompleteTime(LocalDateTime.now()); + saveTaskToRedis(taskVO); + log.info("风险评估任务已取消, taskId: {}", taskId); + return true; + } + + return false; + } + + @Override + public List getTasksByUserId(Long userId) { + if (userId == null) { + return new ArrayList<>(); + } + + String userTasksKey = getUserTasksKey(userId); + Set taskIds = redisService.sMembers(userTasksKey); + + if (taskIds == null || taskIds.isEmpty()) { + return new ArrayList<>(); + } + + return taskIds.stream() + .map(this::getTaskFromRedis) + .filter(Objects::nonNull) + .sorted(Comparator.comparing(RiskAssessmentTaskVO::getCreateTime).reversed()) + .map(this::copyTaskVO) + .collect(Collectors.toList()); + } + + @Override + public int getProcessingTaskCount(Long userId) { + if (userId == null) { + return 0; + } + + List tasks = getTasksByUserId(userId); + return (int) tasks.stream() + .filter(task -> AsyncTaskStatus.PENDING.getCode().equals(task.getStatus()) + || AsyncTaskStatus.PROCESSING.getCode().equals(task.getStatus())) + .count(); + } + + // ==================== Redis 操作方法 ==================== + + /** + * 从 Redis 获取任务 + */ + private RiskAssessmentTaskVO getTaskFromRedis(String taskId) { + String key = getTaskKey(taskId); + return redisService.get(key); + } + + /** + * 保存任务到 Redis + */ + private void saveTaskToRedis(RiskAssessmentTaskVO taskVO) { + String key = getTaskKey(taskVO.getTaskId()); + redisService.set(key, taskVO, TASK_EXPIRE_DURATION); + } + + /** + * 获取任务存储 key + */ + private String getTaskKey(String taskId) { + return TASK_KEY_PREFIX + taskId; + } + + /** + * 获取用户任务列表 key + */ + private String getUserTasksKey(Long userId) { + return USER_TASKS_KEY_PREFIX + userId + ":tasks"; + } + + // ==================== 工具方法 ==================== + + /** + * 复制任务VO + */ + private RiskAssessmentTaskVO copyTaskVO(RiskAssessmentTaskVO source) { + RiskAssessmentTaskVO copy = new RiskAssessmentTaskVO(); + copy.setTaskId(source.getTaskId()); + copy.setUserId(source.getUserId()); + copy.setProjectId(source.getProjectId()); + copy.setProjectName(source.getProjectName()); + copy.setStatus(source.getStatus()); + copy.setStatusDesc(source.getStatusDesc()); + copy.setProgress(source.getProgress()); + copy.setProgressMessage(source.getProgressMessage()); + copy.setCreateTime(source.getCreateTime()); + copy.setStartTime(source.getStartTime()); + copy.setCompleteTime(source.getCompleteTime()); + copy.setResult(source.getResult()); + copy.setErrorMessage(source.getErrorMessage()); + return copy; + } +}