feat(project): 实现异步项目初始化及SSE进度推送功能

- 新增异步任务线程池配置,支持项目初始化异步执行
- 定义异步任务状态枚举,统一管理任务生命周期状态
- 实现通用SSE通道管理器,支持用户绑定及多业务消息推送
- 创建统一SSE消息结构,支持多业务类型及事件分类
- 提供基础SSE连接管理接口,支持连接建立、状态查询及关闭
- 提供项目初始化异步任务服务接口及实现,支持进度回调和任务取消
- 添加项目初始化异步预览任务接口,支持异步提交、状态查询、结果获取及取消
- 新增项目初始化任务SSE接口,实现任务异步提交与实时进度推送
- 设计前端SSE集成文档,详细说明SSE连接、消息格式和对接步骤
- 添加Spring工具类,方便非Spring管理类获取Bean实例
- 优化项目控制器,整合异步任务相关API接口支持异步项目初始化工作流
This commit is contained in:
2026-03-28 16:57:55 +08:00
parent a7bb054e6e
commit 6d91be8af5
13 changed files with 1505 additions and 6 deletions

View File

@@ -0,0 +1,46 @@
package cn.yinlihupo.common.config;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
/**
* 异步任务配置类
*/
@Slf4j
@Configuration
@EnableAsync
public class AsyncConfig {
/**
* 项目初始化任务线程池
*/
@Bean("projectInitTaskExecutor")
public Executor projectInitTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 核心线程数
executor.setCorePoolSize(2);
// 最大线程数
executor.setMaxPoolSize(5);
// 队列容量
executor.setQueueCapacity(50);
// 线程名称前缀
executor.setThreadNamePrefix("project-init-");
// 拒绝策略:由调用线程处理
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成后再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间(秒)
executor.setAwaitTerminationSeconds(60);
// 初始化
executor.initialize();
log.info("项目初始化异步任务线程池初始化完成");
return executor;
}
}

View File

@@ -0,0 +1,55 @@
package cn.yinlihupo.common.enums;
import lombok.Getter;
/**
* 异步任务状态枚举
*/
@Getter
public enum AsyncTaskStatus {
/**
* 待处理
*/
PENDING("pending", "待处理"),
/**
* 处理中
*/
PROCESSING("processing", "处理中"),
/**
* 已完成
*/
COMPLETED("completed", "已完成"),
/**
* 失败
*/
FAILED("failed", "失败"),
/**
* 已取消
*/
CANCELLED("cancelled", "已取消");
private final String code;
private final String description;
AsyncTaskStatus(String code, String description) {
this.code = code;
this.description = description;
}
/**
* 根据code获取枚举
*/
public static AsyncTaskStatus fromCode(String code) {
for (AsyncTaskStatus status : values()) {
if (status.code.equals(code)) {
return status;
}
}
return null;
}
}

View File

@@ -0,0 +1,179 @@
package cn.yinlihupo.common.sse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
/**
* SSE 通道管理器
* 统一管理用户与 SSE 通道的关联,支持多业务类型消息推送
*/
@Slf4j
@Component
public class SseChannelManager {
/**
* 用户通道映射userId -> SseEmitter
*/
private final Map<String, SseEmitter> userChannelMap = new ConcurrentHashMap<>();
/**
* 默认超时时间30分钟适合长连接场景
*/
private static final long DEFAULT_TIMEOUT = 30 * 60 * 1000;
/**
* 创建用户通道
*
* @param userId 用户ID
* @return SseEmitter
*/
public SseEmitter createChannel(String userId) {
return createChannel(userId, DEFAULT_TIMEOUT, null);
}
/**
* 创建用户通道
*
* @param userId 用户ID
* @param timeout 超时时间(毫秒)
* @return SseEmitter
*/
public SseEmitter createChannel(String userId, long timeout) {
return createChannel(userId, timeout, null);
}
/**
* 创建用户通道
*
* @param userId 用户ID
* @param timeout 超时时间(毫秒)
* @param onCloseCallback 通道关闭回调
* @return SseEmitter
*/
public SseEmitter createChannel(String userId, long timeout, Consumer<String> onCloseCallback) {
// 关闭已存在的通道
closeChannel(userId);
SseEmitter emitter = new SseEmitter(timeout);
// 存储通道
userChannelMap.put(userId, emitter);
// 设置回调
emitter.onCompletion(() -> {
log.debug("SSE通道完成关闭, userId: {}", userId);
cleanupChannel(userId);
if (onCloseCallback != null) {
onCloseCallback.accept(userId);
}
});
emitter.onError((e) -> {
log.error("SSE通道发生错误, userId: {}", userId, e);
cleanupChannel(userId);
if (onCloseCallback != null) {
onCloseCallback.accept(userId);
}
});
emitter.onTimeout(() -> {
log.warn("SSE通道超时, userId: {}", userId);
cleanupChannel(userId);
if (onCloseCallback != null) {
onCloseCallback.accept(userId);
}
});
log.info("SSE通道创建成功, userId: {}", userId);
return emitter;
}
/**
* 发送消息到指定用户
*
* @param userId 用户ID
* @param message 消息对象
* @return 是否发送成功
*/
public boolean send(String userId, SseMessage message) {
SseEmitter emitter = userChannelMap.get(userId);
if (emitter == null) {
log.warn("用户未建立SSE连接无法推送, userId: {}", userId);
return false;
}
try {
emitter.send(SseEmitter.event()
.name(message.getEvent())
.data(message));
return true;
} catch (IOException e) {
log.error("发送消息失败, userId: {}, event: {}", userId, message.getEvent(), e);
closeChannel(userId);
return false;
}
}
/**
* 发送消息到指定用户(简化版)
*
* @param userId 用户ID
* @param type 消息类型
* @param event 事件名称
* @param data 业务数据
* @return 是否发送成功
*/
public boolean send(String userId, String type, String event, Object data) {
SseMessage message = SseMessage.of(type, event, userId, data);
return send(userId, message);
}
/**
* 检查用户是否在线
*
* @param userId 用户ID
* @return 是否在线
*/
public boolean isOnline(String userId) {
return userChannelMap.containsKey(userId);
}
/**
* 获取在线用户数量
*
* @return 在线用户数
*/
public int getOnlineCount() {
return userChannelMap.size();
}
/**
* 关闭用户通道
*
* @param userId 用户ID
*/
public void closeChannel(String userId) {
SseEmitter emitter = userChannelMap.remove(userId);
if (emitter != null) {
try {
emitter.complete();
} catch (Exception e) {
log.warn("关闭通道时发生异常, userId: {}", userId, e);
}
}
log.info("SSE通道已关闭, userId: {}", userId);
}
/**
* 清理通道
*/
private void cleanupChannel(String userId) {
userChannelMap.remove(userId);
}
}

View File

@@ -0,0 +1,76 @@
package cn.yinlihupo.common.sse;
import lombok.Builder;
import lombok.Data;
import java.time.LocalDateTime;
/**
* SSE 消息包装类
* 统一的消息格式,支持多业务类型
*/
@Data
@Builder
public class SseMessage {
/**
* 消息类型
* 例如project-init、system-notification、task-notification 等
*/
private String type;
/**
* 事件名称
* 例如submitted、progress、complete、error 等
*/
private String event;
/**
* 用户ID
*/
private String userId;
/**
* 业务数据
*/
private Object data;
/**
* 消息时间戳
*/
private LocalDateTime timestamp;
/**
* 创建消息
*/
public static SseMessage of(String type, String event, String userId, Object data) {
return SseMessage.builder()
.type(type)
.event(event)
.userId(userId)
.data(data)
.timestamp(LocalDateTime.now())
.build();
}
/**
* 创建项目初始化消息
*/
public static SseMessage projectInit(String event, String userId, Object data) {
return of("project-init", event, userId, data);
}
/**
* 创建系统通知消息
*/
public static SseMessage systemNotify(String event, String userId, Object data) {
return of("system-notification", event, userId, data);
}
/**
* 创建任务通知消息
*/
public static SseMessage taskNotify(String event, String userId, Object data) {
return of("task-notification", event, userId, data);
}
}

View File

@@ -0,0 +1,100 @@
package cn.yinlihupo.common.util;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
/**
* Spring工具类
* 用于在非Spring管理的类中获取Spring容器中的Bean
*
* @author cheems
*/
@Component
public class SpringUtils implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext context) throws BeansException {
SpringUtils.applicationContext = context;
}
/**
* 获取ApplicationContext
*
* @return ApplicationContext
*/
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
/**
* 通过name获取Bean
*
* @param name Bean名称
* @return Bean实例
*/
public static Object getBean(String name) {
return applicationContext.getBean(name);
}
/**
* 通过class获取Bean
*
* @param clazz Bean类型
* @param <T> 泛型
* @return Bean实例
*/
public static <T> T getBean(Class<T> clazz) {
return applicationContext.getBean(clazz);
}
/**
* 通过name和class获取Bean
*
* @param name Bean名称
* @param clazz Bean类型
* @param <T> 泛型
* @return Bean实例
*/
public static <T> T getBean(String name, Class<T> clazz) {
return applicationContext.getBean(name, clazz);
}
/**
* 判断是否包含Bean
*
* @param name Bean名称
* @return 是否包含
*/
public static boolean containsBean(String name) {
return applicationContext.containsBean(name);
}
/**
* 判断Bean是否为单例
*
* @param name Bean名称
* @return 是否为单例
*/
public static boolean isSingleton(String name) {
return applicationContext.isSingleton(name);
}
/**
* 获取Bean的类型
*
* @param name Bean名称
* @return Bean类型
*/
public static Class<?> getType(String name) {
return applicationContext.getType(name);
}
}

View File

@@ -3,13 +3,18 @@ package cn.yinlihupo.controller.project;
import cn.yinlihupo.common.core.BaseResponse;
import cn.yinlihupo.common.util.ResultUtils;
import cn.yinlihupo.domain.vo.ProjectInitResult;
import cn.yinlihupo.domain.vo.ProjectInitTaskVO;
import cn.yinlihupo.service.oss.OssService;
import cn.yinlihupo.service.project.ProjectInitAsyncService;
import cn.yinlihupo.service.project.ProjectService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.HashMap;
import java.util.Map;
/**
* AI项目初始化控制器
* 提供项目文档解析和结构化数据生成功能
@@ -21,6 +26,7 @@ import org.springframework.web.multipart.MultipartFile;
public class ProjectController {
private final ProjectService projectService;
private final ProjectInitAsyncService projectInitAsyncService;
private final OssService ossService;
/**
@@ -71,6 +77,98 @@ public class ProjectController {
}
}
/**
* 异步提交项目初始化预览任务
*
* @param file 项目资料文件
* @return 任务ID和初始状态
*/
@PostMapping("/preview-async")
public BaseResponse<Map<String, Object>> submitPreviewTask(@RequestParam("file") MultipartFile file) {
log.info("收到异步项目初始化预览请求, 文件名: {}", file.getOriginalFilename());
if (file.isEmpty()) {
return ResultUtils.error("上传文件不能为空");
}
try {
// 提交异步任务
String taskId = projectInitAsyncService.submitPreviewTask(file);
Map<String, Object> result = new HashMap<>();
result.put("taskId", taskId);
result.put("status", "pending");
result.put("message", "任务已提交请使用任务ID查询进度");
return ResultUtils.success("任务提交成功", result);
} catch (Exception e) {
log.error("任务提交失败: {}", e.getMessage(), e);
return ResultUtils.error("任务提交失败: " + e.getMessage());
}
}
/**
* 查询任务状态和进度
*
* @param taskId 任务ID
* @return 任务状态信息
*/
@GetMapping("/preview-async/status/{taskId}")
public BaseResponse<ProjectInitTaskVO> getTaskStatus(@PathVariable("taskId") String taskId) {
log.info("查询任务状态, taskId: {}", taskId);
ProjectInitTaskVO taskVO = projectInitAsyncService.getTaskStatus(taskId);
if (taskVO == null) {
return ResultUtils.error("任务不存在");
}
return ResultUtils.success("查询成功", taskVO);
}
/**
* 获取任务结果
*
* @param taskId 任务ID
* @return 项目初始化结果
*/
@GetMapping("/preview-async/result/{taskId}")
public BaseResponse<ProjectInitResult> getTaskResult(@PathVariable("taskId") String taskId) {
log.info("获取任务结果, taskId: {}", taskId);
ProjectInitTaskVO taskVO = projectInitAsyncService.getTaskStatus(taskId);
if (taskVO == null) {
return ResultUtils.error("任务不存在");
}
if (!"completed".equals(taskVO.getStatus())) {
return ResultUtils.error("任务尚未完成,当前状态: " + taskVO.getStatus());
}
ProjectInitResult result = projectInitAsyncService.getTaskResult(taskId);
return ResultUtils.success("获取成功", result);
}
/**
* 取消任务
*
* @param taskId 任务ID
* @return 取消结果
*/
@PostMapping("/preview-async/cancel/{taskId}")
public BaseResponse<String> cancelTask(@PathVariable("taskId") String taskId) {
log.info("取消任务, taskId: {}", taskId);
boolean success = projectInitAsyncService.cancelTask(taskId);
if (success) {
return ResultUtils.success("任务已取消", null);
} else {
return ResultUtils.error("任务不存在或已完成,无法取消");
}
}
/**
* 获取文件扩展名
*

View File

@@ -0,0 +1,118 @@
package cn.yinlihupo.controller.project;
import cn.yinlihupo.common.core.BaseResponse;
import cn.yinlihupo.common.enums.AsyncTaskStatus;
import cn.yinlihupo.common.sse.SseChannelManager;
import cn.yinlihupo.common.sse.SseMessage;
import cn.yinlihupo.common.util.ResultUtils;
import cn.yinlihupo.service.project.ProjectInitAsyncService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import java.util.HashMap;
import java.util.Map;
/**
* 项目初始化 SSE 控制器
* 使用通用 SSE 通道管理器,通过 userId 绑定type 字段区分业务
*/
@Slf4j
@RestController
@RequestMapping("/api/v1/project-init")
@RequiredArgsConstructor
public class ProjectInitSseController {
private final ProjectInitAsyncService projectInitAsyncService;
private final SseChannelManager sseChannelManager;
/**
* 消息类型常量
*/
private static final String MESSAGE_TYPE = "project-init";
/**
* 通过 SSE 提交项目初始化任务
* 使用通用 SSE 通道,通过 userId 推送进度
*
* @param userId 用户ID
* @param file 项目资料文件
* @return 提交结果
*/
@PostMapping("/sse/submit-task")
public BaseResponse<Map<String, Object>> submitTaskWithSse(@RequestParam("userId") String userId,
@RequestParam("file") MultipartFile file) {
log.info("用户通过SSE提交任务, userId: {}, 文件名: {}", userId, file.getOriginalFilename());
if (file.isEmpty()) {
return ResultUtils.error("上传文件不能为空");
}
// 检查用户是否在线
if (!sseChannelManager.isOnline(userId)) {
return ResultUtils.error("用户未建立SSE连接请先调用 /api/v1/sse/connect/" + userId);
}
try {
// 提交异步任务,带进度回调
String taskId = projectInitAsyncService.submitPreviewTask(file, taskVO -> {
// 构建消息并推送
SseMessage message = SseMessage.of(MESSAGE_TYPE, "progress", userId, taskVO);
sseChannelManager.send(userId, message);
// 任务完成或失败,推送完成事件
if (isTaskFinished(taskVO.getStatus())) {
SseMessage completeMessage = SseMessage.of(MESSAGE_TYPE, "complete", userId, taskVO);
sseChannelManager.send(userId, completeMessage);
}
});
// 推送任务提交成功事件
Map<String, Object> submittedData = new HashMap<>();
submittedData.put("taskId", taskId);
submittedData.put("message", "任务已提交");
SseMessage submittedMessage = SseMessage.of(MESSAGE_TYPE, "submitted", userId, submittedData);
sseChannelManager.send(userId, submittedMessage);
log.info("任务提交成功并通过SSE推送, userId: {}, taskId: {}", userId, taskId);
Map<String, Object> result = new HashMap<>();
result.put("taskId", taskId);
result.put("message", "任务已提交进度将通过SSE推送");
return ResultUtils.success("提交成功", result);
} catch (Exception e) {
log.error("提交任务失败, userId: {}, error: {}", userId, e.getMessage(), e);
// 推送错误事件
Map<String, Object> errorData = new HashMap<>();
errorData.put("error", e.getMessage());
SseMessage errorMessage = SseMessage.of(MESSAGE_TYPE, "error", userId, errorData);
sseChannelManager.send(userId, errorMessage);
return ResultUtils.error("提交失败: " + e.getMessage());
}
}
// ==================== 工具方法 ====================
private boolean isTaskFinished(String status) {
return AsyncTaskStatus.COMPLETED.getCode().equals(status) ||
AsyncTaskStatus.FAILED.getCode().equals(status) ||
AsyncTaskStatus.CANCELLED.getCode().equals(status);
}
private void sendErrorAndClose(String userId, String errorMessage) {
Map<String, Object> errorData = new HashMap<>();
errorData.put("error", errorMessage);
SseMessage errorMessage_obj = SseMessage.of(MESSAGE_TYPE, "error", userId, errorData);
sseChannelManager.send(userId, errorMessage_obj);
sseChannelManager.closeChannel(userId);
}
}

View File

@@ -0,0 +1,106 @@
package cn.yinlihupo.controller.sse;
import cn.yinlihupo.common.core.BaseResponse;
import cn.yinlihupo.common.sse.SseChannelManager;
import cn.yinlihupo.common.sse.SseMessage;
import cn.yinlihupo.common.util.ResultUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
* 通用 SSE 控制器
* 提供基础的 SSE 连接管理,通过 userId 绑定通道
*/
@Slf4j
@RestController
@RequestMapping("/api/v1/sse")
@RequiredArgsConstructor
public class SseController {
private final SseChannelManager sseChannelManager;
/**
* 建立 SSE 连接
* 通过 userId 绑定通道,支持多业务类型消息推送
*
* @param userId 用户ID
* @return SSE连接
*/
@GetMapping(value = "/connect/{userId}", produces = "text/event-stream;charset=UTF-8")
public SseEmitter connect(@PathVariable("userId") String userId) {
log.info("建立SSE连接, userId: {}", userId);
// 创建通道默认30分钟超时
SseEmitter emitter = sseChannelManager.createChannel(userId);
try {
// 发送连接成功事件
Map<String, Object> data = new HashMap<>();
data.put("userId", userId);
data.put("message", "连接成功");
data.put("timestamp", System.currentTimeMillis());
SseMessage message = SseMessage.of("system", "connected", userId, data);
emitter.send(SseEmitter.event()
.name("connected")
.data(message));
} catch (IOException e) {
log.error("发送连接成功消息失败, userId: {}", userId, e);
sseChannelManager.closeChannel(userId);
emitter.completeWithError(e);
}
return emitter;
}
/**
* 检查用户在线状态
*
* @param userId 用户ID
* @return 在线状态
*/
@GetMapping("/status/{userId}")
public BaseResponse<Map<String, Object>> getStatus(@PathVariable("userId") String userId) {
boolean online = sseChannelManager.isOnline(userId);
Map<String, Object> data = new HashMap<>();
data.put("userId", userId);
data.put("online", online);
return ResultUtils.success("查询成功", data);
}
/**
* 关闭 SSE 连接
*
* @param userId 用户ID
* @return 操作结果
*/
@PostMapping("/close/{userId}")
public BaseResponse<String> close(@PathVariable("userId") String userId) {
log.info("主动关闭SSE连接, userId: {}", userId);
sseChannelManager.closeChannel(userId);
return ResultUtils.success("连接已关闭", null);
}
/**
* 获取在线统计信息
*
* @return 统计信息
*/
@GetMapping("/stats")
public BaseResponse<Map<String, Object>> getStats() {
Map<String, Object> data = new HashMap<>();
data.put("onlineCount", sseChannelManager.getOnlineCount());
data.put("timestamp", System.currentTimeMillis());
return ResultUtils.success("查询成功", data);
}
}

View File

@@ -0,0 +1,67 @@
package cn.yinlihupo.domain.vo;
import lombok.Data;
import java.time.LocalDateTime;
/**
* 项目初始化异步任务VO
*/
@Data
public class ProjectInitTaskVO {
/**
* 任务ID
*/
private String taskId;
/**
* 任务状态: pending-待处理, processing-处理中, completed-已完成, failed-失败
*/
private String status;
/**
* 状态描述
*/
private String statusDesc;
/**
* 当前进度百分比 (0-100)
*/
private Integer progress;
/**
* 进度描述信息
*/
private String progressMessage;
/**
* 原始文件名
*/
private String originalFilename;
/**
* 任务创建时间
*/
private LocalDateTime createTime;
/**
* 任务开始处理时间
*/
private LocalDateTime startTime;
/**
* 任务完成时间
*/
private LocalDateTime completeTime;
/**
* 处理结果仅当status=completed时有值
*/
private ProjectInitResult result;
/**
* 错误信息仅当status=failed时有值
*/
private String errorMessage;
}

View File

@@ -0,0 +1,61 @@
package cn.yinlihupo.service.project;
import cn.yinlihupo.domain.vo.ProjectInitResult;
import cn.yinlihupo.domain.vo.ProjectInitTaskVO;
import org.springframework.web.multipart.MultipartFile;
import java.util.function.Consumer;
/**
* 项目初始化异步任务服务接口
*/
public interface ProjectInitAsyncService {
/**
* 提交异步项目初始化预览任务
*
* @param file 项目资料文件
* @return 任务ID
*/
String submitPreviewTask(MultipartFile file);
/**
* 提交异步项目初始化预览任务(带进度回调)
*
* @param file 项目资料文件
* @param progressCallback 进度回调函数
* @return 任务ID
*/
String submitPreviewTask(MultipartFile file, Consumer<ProjectInitTaskVO> progressCallback);
/**
* 获取任务状态
*
* @param taskId 任务ID
* @return 任务状态VO
*/
ProjectInitTaskVO getTaskStatus(String taskId);
/**
* 获取任务结果
*
* @param taskId 任务ID
* @return 项目初始化结果
*/
ProjectInitResult getTaskResult(String taskId);
/**
* 取消任务
*
* @param taskId 任务ID
* @return 是否取消成功
*/
boolean cancelTask(String taskId);
/**
* 清理过期任务
*
* @param expireHours 过期时间(小时)
*/
void cleanExpiredTasks(int expireHours);
}

View File

@@ -0,0 +1,242 @@
package cn.yinlihupo.service.project.impl;
import cn.hutool.core.util.IdUtil;
import cn.yinlihupo.common.enums.AsyncTaskStatus;
import cn.yinlihupo.domain.vo.ProjectInitResult;
import cn.yinlihupo.domain.vo.ProjectInitTaskVO;
import cn.yinlihupo.service.oss.OssService;
import cn.yinlihupo.service.project.ProjectInitAsyncService;
import cn.yinlihupo.service.project.ProjectService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.time.LocalDateTime;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
/**
* 项目初始化异步任务服务实现类
* 使用内存存储任务状态
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class ProjectInitAsyncServiceImpl implements ProjectInitAsyncService {
private final ProjectService projectService;
private final OssService ossService;
/**
* 任务存储(内存存储)
*/
private final Map<String, ProjectInitTaskVO> taskStore = new ConcurrentHashMap<>();
/**
* 进度回调存储(内存存储,仅当前实例有效)
*/
private final Map<String, Consumer<ProjectInitTaskVO>> progressCallbacks = new ConcurrentHashMap<>();
@Override
public String submitPreviewTask(MultipartFile file) {
return submitPreviewTask(file, null);
}
@Override
public String submitPreviewTask(MultipartFile file, Consumer<ProjectInitTaskVO> progressCallback) {
// 生成任务ID
String taskId = IdUtil.fastSimpleUUID();
String originalFilename = file.getOriginalFilename();
log.info("提交项目初始化预览任务, taskId: {}, 文件名: {}", taskId, originalFilename);
// 创建任务记录
ProjectInitTaskVO taskVO = new ProjectInitTaskVO();
taskVO.setTaskId(taskId);
taskVO.setStatus(AsyncTaskStatus.PENDING.getCode());
taskVO.setStatusDesc(AsyncTaskStatus.PENDING.getDescription());
taskVO.setProgress(0);
taskVO.setProgressMessage("任务已提交,等待处理...");
taskVO.setOriginalFilename(originalFilename);
taskVO.setCreateTime(LocalDateTime.now());
// 存储到内存
taskStore.put(taskId, taskVO);
// 保存进度回调
if (progressCallback != null) {
progressCallbacks.put(taskId, progressCallback);
}
// 异步执行任务
executePreviewTaskAsync(taskId, file);
return taskId;
}
/**
* 异步执行预览任务
*/
@Async("projectInitTaskExecutor")
public CompletableFuture<Void> executePreviewTaskAsync(String taskId, MultipartFile file) {
ProjectInitTaskVO taskVO = taskStore.get(taskId);
if (taskVO == null) {
log.error("任务不存在, taskId: {}", taskId);
return CompletableFuture.completedFuture(null);
}
try {
// 更新状态为处理中
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 10, "正在上传文件...");
// 1. 上传文件到OSS
String fileUrl = ossService.uploadFile(file, file.getOriginalFilename());
log.info("文件上传成功, taskId: {}, URL: {}", taskId, fileUrl);
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 30, "文件上传完成,正在读取内容...");
// 2. 读取文件内容
String content = ossService.readFileAsString(fileUrl);
if (content == null || content.isEmpty()) {
throw new RuntimeException("无法读取文件内容: " + fileUrl);
}
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 50, "文件读取完成AI正在分析...");
// 3. 调用AI生成项目预览数据
updateTaskProgress(taskId, AsyncTaskStatus.PROCESSING, 60, "AI正在解析项目结构...");
ProjectInitResult result = projectService.generateProjectFromContent(content);
// 4. 更新任务完成状态
updateTaskProgress(taskId, AsyncTaskStatus.COMPLETED, 100, "项目预览数据生成成功");
taskVO.setResult(result);
taskVO.setCompleteTime(LocalDateTime.now());
log.info("项目初始化预览任务完成, taskId: {}", taskId);
} catch (Exception e) {
log.error("项目初始化预览任务失败, taskId: {}, error: {}", taskId, e.getMessage(), e);
updateTaskProgress(taskId, AsyncTaskStatus.FAILED, 0, "任务执行失败");
taskVO.setErrorMessage(e.getMessage());
taskVO.setCompleteTime(LocalDateTime.now());
} finally {
// 清理回调(仅清理内存中的回调)
progressCallbacks.remove(taskId);
// 注意Redis中的任务数据保留供后续查询24小时后自动过期
}
return CompletableFuture.completedFuture(null);
}
/**
* 更新任务进度
*/
private void updateTaskProgress(String taskId, AsyncTaskStatus status, int progress, String message) {
ProjectInitTaskVO taskVO = taskStore.get(taskId);
if (taskVO == null) {
return;
}
taskVO.setStatus(status.getCode());
taskVO.setStatusDesc(status.getDescription());
taskVO.setProgress(progress);
taskVO.setProgressMessage(message);
if (status == AsyncTaskStatus.PROCESSING && taskVO.getStartTime() == null) {
taskVO.setStartTime(LocalDateTime.now());
}
// 更新内存存储
taskStore.put(taskId, taskVO);
// 触发进度回调
Consumer<ProjectInitTaskVO> callback = progressCallbacks.get(taskId);
if (callback != null) {
try {
callback.accept(taskVO);
} catch (Exception e) {
log.warn("进度回调执行失败, taskId: {}", taskId, e);
}
}
log.debug("任务进度更新, taskId: {}, status: {}, progress: {}%, message: {}",
taskId, status.getCode(), progress, message);
}
@Override
public ProjectInitTaskVO getTaskStatus(String taskId) {
ProjectInitTaskVO taskVO = taskStore.get(taskId);
if (taskVO == null) {
return null;
}
// 返回副本,避免外部修改
return copyTaskVO(taskVO);
}
@Override
public ProjectInitResult getTaskResult(String taskId) {
ProjectInitTaskVO taskVO = taskStore.get(taskId);
if (taskVO == null || !AsyncTaskStatus.COMPLETED.getCode().equals(taskVO.getStatus())) {
return null;
}
return taskVO.getResult();
}
@Override
public boolean cancelTask(String taskId) {
ProjectInitTaskVO taskVO = taskStore.get(taskId);
if (taskVO == null) {
return false;
}
// 只能取消待处理或处理中的任务
if (AsyncTaskStatus.PENDING.getCode().equals(taskVO.getStatus()) ||
AsyncTaskStatus.PROCESSING.getCode().equals(taskVO.getStatus())) {
updateTaskProgress(taskId, AsyncTaskStatus.CANCELLED, 0, "任务已取消");
taskVO.setCompleteTime(LocalDateTime.now());
taskStore.put(taskId, taskVO);
progressCallbacks.remove(taskId);
log.info("任务已取消, taskId: {}", taskId);
return true;
}
return false;
}
@Override
public void cleanExpiredTasks(int expireHours) {
// 清理已完成的任务,释放内存
LocalDateTime expireTime = LocalDateTime.now().minusHours(expireHours);
int count = 0;
for (Map.Entry<String, ProjectInitTaskVO> entry : taskStore.entrySet()) {
ProjectInitTaskVO task = entry.getValue();
if (task.getCompleteTime() != null && task.getCompleteTime().isBefore(expireTime)) {
taskStore.remove(entry.getKey());
progressCallbacks.remove(entry.getKey());
count++;
}
}
log.info("已清理 {} 个过期任务", count);
}
/**
* 复制任务VO
*/
private ProjectInitTaskVO copyTaskVO(ProjectInitTaskVO source) {
ProjectInitTaskVO copy = new ProjectInitTaskVO();
copy.setTaskId(source.getTaskId());
copy.setStatus(source.getStatus());
copy.setStatusDesc(source.getStatusDesc());
copy.setProgress(source.getProgress());
copy.setProgressMessage(source.getProgressMessage());
copy.setOriginalFilename(source.getOriginalFilename());
copy.setCreateTime(source.getCreateTime());
copy.setStartTime(source.getStartTime());
copy.setCompleteTime(source.getCompleteTime());
copy.setResult(source.getResult());
copy.setErrorMessage(source.getErrorMessage());
return copy;
}
}