feat(risk): 支持基于SSE的异步风险评估任务管理
- 新增RiskAssessmentTaskVO表示风险评估异步任务 - 添加RiskAssessmentAsyncService接口及实现类,实现任务的提交、状态查询和取消 - 实现基于Redis的任务状态存储和管理,支持分布式环境 - 在RiskController中新增异步任务提交及查询相关接口,支持任务列表、统计、状态和结果查询 - 通过SseChannelManager实时推送任务状态和进度给用户 - 保留原有同步风险评估接口,新增日志区分同步与异步调用 - 增加用户登录及SSE连接状态校验,提升异步任务交互的安全性与可靠性
This commit is contained in:
@@ -1,20 +1,27 @@
|
||||
package cn.yinlihupo.controller.risk;
|
||||
|
||||
import cn.yinlihupo.common.core.BaseResponse;
|
||||
import cn.yinlihupo.common.enums.AsyncTaskStatus;
|
||||
import cn.yinlihupo.common.page.TableDataInfo;
|
||||
import cn.yinlihupo.common.sse.SseChannelManager;
|
||||
import cn.yinlihupo.common.sse.SseMessage;
|
||||
import cn.yinlihupo.common.util.ResultUtils;
|
||||
import cn.yinlihupo.common.util.SecurityUtils;
|
||||
import cn.yinlihupo.domain.dto.CreateRiskRequest;
|
||||
import cn.yinlihupo.domain.dto.CreateWorkOrderRequest;
|
||||
import cn.yinlihupo.domain.vo.RiskAssessmentResult;
|
||||
import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO;
|
||||
import cn.yinlihupo.domain.vo.RiskStatisticsVO;
|
||||
import cn.yinlihupo.domain.vo.RiskVO;
|
||||
import cn.yinlihupo.service.risk.RiskAssessmentAsyncService;
|
||||
import cn.yinlihupo.service.risk.RiskService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 风险管理控制器
|
||||
@@ -27,6 +34,13 @@ import java.util.List;
|
||||
public class RiskController {
|
||||
|
||||
private final RiskService riskService;
|
||||
private final RiskAssessmentAsyncService riskAssessmentAsyncService;
|
||||
private final SseChannelManager sseChannelManager;
|
||||
|
||||
/**
|
||||
* SSE 消息类型常量
|
||||
*/
|
||||
private static final String MESSAGE_TYPE = "risk-assess";
|
||||
|
||||
/**
|
||||
* 创建风险评估
|
||||
@@ -201,7 +215,7 @@ public class RiskController {
|
||||
// ==================== AI 风险评估接口 ====================
|
||||
|
||||
/**
|
||||
* AI风险评估
|
||||
* AI风险评估(同步接口,保留兼容)
|
||||
* 使用AI能力对项目整体的进度、人员、资金等会影响项目开展的所有因素进行风险评估
|
||||
* 评估完成后自动将识别的风险入库
|
||||
*
|
||||
@@ -210,7 +224,7 @@ public class RiskController {
|
||||
*/
|
||||
@PostMapping("/assess/{projectId}")
|
||||
public BaseResponse<RiskAssessmentResult> assessProjectRisk(@PathVariable Long projectId) {
|
||||
log.info("AI风险评估, projectId: {}", projectId);
|
||||
log.info("AI风险评估(同步), projectId: {}", projectId);
|
||||
|
||||
try {
|
||||
RiskAssessmentResult result = riskService.assessProjectRisk(projectId);
|
||||
@@ -220,4 +234,161 @@ public class RiskController {
|
||||
return ResultUtils.error("风险评估失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== SSE 异步风险评估接口 ====================
|
||||
|
||||
/**
|
||||
* 通过 SSE 提交异步风险评估任务
|
||||
*
|
||||
* @param projectId 项目ID
|
||||
* @return 提交结果
|
||||
*/
|
||||
@PostMapping("/sse/assess/{projectId}")
|
||||
public BaseResponse<Map<String, Object>> submitAssessmentTaskWithSse(@PathVariable Long projectId) {
|
||||
Long userId = SecurityUtils.getCurrentUserId();
|
||||
log.info("用户通过SSE提交风险评估任务, userId: {}, projectId: {}", userId, projectId);
|
||||
|
||||
if (userId == null) {
|
||||
return ResultUtils.error("用户未登录");
|
||||
}
|
||||
|
||||
// 检查用户是否在线
|
||||
String userIdStr = String.valueOf(userId);
|
||||
if (!sseChannelManager.isOnline(userIdStr)) {
|
||||
return ResultUtils.error("用户未建立SSE连接,请先调用 /api/v1/sse/connect/" + userId);
|
||||
}
|
||||
|
||||
try {
|
||||
// 提交异步任务
|
||||
String taskId = riskAssessmentAsyncService.submitAssessmentTask(projectId, userId);
|
||||
|
||||
// 推送任务提交成功事件
|
||||
Map<String, Object> submittedData = new HashMap<>();
|
||||
submittedData.put("taskId", taskId);
|
||||
submittedData.put("projectId", projectId);
|
||||
submittedData.put("message", "风险评估任务已提交");
|
||||
|
||||
SseMessage submittedMessage = SseMessage.of(MESSAGE_TYPE, "submitted", userIdStr, submittedData);
|
||||
sseChannelManager.send(userIdStr, submittedMessage);
|
||||
|
||||
log.info("风险评估任务提交成功, userId: {}, projectId: {}, taskId: {}", userId, projectId, taskId);
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("taskId", taskId);
|
||||
result.put("projectId", projectId);
|
||||
result.put("message", "任务已提交,进度将通过SSE推送");
|
||||
return ResultUtils.success("提交成功", result);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("提交风险评估任务失败, userId: {}, error: {}", userId, e.getMessage(), e);
|
||||
|
||||
// 推送错误事件
|
||||
Map<String, Object> errorData = new HashMap<>();
|
||||
errorData.put("projectId", projectId);
|
||||
errorData.put("error", e.getMessage());
|
||||
SseMessage errorMessage = SseMessage.of(MESSAGE_TYPE, "error", userIdStr, errorData);
|
||||
sseChannelManager.send(userIdStr, errorMessage);
|
||||
|
||||
return ResultUtils.error("提交失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询我的风险评估任务列表
|
||||
*
|
||||
* @return 任务列表
|
||||
*/
|
||||
@GetMapping("/sse/my-tasks")
|
||||
public BaseResponse<List<RiskAssessmentTaskVO>> getMyTasks() {
|
||||
Long userId = SecurityUtils.getCurrentUserId();
|
||||
if (userId == null) {
|
||||
return ResultUtils.error("用户未登录");
|
||||
}
|
||||
|
||||
List<RiskAssessmentTaskVO> tasks = riskAssessmentAsyncService.getTasksByUserId(userId);
|
||||
return ResultUtils.success(tasks);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询我的风险评估任务统计信息
|
||||
*
|
||||
* @return 统计信息
|
||||
*/
|
||||
@GetMapping("/sse/my-tasks/stats")
|
||||
public BaseResponse<Map<String, Object>> getMyTaskStats() {
|
||||
Long userId = SecurityUtils.getCurrentUserId();
|
||||
if (userId == null) {
|
||||
return ResultUtils.error("用户未登录");
|
||||
}
|
||||
|
||||
List<RiskAssessmentTaskVO> tasks = riskAssessmentAsyncService.getTasksByUserId(userId);
|
||||
int processingCount = riskAssessmentAsyncService.getProcessingTaskCount(userId);
|
||||
|
||||
Map<String, Object> stats = new HashMap<>();
|
||||
stats.put("total", tasks.size());
|
||||
stats.put("processing", processingCount);
|
||||
stats.put("completed", (int) tasks.stream()
|
||||
.filter(t -> AsyncTaskStatus.COMPLETED.getCode().equals(t.getStatus())).count());
|
||||
stats.put("failed", (int) tasks.stream()
|
||||
.filter(t -> AsyncTaskStatus.FAILED.getCode().equals(t.getStatus())).count());
|
||||
|
||||
return ResultUtils.success(stats);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询单个风险评估任务状态
|
||||
*
|
||||
* @param taskId 任务ID
|
||||
* @return 任务状态
|
||||
*/
|
||||
@GetMapping("/sse/task/{taskId}")
|
||||
public BaseResponse<RiskAssessmentTaskVO> getTaskStatus(@PathVariable String taskId) {
|
||||
Long userId = SecurityUtils.getCurrentUserId();
|
||||
if (userId == null) {
|
||||
return ResultUtils.error("用户未登录");
|
||||
}
|
||||
|
||||
RiskAssessmentTaskVO task = riskAssessmentAsyncService.getTaskStatus(taskId);
|
||||
if (task == null) {
|
||||
return ResultUtils.error("任务不存在");
|
||||
}
|
||||
|
||||
// 校验任务归属
|
||||
if (!userId.equals(task.getUserId())) {
|
||||
return ResultUtils.error("无权访问该任务");
|
||||
}
|
||||
|
||||
return ResultUtils.success(task);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取风险评估任务结果
|
||||
*
|
||||
* @param taskId 任务ID
|
||||
* @return 风险评估结果
|
||||
*/
|
||||
@GetMapping("/sse/task/{taskId}/result")
|
||||
public BaseResponse<RiskAssessmentResult> getTaskResult(@PathVariable String taskId) {
|
||||
Long userId = SecurityUtils.getCurrentUserId();
|
||||
if (userId == null) {
|
||||
return ResultUtils.error("用户未登录");
|
||||
}
|
||||
|
||||
RiskAssessmentTaskVO task = riskAssessmentAsyncService.getTaskStatus(taskId);
|
||||
if (task == null) {
|
||||
return ResultUtils.error("任务不存在");
|
||||
}
|
||||
|
||||
// 校验任务归属
|
||||
if (!userId.equals(task.getUserId())) {
|
||||
return ResultUtils.error("无权访问该任务");
|
||||
}
|
||||
|
||||
if (!AsyncTaskStatus.COMPLETED.getCode().equals(task.getStatus())) {
|
||||
return ResultUtils.error("任务尚未完成");
|
||||
}
|
||||
|
||||
RiskAssessmentResult result = riskAssessmentAsyncService.getTaskResult(taskId);
|
||||
return ResultUtils.success(result);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package cn.yinlihupo.domain.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
/**
|
||||
* 风险评估异步任务VO
|
||||
*/
|
||||
@Data
|
||||
public class RiskAssessmentTaskVO {
|
||||
|
||||
/**
|
||||
* 任务ID
|
||||
*/
|
||||
private String taskId;
|
||||
|
||||
/**
|
||||
* 用户ID(任务所属用户)
|
||||
*/
|
||||
private Long userId;
|
||||
|
||||
/**
|
||||
* 项目ID
|
||||
*/
|
||||
private Long projectId;
|
||||
|
||||
/**
|
||||
* 项目名称
|
||||
*/
|
||||
private String projectName;
|
||||
|
||||
/**
|
||||
* 任务状态: pending-待处理, processing-处理中, completed-已完成, failed-失败
|
||||
*/
|
||||
private String status;
|
||||
|
||||
/**
|
||||
* 状态描述
|
||||
*/
|
||||
private String statusDesc;
|
||||
|
||||
/**
|
||||
* 当前进度百分比 (0-100)
|
||||
*/
|
||||
private Integer progress;
|
||||
|
||||
/**
|
||||
* 进度描述信息
|
||||
*/
|
||||
private String progressMessage;
|
||||
|
||||
/**
|
||||
* 任务创建时间
|
||||
*/
|
||||
private LocalDateTime createTime;
|
||||
|
||||
/**
|
||||
* 任务开始处理时间
|
||||
*/
|
||||
private LocalDateTime startTime;
|
||||
|
||||
/**
|
||||
* 任务完成时间
|
||||
*/
|
||||
private LocalDateTime completeTime;
|
||||
|
||||
/**
|
||||
* 处理结果(仅当status=completed时有值)
|
||||
*/
|
||||
private RiskAssessmentResult result;
|
||||
|
||||
/**
|
||||
* 错误信息(仅当status=failed时有值)
|
||||
*/
|
||||
private String errorMessage;
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package cn.yinlihupo.service.risk;
|
||||
|
||||
import cn.yinlihupo.domain.vo.RiskAssessmentResult;
|
||||
import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 风险评估异步任务服务接口
|
||||
*/
|
||||
public interface RiskAssessmentAsyncService {
|
||||
|
||||
/**
|
||||
* 提交异步风险评估任务
|
||||
*
|
||||
* @param projectId 项目ID
|
||||
* @param userId 用户ID
|
||||
* @return 任务ID
|
||||
*/
|
||||
String submitAssessmentTask(Long projectId, Long userId);
|
||||
|
||||
/**
|
||||
* 获取指定用户的所有风险评估任务
|
||||
*
|
||||
* @param userId 用户ID
|
||||
* @return 任务列表(按创建时间倒序)
|
||||
*/
|
||||
List<RiskAssessmentTaskVO> getTasksByUserId(Long userId);
|
||||
|
||||
/**
|
||||
* 获取任务状态
|
||||
*
|
||||
* @param taskId 任务ID
|
||||
* @return 任务状态VO
|
||||
*/
|
||||
RiskAssessmentTaskVO getTaskStatus(String taskId);
|
||||
|
||||
/**
|
||||
* 获取任务结果
|
||||
*
|
||||
* @param taskId 任务ID
|
||||
* @return 风险评估结果
|
||||
*/
|
||||
RiskAssessmentResult getTaskResult(String taskId);
|
||||
|
||||
/**
|
||||
* 取消任务
|
||||
*
|
||||
* @param taskId 任务ID
|
||||
* @return 是否取消成功
|
||||
*/
|
||||
boolean cancelTask(String taskId);
|
||||
|
||||
/**
|
||||
* 获取指定用户的正在进行的任务数量
|
||||
*
|
||||
* @param userId 用户ID
|
||||
* @return 正在进行的任务数量
|
||||
*/
|
||||
int getProcessingTaskCount(Long userId);
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
package cn.yinlihupo.service.risk.impl;
|
||||
|
||||
import cn.hutool.core.util.IdUtil;
|
||||
import cn.yinlihupo.common.enums.AsyncTaskStatus;
|
||||
import cn.yinlihupo.common.sse.SseChannelManager;
|
||||
import cn.yinlihupo.common.sse.SseMessage;
|
||||
import cn.yinlihupo.common.util.RedisService;
|
||||
import cn.yinlihupo.domain.vo.RiskAssessmentResult;
|
||||
import cn.yinlihupo.domain.vo.RiskAssessmentTaskVO;
|
||||
import cn.yinlihupo.service.risk.RiskAssessmentAsyncService;
|
||||
import cn.yinlihupo.service.risk.RiskService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 风险评估异步任务服务实现类
|
||||
* 使用 Redis 存储任务状态,支持分布式环境
|
||||
*/
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class RiskAssessmentAsyncServiceImpl implements RiskAssessmentAsyncService {
|
||||
|
||||
private final RiskService riskService;
|
||||
private final RedisService redisService;
|
||||
private final SseChannelManager sseChannelManager;
|
||||
|
||||
/**
|
||||
* 任务存储 key 前缀
|
||||
*/
|
||||
private static final String TASK_KEY_PREFIX = "risk:assess:task:";
|
||||
|
||||
/**
|
||||
* 用户任务列表 key 前缀
|
||||
*/
|
||||
private static final String USER_TASKS_KEY_PREFIX = "risk:assess:user:";
|
||||
|
||||
/**
|
||||
* SSE 消息类型
|
||||
*/
|
||||
private static final String MESSAGE_TYPE = "risk-assess";
|
||||
|
||||
/**
|
||||
* 任务默认过期时间:24小时
|
||||
*/
|
||||
private static final Duration TASK_EXPIRE_DURATION = Duration.ofHours(24);
|
||||
|
||||
@Override
|
||||
public String submitAssessmentTask(Long projectId, Long userId) {
|
||||
// 生成任务ID
|
||||
String taskId = IdUtil.fastSimpleUUID();
|
||||
|
||||
log.info("提交风险评估任务, taskId: {}, projectId: {}, userId: {}", taskId, projectId, userId);
|
||||
|
||||
// 创建任务记录
|
||||
RiskAssessmentTaskVO taskVO = new RiskAssessmentTaskVO();
|
||||
taskVO.setTaskId(taskId);
|
||||
taskVO.setUserId(userId);
|
||||
taskVO.setProjectId(projectId);
|
||||
taskVO.setStatus(AsyncTaskStatus.PENDING.getCode());
|
||||
taskVO.setStatusDesc(AsyncTaskStatus.PENDING.getDescription());
|
||||
taskVO.setProgress(0);
|
||||
taskVO.setProgressMessage("任务已提交,等待处理...");
|
||||
taskVO.setCreateTime(LocalDateTime.now());
|
||||
|
||||
// 存储任务到 Redis
|
||||
String taskKey = getTaskKey(taskId);
|
||||
redisService.set(taskKey, taskVO, TASK_EXPIRE_DURATION);
|
||||
|
||||
// 将任务ID添加到用户的任务列表
|
||||
if (userId != null) {
|
||||
String userTasksKey = getUserTasksKey(userId);
|
||||
redisService.sAdd(userTasksKey, taskId);
|
||||
redisService.expire(userTasksKey, TASK_EXPIRE_DURATION);
|
||||
}
|
||||
|
||||
// 异步执行任务
|
||||
executeAssessmentTaskAsync(taskId, projectId);
|
||||
|
||||
return taskId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 异步执行风险评估任务
|
||||
*/
|
||||
@Async("projectInitTaskExecutor")
|
||||
public CompletableFuture<Void> executeAssessmentTaskAsync(String taskId, Long projectId) {
|
||||
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
|
||||
if (taskVO == null) {
|
||||
log.error("任务不存在, taskId: {}", taskId);
|
||||
return CompletableFuture.completedFuture(null);
|
||||
}
|
||||
|
||||
try {
|
||||
// 更新状态为处理中
|
||||
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 10, "正在收集项目数据...");
|
||||
|
||||
// 调用风险评估服务
|
||||
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 30, "项目数据收集完成,AI正在分析...");
|
||||
RiskAssessmentResult result = riskService.assessProjectRisk(projectId);
|
||||
|
||||
// 更新任务完成状态
|
||||
taskVO = getTaskFromRedis(taskId);
|
||||
if (taskVO != null) {
|
||||
taskVO.setResult(result);
|
||||
taskVO.setProjectName(result.getProjectName());
|
||||
taskVO.setCompleteTime(LocalDateTime.now());
|
||||
saveTaskToRedis(taskVO);
|
||||
}
|
||||
|
||||
updateTaskProgress(taskId, AsyncTaskStatus.COMPLETED, 100, "风险评估完成");
|
||||
|
||||
log.info("风险评估任务完成, taskId: {}, projectId: {}", taskId, projectId);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("风险评估任务失败, taskId: {}, error: {}", taskId, e.getMessage(), e);
|
||||
taskVO = getTaskFromRedis(taskId);
|
||||
if (taskVO != null) {
|
||||
taskVO.setErrorMessage(e.getMessage());
|
||||
taskVO.setCompleteTime(LocalDateTime.now());
|
||||
saveTaskToRedis(taskVO);
|
||||
}
|
||||
updateTaskProgress(taskId, AsyncTaskStatus.FAILED, 0, "任务执行失败: " + e.getMessage());
|
||||
}
|
||||
|
||||
return CompletableFuture.completedFuture(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新任务进度
|
||||
*/
|
||||
private void updateTaskProgress(String taskId, AsyncTaskStatus status, int progress, String message) {
|
||||
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
|
||||
if (taskVO == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
taskVO.setStatus(status.getCode());
|
||||
taskVO.setStatusDesc(status.getDescription());
|
||||
taskVO.setProgress(progress);
|
||||
taskVO.setProgressMessage(message);
|
||||
|
||||
if (status == AsyncTaskStatus.PROCESSING && taskVO.getStartTime() == null) {
|
||||
taskVO.setStartTime(LocalDateTime.now());
|
||||
}
|
||||
|
||||
// 更新 Redis
|
||||
saveTaskToRedis(taskVO);
|
||||
|
||||
// 通过 SSE 推送进度给用户
|
||||
pushProgressToUser(taskVO, status);
|
||||
|
||||
log.debug("任务进度更新, taskId: {}, status: {}, progress: {}%, message: {}",
|
||||
taskId, status.getCode(), progress, message);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过 SSE 推送进度给用户
|
||||
*/
|
||||
private void pushProgressToUser(RiskAssessmentTaskVO taskVO, AsyncTaskStatus status) {
|
||||
if (taskVO.getUserId() == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
String userId = String.valueOf(taskVO.getUserId());
|
||||
|
||||
// 判断任务是否结束
|
||||
boolean isFinished = status == AsyncTaskStatus.COMPLETED ||
|
||||
status == AsyncTaskStatus.FAILED ||
|
||||
status == AsyncTaskStatus.CANCELLED;
|
||||
|
||||
// 推送进度消息
|
||||
SseMessage message = SseMessage.of(MESSAGE_TYPE, "progress", userId, taskVO);
|
||||
sseChannelManager.send(userId, message);
|
||||
|
||||
// 任务结束时推送完成消息
|
||||
if (isFinished) {
|
||||
SseMessage completeMessage = SseMessage.of(MESSAGE_TYPE, "complete", userId, taskVO);
|
||||
sseChannelManager.send(userId, completeMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public RiskAssessmentTaskVO getTaskStatus(String taskId) {
|
||||
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
|
||||
if (taskVO == null) {
|
||||
return null;
|
||||
}
|
||||
return copyTaskVO(taskVO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RiskAssessmentResult getTaskResult(String taskId) {
|
||||
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
|
||||
if (taskVO == null || !AsyncTaskStatus.COMPLETED.getCode().equals(taskVO.getStatus())) {
|
||||
return null;
|
||||
}
|
||||
return taskVO.getResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean cancelTask(String taskId) {
|
||||
RiskAssessmentTaskVO taskVO = getTaskFromRedis(taskId);
|
||||
if (taskVO == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 只能取消待处理或处理中的任务
|
||||
if (AsyncTaskStatus.PENDING.getCode().equals(taskVO.getStatus()) ||
|
||||
AsyncTaskStatus.PROCESSING.getCode().equals(taskVO.getStatus())) {
|
||||
updateTaskProgress(taskId, AsyncTaskStatus.CANCELLED, 0, "任务已取消");
|
||||
taskVO.setCompleteTime(LocalDateTime.now());
|
||||
saveTaskToRedis(taskVO);
|
||||
log.info("风险评估任务已取消, taskId: {}", taskId);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RiskAssessmentTaskVO> getTasksByUserId(Long userId) {
|
||||
if (userId == null) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
String userTasksKey = getUserTasksKey(userId);
|
||||
Set<String> taskIds = redisService.sMembers(userTasksKey);
|
||||
|
||||
if (taskIds == null || taskIds.isEmpty()) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
return taskIds.stream()
|
||||
.map(this::getTaskFromRedis)
|
||||
.filter(Objects::nonNull)
|
||||
.sorted(Comparator.comparing(RiskAssessmentTaskVO::getCreateTime).reversed())
|
||||
.map(this::copyTaskVO)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getProcessingTaskCount(Long userId) {
|
||||
if (userId == null) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
List<RiskAssessmentTaskVO> tasks = getTasksByUserId(userId);
|
||||
return (int) tasks.stream()
|
||||
.filter(task -> AsyncTaskStatus.PENDING.getCode().equals(task.getStatus())
|
||||
|| AsyncTaskStatus.PROCESSING.getCode().equals(task.getStatus()))
|
||||
.count();
|
||||
}
|
||||
|
||||
// ==================== Redis 操作方法 ====================
|
||||
|
||||
/**
|
||||
* 从 Redis 获取任务
|
||||
*/
|
||||
private RiskAssessmentTaskVO getTaskFromRedis(String taskId) {
|
||||
String key = getTaskKey(taskId);
|
||||
return redisService.get(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存任务到 Redis
|
||||
*/
|
||||
private void saveTaskToRedis(RiskAssessmentTaskVO taskVO) {
|
||||
String key = getTaskKey(taskVO.getTaskId());
|
||||
redisService.set(key, taskVO, TASK_EXPIRE_DURATION);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取任务存储 key
|
||||
*/
|
||||
private String getTaskKey(String taskId) {
|
||||
return TASK_KEY_PREFIX + taskId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取用户任务列表 key
|
||||
*/
|
||||
private String getUserTasksKey(Long userId) {
|
||||
return USER_TASKS_KEY_PREFIX + userId + ":tasks";
|
||||
}
|
||||
|
||||
// ==================== 工具方法 ====================
|
||||
|
||||
/**
|
||||
* 复制任务VO
|
||||
*/
|
||||
private RiskAssessmentTaskVO copyTaskVO(RiskAssessmentTaskVO source) {
|
||||
RiskAssessmentTaskVO copy = new RiskAssessmentTaskVO();
|
||||
copy.setTaskId(source.getTaskId());
|
||||
copy.setUserId(source.getUserId());
|
||||
copy.setProjectId(source.getProjectId());
|
||||
copy.setProjectName(source.getProjectName());
|
||||
copy.setStatus(source.getStatus());
|
||||
copy.setStatusDesc(source.getStatusDesc());
|
||||
copy.setProgress(source.getProgress());
|
||||
copy.setProgressMessage(source.getProgressMessage());
|
||||
copy.setCreateTime(source.getCreateTime());
|
||||
copy.setStartTime(source.getStartTime());
|
||||
copy.setCompleteTime(source.getCompleteTime());
|
||||
copy.setResult(source.getResult());
|
||||
copy.setErrorMessage(source.getErrorMessage());
|
||||
return copy;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user