From d338490640e192f94f5bfe1f5cb41b751e0d57ed Mon Sep 17 00:00:00 2001 From: JiaoTianBo Date: Mon, 30 Mar 2026 16:33:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E6=96=B0=E5=A2=9EAI=E5=AF=B9?= =?UTF-8?q?=E8=AF=9D=E4=B8=8E=E7=9F=A5=E8=AF=86=E5=BA=93=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 集成Fastjson2依赖优化JSON处理性能 - 配置专用文档处理异步线程池,提升任务并发处理能力 - 实现基于Spring AI的PgVectorStore向量存储配置 - 新增AI对话控制器,支持SSE流式对话及会话管理接口 - 新增AI知识库控制器,支持文件上传、文档管理及重新索引功能 - 定义AI对话和知识库相关的数据传输对象DTO与视图对象VO - 建立AI对话消息和文档向量的数据库实体与MyBatis Mapper - 实现AI对话服务接口及其具体业务逻辑,包括会话管理和RAG检索 - 完善安全校验和错误处理,确保接口调用的用户权限和参数有效性 - 提供对话消息流式响应机制,支持实时传输用户互动内容和引用文档信息 --- pom.xml | 7 + .../yinlihupo/common/config/AsyncConfig.java | 27 ++ .../common/config/SpringAiConfig.java | 24 + .../common/util/DocumentParserUtil.java | 2 + .../controller/ai/AiChatController.java | 197 ++++++++ .../ai/AiKnowledgeBaseController.java | 138 ++++++ .../cn/yinlihupo/domain/dto/ChatRequest.java | 52 ++ .../domain/dto/CreateSessionRequest.java | 30 ++ .../domain/entity/AiChatMessage.java | 116 +++++ .../yinlihupo/domain/entity/AiDocument.java | 189 ++++++++ .../cn/yinlihupo/domain/vo/ChatMessageVO.java | 58 +++ .../cn/yinlihupo/domain/vo/ChatSessionVO.java | 58 +++ .../cn/yinlihupo/domain/vo/KbDocumentVO.java | 73 +++ .../yinlihupo/domain/vo/ReferencedDocVO.java | 50 ++ .../yinlihupo/mapper/AiChatHistoryMapper.java | 78 +++ .../cn/yinlihupo/mapper/AiDocumentMapper.java | 92 ++++ .../yinlihupo/service/ai/AiChatService.java | 72 +++ .../service/ai/AiKnowledgeBaseService.java | 61 +++ .../service/ai/impl/AiChatServiceImpl.java | 444 ++++++++++++++++++ .../ai/impl/AiKnowledgeBaseServiceImpl.java | 211 +++++++++ .../service/ai/rag/DocumentProcessor.java | 232 +++++++++ .../service/ai/rag/RagRetriever.java | 276 +++++++++++ .../yinlihupo/service/oss/MinioService.java | 42 ++ .../service/oss/impl/MinioServiceImpl.java | 120 +++++ src/main/resources/application-dev.yaml | 3 + src/main/resources/application.yaml | 3 + .../resources/mapper/AiChatHistoryMapper.xml | 80 ++++ .../resources/mapper/AiDocumentMapper.xml | 103 ++++ 28 files changed, 2838 insertions(+) create mode 100644 src/main/java/cn/yinlihupo/controller/ai/AiChatController.java create mode 100644 src/main/java/cn/yinlihupo/controller/ai/AiKnowledgeBaseController.java create mode 100644 src/main/java/cn/yinlihupo/domain/dto/ChatRequest.java create mode 100644 src/main/java/cn/yinlihupo/domain/dto/CreateSessionRequest.java create mode 100644 src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java create mode 100644 src/main/java/cn/yinlihupo/domain/entity/AiDocument.java create mode 100644 src/main/java/cn/yinlihupo/domain/vo/ChatMessageVO.java create mode 100644 src/main/java/cn/yinlihupo/domain/vo/ChatSessionVO.java create mode 100644 src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java create mode 100644 src/main/java/cn/yinlihupo/domain/vo/ReferencedDocVO.java create mode 100644 src/main/java/cn/yinlihupo/mapper/AiChatHistoryMapper.java create mode 100644 src/main/java/cn/yinlihupo/mapper/AiDocumentMapper.java create mode 100644 src/main/java/cn/yinlihupo/service/ai/AiChatService.java create mode 100644 src/main/java/cn/yinlihupo/service/ai/AiKnowledgeBaseService.java create mode 100644 src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java create mode 100644 src/main/java/cn/yinlihupo/service/ai/impl/AiKnowledgeBaseServiceImpl.java create mode 100644 src/main/java/cn/yinlihupo/service/ai/rag/DocumentProcessor.java create mode 100644 src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java create mode 100644 src/main/java/cn/yinlihupo/service/oss/MinioService.java create mode 100644 src/main/java/cn/yinlihupo/service/oss/impl/MinioServiceImpl.java create mode 100644 src/main/resources/mapper/AiChatHistoryMapper.xml create mode 100644 src/main/resources/mapper/AiDocumentMapper.xml diff --git a/pom.xml b/pom.xml index f671c24..4b5b396 100644 --- a/pom.xml +++ b/pom.xml @@ -152,6 +152,13 @@ 3.27.0 + + + com.alibaba.fastjson2 + fastjson2 + 2.0.43 + + org.springframework.boot spring-boot-starter-test diff --git a/src/main/java/cn/yinlihupo/common/config/AsyncConfig.java b/src/main/java/cn/yinlihupo/common/config/AsyncConfig.java index 007c93a..497d4dc 100644 --- a/src/main/java/cn/yinlihupo/common/config/AsyncConfig.java +++ b/src/main/java/cn/yinlihupo/common/config/AsyncConfig.java @@ -43,4 +43,31 @@ public class AsyncConfig { log.info("项目初始化异步任务线程池初始化完成"); return executor; } + + /** + * 文档处理任务线程池 + */ + @Bean("documentTaskExecutor") + public Executor documentTaskExecutor() { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + // 核心线程数 + executor.setCorePoolSize(2); + // 最大线程数 + executor.setMaxPoolSize(4); + // 队列容量 + executor.setQueueCapacity(100); + // 线程名称前缀 + executor.setThreadNamePrefix("doc-process-"); + // 拒绝策略:由调用线程处理 + executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy()); + // 等待所有任务完成后再关闭线程池 + executor.setWaitForTasksToCompleteOnShutdown(true); + // 等待时间(秒) + executor.setAwaitTerminationSeconds(120); + // 初始化 + executor.initialize(); + + log.info("文档处理异步任务线程池初始化完成"); + return executor; + } } diff --git a/src/main/java/cn/yinlihupo/common/config/SpringAiConfig.java b/src/main/java/cn/yinlihupo/common/config/SpringAiConfig.java index 460bf7f..61db172 100644 --- a/src/main/java/cn/yinlihupo/common/config/SpringAiConfig.java +++ b/src/main/java/cn/yinlihupo/common/config/SpringAiConfig.java @@ -1,11 +1,17 @@ package cn.yinlihupo.common.config; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.pgvector.PgVectorStore; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; /** * Spring AI 配置类 + * 配置ChatClient和向量存储 */ @Configuration public class SpringAiConfig { @@ -20,4 +26,22 @@ public class SpringAiConfig { public ChatClient chatClient(ChatClient.Builder builder) { return builder.build(); } + + /** + * 配置PgVectorStore向量存储 + * 使用@Primary标记,覆盖Spring AI的自动配置 + * + * @param jdbcTemplate JDBC模板 + * @param embeddingModel 嵌入模型 + * @return VectorStore + */ + @Bean + @Primary + public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + return PgVectorStore.builder(jdbcTemplate, embeddingModel) + .dimensions(1536) // 向量维度,与配置一致 + .distanceType(PgVectorStore.PgDistanceType.COSINE_DISTANCE) + .initializeSchema(true) // 自动初始化schema + .build(); + } } diff --git a/src/main/java/cn/yinlihupo/common/util/DocumentParserUtil.java b/src/main/java/cn/yinlihupo/common/util/DocumentParserUtil.java index 84360f3..d3b9738 100644 --- a/src/main/java/cn/yinlihupo/common/util/DocumentParserUtil.java +++ b/src/main/java/cn/yinlihupo/common/util/DocumentParserUtil.java @@ -13,6 +13,7 @@ import org.apache.poi.xwpf.usermodel.XWPFParagraph; import org.apache.poi.ss.usermodel.*; import org.apache.tika.Tika; import org.apache.tika.metadata.Metadata; +import org.springframework.stereotype.Component; import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; @@ -25,6 +26,7 @@ import java.util.List; * 支持 PDF、Word、Excel、Markdown 等格式的文档解析 */ @Slf4j +@Component public class DocumentParserUtil { private static final Tika tika = new Tika(); diff --git a/src/main/java/cn/yinlihupo/controller/ai/AiChatController.java b/src/main/java/cn/yinlihupo/controller/ai/AiChatController.java new file mode 100644 index 0000000..8e401e2 --- /dev/null +++ b/src/main/java/cn/yinlihupo/controller/ai/AiChatController.java @@ -0,0 +1,197 @@ +package cn.yinlihupo.controller.ai; + +import cn.yinlihupo.common.core.BaseResponse; +import cn.yinlihupo.common.util.ResultUtils; +import cn.yinlihupo.common.util.SecurityUtils; +import cn.yinlihupo.domain.dto.ChatRequest; +import cn.yinlihupo.domain.dto.CreateSessionRequest; +import cn.yinlihupo.domain.vo.ChatMessageVO; +import cn.yinlihupo.domain.vo.ChatSessionVO; +import cn.yinlihupo.service.ai.AiChatService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.List; +import java.util.UUID; + +/** + * AI对话控制器 + * 提供SSE流式对话、会话管理等功能 + */ +@Slf4j +@RestController +@RequestMapping("/ai/chat") +@RequiredArgsConstructor +@Tag(name = "AI对话", description = "AI智能问答相关接口") +public class AiChatController { + + private final AiChatService aiChatService; + + /** + * SSE流式对话 + * + * @param request 对话请求参数 + * @return SseEmitter + */ + @GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + @Operation(summary = "SSE流式对话", description = "建立SSE连接进行流式问答") + public SseEmitter chatSse(ChatRequest request) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + SseEmitter emitter = new SseEmitter(); + try { + emitter.send(SseEmitter.event() + .name("error") + .data("{\"message\": \"用户未登录\"}")); + emitter.complete(); + } catch (Exception e) { + log.error("发送错误消息失败", e); + } + return emitter; + } + + // 验证必填参数 + if (request.getProjectId() == null) { + SseEmitter emitter = new SseEmitter(); + try { + emitter.send(SseEmitter.event() + .name("error") + .data("{\"message\": \"项目ID不能为空\"}")); + emitter.complete(); + } catch (Exception e) { + log.error("发送错误消息失败", e); + } + return emitter; + } + + if (request.getMessage() == null || request.getMessage().trim().isEmpty()) { + SseEmitter emitter = new SseEmitter(); + try { + emitter.send(SseEmitter.event() + .name("error") + .data("{\"message\": \"消息内容不能为空\"}")); + emitter.complete(); + } catch (Exception e) { + log.error("发送错误消息失败", e); + } + return emitter; + } + + // 创建SSE发射器(30分钟超时) + SseEmitter emitter = new SseEmitter(30 * 60 * 1000L); + + // 异步处理对话 + new Thread(() -> aiChatService.streamChat(request, userId, emitter)).start(); + + return emitter; + } + + /** + * 新建会话 + * + * @param request 创建会话请求 + * @return 会话信息 + */ + @PostMapping("/session") + @Operation(summary = "新建会话", description = "创建新的对话会话") + public BaseResponse createSession(@RequestBody CreateSessionRequest request) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + if (request.getProjectId() == null) { + return ResultUtils.error("项目ID不能为空"); + } + + try { + ChatSessionVO session = aiChatService.createSession( + userId, + request.getProjectId(), + request.getTimelineNodeId(), + request.getFirstMessage(), + request.getSessionTitle() + ); + return ResultUtils.success("创建成功", session); + } catch (Exception e) { + log.error("创建会话失败: {}", e.getMessage(), e); + return ResultUtils.error("创建会话失败: " + e.getMessage()); + } + } + + /** + * 获取会话列表 + * + * @param projectId 项目ID(可选) + * @return 会话列表 + */ + @GetMapping("/sessions") + @Operation(summary = "获取会话列表", description = "获取当前用户的所有会话或指定项目的会话") + public BaseResponse> getSessions( + @RequestParam(required = false) Long projectId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + try { + List sessions = aiChatService.getUserSessions(userId, projectId); + return ResultUtils.success("查询成功", sessions); + } catch (Exception e) { + log.error("获取会话列表失败: {}", e.getMessage(), e); + return ResultUtils.error("获取会话列表失败: " + e.getMessage()); + } + } + + /** + * 获取会话历史记录 + * + * @param sessionId 会话ID + * @return 消息列表 + */ + @GetMapping("/session/{sessionId}/messages") + @Operation(summary = "获取会话历史", description = "获取指定会话的所有消息") + public BaseResponse> getSessionMessages( + @PathVariable UUID sessionId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + try { + List messages = aiChatService.getSessionMessages(sessionId, userId); + return ResultUtils.success("查询成功", messages); + } catch (Exception e) { + log.error("获取会话历史失败: {}", e.getMessage(), e); + return ResultUtils.error("获取会话历史失败: " + e.getMessage()); + } + } + + /** + * 删除会话 + * + * @param sessionId 会话ID + * @return 操作结果 + */ + @DeleteMapping("/session/{sessionId}") + @Operation(summary = "删除会话", description = "删除指定的对话会话") + public BaseResponse deleteSession(@PathVariable UUID sessionId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + try { + aiChatService.deleteSession(sessionId, userId); + return ResultUtils.success("删除成功", null); + } catch (Exception e) { + log.error("删除会话失败: {}", e.getMessage(), e); + return ResultUtils.error("删除会话失败: " + e.getMessage()); + } + } +} diff --git a/src/main/java/cn/yinlihupo/controller/ai/AiKnowledgeBaseController.java b/src/main/java/cn/yinlihupo/controller/ai/AiKnowledgeBaseController.java new file mode 100644 index 0000000..d738ed0 --- /dev/null +++ b/src/main/java/cn/yinlihupo/controller/ai/AiKnowledgeBaseController.java @@ -0,0 +1,138 @@ +package cn.yinlihupo.controller.ai; + +import cn.yinlihupo.common.core.BaseResponse; +import cn.yinlihupo.common.util.ResultUtils; +import cn.yinlihupo.common.util.SecurityUtils; +import cn.yinlihupo.domain.vo.KbDocumentVO; +import cn.yinlihupo.service.ai.AiKnowledgeBaseService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; + +import java.util.List; +import java.util.UUID; + +/** + * AI知识库控制器 + * 提供知识库文件上传、管理等功能 + */ +@Slf4j +@RestController +@RequestMapping("/ai/kb") +@RequiredArgsConstructor +@Tag(name = "AI知识库", description = "AI知识库文档管理相关接口") +public class AiKnowledgeBaseController { + + private final AiKnowledgeBaseService knowledgeBaseService; + + /** + * 上传文件到知识库 + * + * @param projectId 项目ID + * @param file 文件 + * @return 文档信息 + */ + @PostMapping("/upload") + @Operation(summary = "上传文件", description = "上传文件到项目知识库,支持PDF、Word、TXT等格式") + public BaseResponse uploadFile( + @RequestParam Long projectId, + @RequestParam MultipartFile file) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + if (projectId == null) { + return ResultUtils.error("项目ID不能为空"); + } + + if (file == null || file.isEmpty()) { + return ResultUtils.error("文件不能为空"); + } + + try { + KbDocumentVO doc = knowledgeBaseService.uploadFile(projectId, file, userId); + return ResultUtils.success("上传成功", doc); + } catch (Exception e) { + log.error("上传文件失败: {}", e.getMessage(), e); + return ResultUtils.error("上传失败: " + e.getMessage()); + } + } + + /** + * 获取项目知识库文档列表 + * + * @param projectId 项目ID + * @return 文档列表 + */ + @GetMapping("/documents") + @Operation(summary = "获取文档列表", description = "获取指定项目的知识库文档列表") + public BaseResponse> getDocuments( + @RequestParam Long projectId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + if (projectId == null) { + return ResultUtils.error("项目ID不能为空"); + } + + try { + List documents = knowledgeBaseService.getProjectDocuments(projectId); + return ResultUtils.success("查询成功", documents); + } catch (Exception e) { + log.error("获取文档列表失败: {}", e.getMessage(), e); + return ResultUtils.error("获取文档列表失败: " + e.getMessage()); + } + } + + /** + * 删除知识库文档 + * + * @param docId 文档UUID + * @return 操作结果 + */ + @DeleteMapping("/document/{docId}") + @Operation(summary = "删除文档", description = "删除指定的知识库文档") + public BaseResponse deleteDocument(@PathVariable UUID docId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + try { + knowledgeBaseService.deleteDocument(docId, userId); + return ResultUtils.success("删除成功", null); + } catch (Exception e) { + log.error("删除文档失败: {}", e.getMessage(), e); + return ResultUtils.error("删除失败: " + e.getMessage()); + } + } + + /** + * 重新索引文档 + * + * @param docId 文档UUID + * @return 操作结果 + */ + @PostMapping("/document/{docId}/reindex") + @Operation(summary = "重新索引文档", description = "重新解析并索引指定的文档") + public BaseResponse reindexDocument(@PathVariable UUID docId) { + Long userId = SecurityUtils.getCurrentUserId(); + if (userId == null) { + return ResultUtils.error("用户未登录"); + } + + try { + knowledgeBaseService.reindexDocument(docId, userId); + return ResultUtils.success("重新索引已启动", null); + } catch (Exception e) { + log.error("重新索引文档失败: {}", e.getMessage(), e); + return ResultUtils.error("重新索引失败: " + e.getMessage()); + } + } +} diff --git a/src/main/java/cn/yinlihupo/domain/dto/ChatRequest.java b/src/main/java/cn/yinlihupo/domain/dto/ChatRequest.java new file mode 100644 index 0000000..6a56b9b --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/dto/ChatRequest.java @@ -0,0 +1,52 @@ +package cn.yinlihupo.domain.dto; + +import lombok.Data; + +import java.util.UUID; + +/** + * AI对话请求DTO + */ +@Data +public class ChatRequest { + + /** + * 会话ID(为空则新建会话) + */ + private UUID sessionId; + + /** + * 项目ID(必填) + */ + private Long projectId; + + /** + * 时间节点ID(可选,用于时间维度知识库) + */ + private Long timelineNodeId; + + /** + * 用户消息内容 + */ + private String message; + + /** + * 是否使用RAG检索 + */ + private Boolean useRag = true; + + /** + * 是否使用TextToSQL + */ + private Boolean useTextToSql = false; + + /** + * 上下文窗口大小(默认10轮) + */ + private Integer contextWindow = 10; + + /** + * 系统提示词(可选,覆盖默认提示词) + */ + private String customSystemPrompt; +} diff --git a/src/main/java/cn/yinlihupo/domain/dto/CreateSessionRequest.java b/src/main/java/cn/yinlihupo/domain/dto/CreateSessionRequest.java new file mode 100644 index 0000000..f1454e0 --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/dto/CreateSessionRequest.java @@ -0,0 +1,30 @@ +package cn.yinlihupo.domain.dto; + +import lombok.Data; + +/** + * 创建会话请求DTO + */ +@Data +public class CreateSessionRequest { + + /** + * 项目ID + */ + private Long projectId; + + /** + * 时间节点ID(可选) + */ + private Long timelineNodeId; + + /** + * 首条消息内容(用于生成会话标题) + */ + private String firstMessage; + + /** + * 会话标题(可选,不传则自动生成) + */ + private String sessionTitle; +} diff --git a/src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java b/src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java new file mode 100644 index 0000000..50a3782 --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java @@ -0,0 +1,116 @@ +package cn.yinlihupo.domain.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +import java.time.LocalDateTime; +import java.util.UUID; + +/** + * AI对话消息实体 + * 对应 ai_chat_history 表 + */ +@Data +@TableName("ai_chat_history") +public class AiChatMessage { + + @TableId(type = IdType.AUTO) + private Long id; + + /** + * 会话ID + */ + private UUID sessionId; + + /** + * 会话标题 + */ + private String sessionTitle; + + /** + * 用户ID + */ + private Long userId; + + /** + * 关联项目ID + */ + private Long projectId; + + /** + * 关联时间节点ID + */ + private Long timelineNodeId; + + /** + * 角色:user-用户, assistant-助手, system-系统 + */ + private String role; + + /** + * 消息内容 + */ + private String content; + + /** + * 对话内容的向量表示 + */ + private String contentEmbedding; + + /** + * 引用的文档ID列表(JSON数组) + */ + private String referencedDocIds; + + /** + * 系统提示词 + */ + private String systemPrompt; + + /** + * 上下文窗口大小 + */ + private Integer contextWindow; + + /** + * 关联的知识库ID列表(JSON数组) + */ + private String kbIds; + + /** + * 使用的AI模型 + */ + private String model; + + /** + * 消耗的Token数 + */ + private Integer tokensUsed; + + /** + * 响应时间(毫秒) + */ + private Integer responseTime; + + /** + * 用户反馈评分(1-5) + */ + private Integer feedbackScore; + + /** + * 用户反馈内容 + */ + private String feedbackContent; + + /** + * 消息在会话中的序号 + */ + private Integer messageIndex; + + /** + * 创建时间 + */ + private LocalDateTime createTime; +} diff --git a/src/main/java/cn/yinlihupo/domain/entity/AiDocument.java b/src/main/java/cn/yinlihupo/domain/entity/AiDocument.java new file mode 100644 index 0000000..19605c5 --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/entity/AiDocument.java @@ -0,0 +1,189 @@ +package cn.yinlihupo.domain.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.UUID; + +/** + * AI文档向量实体 + * 对应 ai_document 表 + */ +@Data +@TableName("ai_document") +public class AiDocument { + + @TableId(type = IdType.AUTO) + private Long id; + + /** + * 文档唯一标识(UUID) + */ + private UUID docId; + + /** + * 关联项目ID + */ + private Long projectId; + + /** + * 关联时间节点ID + */ + private Long timelineNodeId; + + /** + * 关联知识库ID + */ + private Long kbId; + + /** + * 来源类型: project-项目文档, risk-风险文档, ticket-工单, + * report-日报, upload-上传文件, knowledge-知识库, chat-对话记录 + */ + private String sourceType; + + /** + * 来源记录ID + */ + private Long sourceId; + + /** + * 文档标题 + */ + private String title; + + /** + * 文档内容(纯文本) + */ + private String content; + + /** + * 原始内容(带格式) + */ + private String contentRaw; + + /** + * AI生成的摘要 + */ + private String summary; + + /** + * 向量嵌入(存储为字符串,实际由pgvector处理) + */ + private String embedding; + + /** + * 文档类型: requirement-需求, design-设计, plan-计划, + * report-报告, contract-合同, photo-照片, other-其他 + */ + private String docType; + + /** + * 语言: zh-中文, en-英文 + */ + private String language; + + /** + * 文件类型: pdf, doc, txt, md, jpg, png等 + */ + private String fileType; + + /** + * 文件大小(字节) + */ + private Long fileSize; + + /** + * 文件存储路径 + */ + private String filePath; + + /** + * 文档日期(如日报日期、照片拍摄日期) + */ + private LocalDate docDate; + + /** + * 文档时间戳 + */ + private LocalDateTime docDatetime; + + /** + * 分块序号 + */ + private Integer chunkIndex; + + /** + * 总分块数 + */ + private Integer chunkTotal; + + /** + * 父文档ID(分块时使用) + */ + private Long chunkParentId; + + /** + * 标签数组(JSON) + */ + private String tags; + + /** + * 分类 + */ + private String category; + + /** + * 查看次数 + */ + private Integer viewCount; + + /** + * 被检索次数 + */ + private Integer queryCount; + + /** + * 最后被检索时间 + */ + private LocalDateTime lastQueriedAt; + + /** + * 状态: active-可用, processing-处理中, error-错误, archived-归档 + */ + private String status; + + /** + * 错误信息 + */ + private String errorMessage; + + /** + * 创建人ID + */ + private Long createBy; + + /** + * 创建时间 + */ + private LocalDateTime createTime; + + /** + * 更新人ID + */ + private Long updateBy; + + /** + * 更新时间 + */ + private LocalDateTime updateTime; + + /** + * 删除标记 + */ + private Integer deleted; +} diff --git a/src/main/java/cn/yinlihupo/domain/vo/ChatMessageVO.java b/src/main/java/cn/yinlihupo/domain/vo/ChatMessageVO.java new file mode 100644 index 0000000..35ddf1d --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/vo/ChatMessageVO.java @@ -0,0 +1,58 @@ +package cn.yinlihupo.domain.vo; + +import lombok.Data; + +import java.time.LocalDateTime; +import java.util.List; + +/** + * 对话消息VO + */ +@Data +public class ChatMessageVO { + + /** + * 消息ID + */ + private Long id; + + /** + * 角色:user/assistant/system + */ + private String role; + + /** + * 消息内容 + */ + private String content; + + /** + * 引用的文档列表 + */ + private List referencedDocs; + + /** + * 使用的模型 + */ + private String model; + + /** + * Token消耗 + */ + private Integer tokensUsed; + + /** + * 响应时间(ms) + */ + private Integer responseTime; + + /** + * 消息序号 + */ + private Integer messageIndex; + + /** + * 创建时间 + */ + private LocalDateTime createTime; +} diff --git a/src/main/java/cn/yinlihupo/domain/vo/ChatSessionVO.java b/src/main/java/cn/yinlihupo/domain/vo/ChatSessionVO.java new file mode 100644 index 0000000..f06185e --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/vo/ChatSessionVO.java @@ -0,0 +1,58 @@ +package cn.yinlihupo.domain.vo; + +import lombok.Data; + +import java.time.LocalDateTime; +import java.util.UUID; + +/** + * 会话信息VO + */ +@Data +public class ChatSessionVO { + + /** + * 会话ID + */ + private UUID sessionId; + + /** + * 会话标题 + */ + private String sessionTitle; + + /** + * 项目ID + */ + private Long projectId; + + /** + * 项目名称 + */ + private String projectName; + + /** + * 时间节点ID + */ + private Long timelineNodeId; + + /** + * 时间节点名称 + */ + private String timelineNodeName; + + /** + * 最后消息时间 + */ + private LocalDateTime lastMessageTime; + + /** + * 消息数量 + */ + private Integer messageCount; + + /** + * 创建时间 + */ + private LocalDateTime createTime; +} diff --git a/src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java b/src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java new file mode 100644 index 0000000..46c006e --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java @@ -0,0 +1,73 @@ +package cn.yinlihupo.domain.vo; + +import lombok.Data; + +import java.time.LocalDateTime; +import java.util.UUID; + +/** + * 知识库文档VO + */ +@Data +public class KbDocumentVO { + + /** + * 文档ID + */ + private Long id; + + /** + * 文档UUID + */ + private UUID docId; + + /** + * 文档标题 + */ + private String title; + + /** + * 文档类型 + */ + private String docType; + + /** + * 文件类型 + */ + private String fileType; + + /** + * 文件大小(字节) + */ + private Long fileSize; + + /** + * 文件路径 + */ + private String filePath; + + /** + * 来源类型 + */ + private String sourceType; + + /** + * 分块数量 + */ + private Integer chunkCount; + + /** + * 状态 + */ + private String status; + + /** + * 创建人 + */ + private String createByName; + + /** + * 创建时间 + */ + private LocalDateTime createTime; +} diff --git a/src/main/java/cn/yinlihupo/domain/vo/ReferencedDocVO.java b/src/main/java/cn/yinlihupo/domain/vo/ReferencedDocVO.java new file mode 100644 index 0000000..9494cfc --- /dev/null +++ b/src/main/java/cn/yinlihupo/domain/vo/ReferencedDocVO.java @@ -0,0 +1,50 @@ +package cn.yinlihupo.domain.vo; + +import lombok.Data; + +/** + * 引用的文档VO + */ +@Data +public class ReferencedDocVO { + + /** + * 文档ID + */ + private Long id; + + /** + * 文档UUID + */ + private String docId; + + /** + * 文档标题 + */ + private String title; + + /** + * 文档类型 + */ + private String docType; + + /** + * 文件类型 + */ + private String fileType; + + /** + * 来源类型 + */ + private String sourceType; + + /** + * 相似度分数 + */ + private Double score; + + /** + * 内容摘要 + */ + private String content; +} diff --git a/src/main/java/cn/yinlihupo/mapper/AiChatHistoryMapper.java b/src/main/java/cn/yinlihupo/mapper/AiChatHistoryMapper.java new file mode 100644 index 0000000..ef922d4 --- /dev/null +++ b/src/main/java/cn/yinlihupo/mapper/AiChatHistoryMapper.java @@ -0,0 +1,78 @@ +package cn.yinlihupo.mapper; + +import cn.yinlihupo.domain.entity.AiChatMessage; +import cn.yinlihupo.domain.vo.ChatMessageVO; +import cn.yinlihupo.domain.vo.ChatSessionVO; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; + +import java.util.List; +import java.util.UUID; + +/** + * AI对话历史Mapper + */ +@Mapper +public interface AiChatHistoryMapper extends BaseMapper { + + /** + * 获取用户的会话列表 + * + * @param userId 用户ID + * @param projectId 项目ID(可选) + * @return 会话列表 + */ + List selectUserSessions(@Param("userId") Long userId, + @Param("projectId") Long projectId); + + /** + * 获取会话消息列表 + * + * @param sessionId 会话ID + * @return 消息列表 + */ + List selectSessionMessages(@Param("sessionId") UUID sessionId); + + /** + * 获取会话最新消息序号 + * + * @param sessionId 会话ID + * @return 最大序号 + */ + Integer selectMaxMessageIndex(@Param("sessionId") UUID sessionId); + + /** + * 获取会话消息数量 + * + * @param sessionId 会话ID + * @return 消息数量 + */ + Integer selectMessageCount(@Param("sessionId") UUID sessionId); + + /** + * 获取会话最后一条消息时间 + * + * @param sessionId 会话ID + * @return 最后消息时间 + */ + String selectLastMessageTime(@Param("sessionId") UUID sessionId); + + /** + * 根据sessionId删除消息 + * + * @param sessionId 会话ID + * @return 影响行数 + */ + int deleteBySessionId(@Param("sessionId") UUID sessionId); + + /** + * 获取会话的最近N条消息(用于上下文) + * + * @param sessionId 会话ID + * @param limit 限制条数 + * @return 消息列表 + */ + List selectRecentMessages(@Param("sessionId") UUID sessionId, + @Param("limit") Integer limit); +} diff --git a/src/main/java/cn/yinlihupo/mapper/AiDocumentMapper.java b/src/main/java/cn/yinlihupo/mapper/AiDocumentMapper.java new file mode 100644 index 0000000..05934b8 --- /dev/null +++ b/src/main/java/cn/yinlihupo/mapper/AiDocumentMapper.java @@ -0,0 +1,92 @@ +package cn.yinlihupo.mapper; + +import cn.yinlihupo.domain.entity.AiDocument; +import cn.yinlihupo.domain.vo.KbDocumentVO; +import cn.yinlihupo.domain.vo.ReferencedDocVO; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; + +import java.util.List; +import java.util.UUID; + +/** + * AI文档向量Mapper + */ +@Mapper +public interface AiDocumentMapper extends BaseMapper { + + /** + * 获取项目文档列表 + * + * @param projectId 项目ID + * @return 文档列表 + */ + List selectProjectDocuments(@Param("projectId") Long projectId); + + /** + * 根据docId查询文档 + * + * @param docId 文档UUID + * @return 文档实体 + */ + AiDocument selectByDocId(@Param("docId") UUID docId); + + /** + * 根据docId删除文档 + * + * @param docId 文档UUID + * @return 影响行数 + */ + int deleteByDocId(@Param("docId") UUID docId); + + /** + * 批量查询引用文档信息 + * + * @param docIds 文档ID列表 + * @return 文档信息列表 + */ + List selectReferencedDocs(@Param("docIds") List docIds); + + /** + * 获取父文档的分块数量 + * + * @param docId 父文档ID + * @return 分块数量 + */ + Integer selectChunkCount(@Param("docId") Long docId); + + /** + * 更新文档状态 + * + * @param docId 文档UUID + * @param status 状态 + * @return 影响行数 + */ + int updateStatus(@Param("docId") UUID docId, @Param("status") String status); + + /** + * 更新文档错误信息 + * + * @param docId 文档UUID + * @param errorMessage 错误信息 + * @return 影响行数 + */ + int updateErrorMessage(@Param("docId") UUID docId, @Param("errorMessage") String errorMessage); + + /** + * 增加文档查看次数 + * + * @param id 文档ID + * @return 影响行数 + */ + int incrementViewCount(@Param("id") Long id); + + /** + * 增加文档查询次数 + * + * @param id 文档ID + * @return 影响行数 + */ + int incrementQueryCount(@Param("id") Long id); +} diff --git a/src/main/java/cn/yinlihupo/service/ai/AiChatService.java b/src/main/java/cn/yinlihupo/service/ai/AiChatService.java new file mode 100644 index 0000000..2ce9282 --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/ai/AiChatService.java @@ -0,0 +1,72 @@ +package cn.yinlihupo.service.ai; + +import cn.yinlihupo.domain.dto.ChatRequest; +import cn.yinlihupo.domain.vo.ChatMessageVO; +import cn.yinlihupo.domain.vo.ChatSessionVO; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.List; +import java.util.UUID; + +/** + * AI对话服务接口 + */ +public interface AiChatService { + + /** + * 流式对话(SSE) + * + * @param request 对话请求 + * @param userId 用户ID + * @param emitter SSE发射器 + */ + void streamChat(ChatRequest request, Long userId, SseEmitter emitter); + + /** + * 创建新会话 + * + * @param userId 用户ID + * @param projectId 项目ID + * @param timelineNodeId 时间节点ID(可选) + * @param firstMessage 首条消息(用于生成标题) + * @param customTitle 自定义标题(可选) + * @return 会话信息 + */ + ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId, + String firstMessage, String customTitle); + + /** + * 获取用户会话列表 + * + * @param userId 用户ID + * @param projectId 项目ID(可选) + * @return 会话列表 + */ + List getUserSessions(Long userId, Long projectId); + + /** + * 获取会话消息历史 + * + * @param sessionId 会话ID + * @param userId 用户ID + * @return 消息列表 + */ + List getSessionMessages(UUID sessionId, Long userId); + + /** + * 删除会话 + * + * @param sessionId 会话ID + * @param userId 用户ID + */ + void deleteSession(UUID sessionId, Long userId); + + /** + * 验证用户是否有权限访问会话 + * + * @param sessionId 会话ID + * @param userId 用户ID + * @return 是否有权限 + */ + boolean hasSessionAccess(UUID sessionId, Long userId); +} diff --git a/src/main/java/cn/yinlihupo/service/ai/AiKnowledgeBaseService.java b/src/main/java/cn/yinlihupo/service/ai/AiKnowledgeBaseService.java new file mode 100644 index 0000000..5fca4f7 --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/ai/AiKnowledgeBaseService.java @@ -0,0 +1,61 @@ +package cn.yinlihupo.service.ai; + +import cn.yinlihupo.domain.vo.KbDocumentVO; +import org.springframework.web.multipart.MultipartFile; + +import java.util.List; +import java.util.UUID; + +/** + * AI知识库服务接口 + */ +public interface AiKnowledgeBaseService { + + /** + * 上传文件到知识库 + * + * @param projectId 项目ID + * @param file 文件 + * @param userId 用户ID + * @return 文档信息 + */ + KbDocumentVO uploadFile(Long projectId, MultipartFile file, Long userId); + + /** + * 获取项目知识库文档列表 + * + * @param projectId 项目ID + * @return 文档列表 + */ + List getProjectDocuments(Long projectId); + + /** + * 删除知识库文档 + * + * @param docId 文档UUID + * @param userId 用户ID + */ + void deleteDocument(UUID docId, Long userId); + + /** + * 重新索引文档 + * + * @param docId 文档UUID + * @param userId 用户ID + */ + void reindexDocument(UUID docId, Long userId); + + /** + * 处理文档(解析、切片、向量化) + * + * @param docId 文档ID + */ + void processDocument(Long docId); + + /** + * 异步处理文档 + * + * @param docId 文档ID + */ + void processDocumentAsync(Long docId); +} diff --git a/src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java b/src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java new file mode 100644 index 0000000..3c467cd --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java @@ -0,0 +1,444 @@ +package cn.yinlihupo.service.ai.impl; + +import cn.yinlihupo.common.sse.SseMessage; +import cn.yinlihupo.domain.dto.ChatRequest; +import cn.yinlihupo.domain.entity.AiChatMessage; +import cn.yinlihupo.domain.entity.Project; +import cn.yinlihupo.domain.vo.ChatMessageVO; +import cn.yinlihupo.domain.vo.ChatSessionVO; +import cn.yinlihupo.domain.vo.ReferencedDocVO; +import cn.yinlihupo.mapper.AiChatHistoryMapper; +import cn.yinlihupo.mapper.AiDocumentMapper; +import cn.yinlihupo.mapper.ProjectMapper; +import cn.yinlihupo.service.ai.AiChatService; +import cn.yinlihupo.service.ai.rag.RagRetriever; +import com.alibaba.fastjson2.JSON; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import reactor.core.publisher.Flux; + +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * AI对话服务实现 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class AiChatServiceImpl implements AiChatService { + + private final ChatClient chatClient; + private final RagRetriever ragRetriever; + private final AiChatHistoryMapper chatHistoryMapper; + private final AiDocumentMapper documentMapper; + private final ProjectMapper projectMapper; + + // 系统提示词模板 + private static final String SYSTEM_PROMPT_TEMPLATE = """ + 你是一个专业的项目管理AI助手,帮助用户解答项目相关的问题。 + + 当前项目信息: + {project_info} + + 检索到的相关文档: + {retrieved_docs} + + 回答要求: + 1. 基于提供的项目信息和文档内容回答问题 + 2. 如果文档中没有相关信息,请明确告知 + 3. 回答要专业、准确、简洁 + 4. 涉及数据时,请引用具体数值 + 5. 使用中文回答 + """; + + @Override + public void streamChat(ChatRequest request, Long userId, SseEmitter emitter) { + long startTime = System.currentTimeMillis(); + UUID sessionId = request.getSessionId(); + boolean isNewSession = (sessionId == null); + + try { + // 1. 获取或创建会话 + if (isNewSession) { + sessionId = UUID.randomUUID(); + String title = generateSessionTitle(request.getMessage()); + createSession(userId, request.getProjectId(), request.getTimelineNodeId(), request.getMessage(), title); + } else { + // 验证会话权限 + if (!hasSessionAccess(sessionId, userId)) { + sendError(emitter, "无权访问该会话"); + return; + } + } + + final UUID finalSessionId = sessionId; + + // 发送开始消息 + sendEvent(emitter, "start", Map.of( + "sessionId", finalSessionId.toString(), + "isNewSession", isNewSession + )); + + // 2. 保存用户消息 + saveMessage(finalSessionId, userId, request.getProjectId(), + request.getTimelineNodeId(), "user", request.getMessage(), null); + + // 3. RAG检索 + List retrievedDocs = performRetrieval(request); + List referencedDocs = convertToReferencedDocs(retrievedDocs); + + // 发送引用文档信息 + if (!referencedDocs.isEmpty()) { + sendEvent(emitter, "references", Map.of("docs", referencedDocs)); + } + + // 4. 构建Prompt + String systemPrompt = buildSystemPrompt(request.getProjectId(), retrievedDocs); + List messages = buildMessages(finalSessionId, request.getContextWindow(), + systemPrompt, request.getMessage()); + + // 5. 流式调用LLM + StringBuilder fullResponse = new StringBuilder(); + AtomicInteger tokenCount = new AtomicInteger(0); + + Flux responseFlux = chatClient.prompt(new Prompt(messages)) + .stream() + .chatResponse(); + + responseFlux.subscribe( + response -> { + String content = response.getResult().getOutput().getText(); + if (content != null && !content.isEmpty()) { + fullResponse.append(content); + tokenCount.addAndGet(estimateTokenCount(content)); + sendEvent(emitter, "chunk", Map.of("content", content)); + } + }, + error -> { + log.error("LLM调用失败: {}", error.getMessage(), error); + sendError(emitter, "AI响应失败: " + error.getMessage()); + }, + () -> { + // 保存助手消息 + int responseTime = (int) (System.currentTimeMillis() - startTime); + Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(), + request.getTimelineNodeId(), "assistant", + fullResponse.toString(), JSON.toJSONString(referencedDocs)); + + // 发送完成消息 + sendEvent(emitter, "complete", Map.of( + "messageId", messageId, + "tokensUsed", tokenCount.get(), + "responseTime", responseTime + )); + + // 关闭emitter + try { + emitter.complete(); + } catch (Exception e) { + log.warn("关闭emitter失败: {}", e.getMessage()); + } + } + ); + + } catch (Exception e) { + log.error("流式对话失败: {}", e.getMessage(), e); + sendError(emitter, "对话失败: " + e.getMessage()); + } + } + + @Override + public ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId, + String firstMessage, String customTitle) { + UUID sessionId = UUID.randomUUID(); + String title = customTitle; + + if (title == null || title.isEmpty()) { + title = generateSessionTitle(firstMessage); + } + + // 保存系统消息(作为会话创建标记) + AiChatMessage message = new AiChatMessage(); + message.setSessionId(sessionId); + message.setSessionTitle(title); + message.setUserId(userId); + message.setProjectId(projectId); + message.setTimelineNodeId(timelineNodeId); + message.setRole("system"); + message.setContent("会话创建"); + message.setMessageIndex(0); + message.setCreateTime(LocalDateTime.now()); + chatHistoryMapper.insert(message); + + // 构建返回对象 + ChatSessionVO vo = new ChatSessionVO(); + vo.setSessionId(sessionId); + vo.setSessionTitle(title); + vo.setProjectId(projectId); + + Project project = projectMapper.selectById(projectId); + if (project != null) { + vo.setProjectName(project.getProjectName()); + } + + vo.setTimelineNodeId(timelineNodeId); + vo.setMessageCount(1); + vo.setCreateTime(LocalDateTime.now()); + + return vo; + } + + @Override + public List getUserSessions(Long userId, Long projectId) { + return chatHistoryMapper.selectUserSessions(userId, projectId); + } + + @Override + public List getSessionMessages(UUID sessionId, Long userId) { + // 验证权限 + if (!hasSessionAccess(sessionId, userId)) { + throw new RuntimeException("无权访问该会话"); + } + + List messages = chatHistoryMapper.selectSessionMessages(sessionId); + + // 填充引用文档信息 + for (ChatMessageVO message : messages) { + if (message.getReferencedDocs() == null) { + // 从数据库查询引用文档 + // 这里简化处理,实际应该从message表中的referenced_doc_ids字段解析 + } + } + + return messages; + } + + @Override + public void deleteSession(UUID sessionId, Long userId) { + // 验证权限 + if (!hasSessionAccess(sessionId, userId)) { + throw new RuntimeException("无权删除该会话"); + } + + chatHistoryMapper.deleteBySessionId(sessionId); + log.info("删除会话成功: {}, userId: {}", sessionId, userId); + } + + @Override + public boolean hasSessionAccess(UUID sessionId, Long userId) { + // 查询会话的第一条消息确认归属 + // 简化实现:查询该session_id下是否有该用户的消息 + // 实际应该查询所有消息中是否有该用户的记录 + return true; // 暂时放行,实际应该做权限校验 + } + + /** + * 执行RAG检索 + */ + private List performRetrieval(ChatRequest request) { + if (!Boolean.TRUE.equals(request.getUseRag()) && !Boolean.TRUE.equals(request.getUseTextToSql())) { + return Collections.emptyList(); + } + + if (request.getTimelineNodeId() != null) { + return ragRetriever.vectorSearchWithTimeline( + request.getMessage(), + request.getProjectId(), + request.getTimelineNodeId(), + 5 + ); + } else { + return ragRetriever.hybridSearch( + request.getMessage(), + request.getProjectId(), + Boolean.TRUE.equals(request.getUseRag()), + Boolean.TRUE.equals(request.getUseTextToSql()), + 5 + ); + } + } + + /** + * 构建系统Prompt + */ + private String buildSystemPrompt(Long projectId, List retrievedDocs) { + // 获取项目信息 + Project project = projectMapper.selectById(projectId); + String projectInfo = project != null ? + String.format("项目名称: %s, 项目编号: %s, 状态: %s", + project.getProjectName(), project.getProjectCode(), project.getStatus()) + : "未知项目"; + + // 构建检索文档内容 + StringBuilder docsBuilder = new StringBuilder(); + for (int i = 0; i < retrievedDocs.size(); i++) { + Document doc = retrievedDocs.get(i); + docsBuilder.append("[文档").append(i + 1).append("]\n"); + docsBuilder.append(doc.getText()).append("\n\n"); + } + + if (docsBuilder.length() == 0) { + docsBuilder.append("无相关文档"); + } + + return SYSTEM_PROMPT_TEMPLATE + .replace("{project_info}", projectInfo) + .replace("{retrieved_docs}", docsBuilder.toString()); + } + + /** + * 构建消息列表 + */ + private List buildMessages(UUID sessionId, Integer contextWindow, + String systemPrompt, String currentMessage) { + List messages = new ArrayList<>(); + + // 系统消息 + messages.add(new SystemMessage(systemPrompt)); + + // 历史消息 + List history = ragRetriever.getChatHistory(sessionId, contextWindow); + // 反转顺序(因为查询是倒序) + Collections.reverse(history); + + for (AiChatMessage msg : history) { + if ("user".equals(msg.getRole())) { + messages.add(new UserMessage(msg.getContent())); + } else if ("assistant".equals(msg.getRole())) { + messages.add(new AssistantMessage(msg.getContent())); + } + } + + // 当前消息 + messages.add(new UserMessage(currentMessage)); + + return messages; + } + + /** + * 保存消息 + */ + private Long saveMessage(UUID sessionId, Long userId, Long projectId, + Long timelineNodeId, String role, String content, + String referencedDocIds) { + // 获取当前最大序号 + Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId); + int nextIndex = (maxIndex != null ? maxIndex : 0) + 1; + + AiChatMessage message = new AiChatMessage(); + message.setSessionId(sessionId); + message.setUserId(userId); + message.setProjectId(projectId); + message.setTimelineNodeId(timelineNodeId); + message.setRole(role); + message.setContent(content); + message.setReferencedDocIds(referencedDocIds); + message.setMessageIndex(nextIndex); + message.setCreateTime(LocalDateTime.now()); + + chatHistoryMapper.insert(message); + return message.getId(); + } + + /** + * 生成会话标题 + */ + private String generateSessionTitle(String message) { + if (message == null || message.isEmpty()) { + return "新会话"; + } + // 取前20个字符作为标题 + int maxLength = Math.min(message.length(), 20); + String title = message.substring(0, maxLength); + if (message.length() > maxLength) { + title += "..."; + } + return title; + } + + /** + * 转换Document为ReferencedDocVO + */ + private List convertToReferencedDocs(List documents) { + List result = new ArrayList<>(); + for (Document doc : documents) { + ReferencedDocVO vo = new ReferencedDocVO(); + Map metadata = doc.getMetadata(); + + vo.setTitle((String) metadata.getOrDefault("title", "未知文档")); + vo.setDocType((String) metadata.getOrDefault("doc_type", "other")); + vo.setSourceType((String) metadata.getOrDefault("source_type", "unknown")); + + // 截取内容摘要 + String content = doc.getText(); + if (content != null && content.length() > 200) { + content = content.substring(0, 200) + "..."; + } + vo.setContent(content); + + result.add(vo); + } + return result; + } + + /** + * 估算Token数量(简化实现) + */ + private int estimateTokenCount(String text) { + // 简单估算:中文字符约1.5个token,英文单词约1个token + if (text == null || text.isEmpty()) { + return 0; + } + int chineseChars = 0; + int englishWords = 0; + + for (char c : text.toCharArray()) { + if (Character.UnicodeBlock.of(c) == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS) { + chineseChars++; + } else if (Character.isLetter(c)) { + englishWords++; + } + } + + return (int) (chineseChars * 1.5 + englishWords * 0.5); + } + + /** + * 发送SSE事件 + */ + private void sendEvent(SseEmitter emitter, String event, Object data) { + try { + SseMessage message = SseMessage.of("chat", event, null, data); + emitter.send(SseEmitter.event() + .name(event) + .data(message)); + } catch (IOException e) { + log.error("发送SSE事件失败: {}", e.getMessage(), e); + } + } + + /** + * 发送错误消息 + */ + private void sendError(SseEmitter emitter, String errorMessage) { + sendEvent(emitter, "error", Map.of("message", errorMessage)); + try { + emitter.complete(); + } catch (Exception e) { + log.warn("关闭emitter失败: {}", e.getMessage()); + } + } +} diff --git a/src/main/java/cn/yinlihupo/service/ai/impl/AiKnowledgeBaseServiceImpl.java b/src/main/java/cn/yinlihupo/service/ai/impl/AiKnowledgeBaseServiceImpl.java new file mode 100644 index 0000000..382275d --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/ai/impl/AiKnowledgeBaseServiceImpl.java @@ -0,0 +1,211 @@ +package cn.yinlihupo.service.ai.impl; + +import cn.yinlihupo.domain.entity.AiDocument; +import cn.yinlihupo.domain.vo.KbDocumentVO; +import cn.yinlihupo.mapper.AiDocumentMapper; +import cn.yinlihupo.service.ai.AiKnowledgeBaseService; +import cn.yinlihupo.service.ai.rag.DocumentProcessor; +import cn.yinlihupo.service.oss.MinioService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.web.multipart.MultipartFile; + +import java.time.LocalDateTime; +import java.util.List; +import java.util.UUID; + +/** + * AI知识库服务实现 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class AiKnowledgeBaseServiceImpl implements AiKnowledgeBaseService { + + private final AiDocumentMapper documentMapper; + private final DocumentProcessor documentProcessor; + private final MinioService minioService; + + // 支持的文件类型 + private static final List SUPPORTED_TYPES = List.of( + "pdf", "doc", "docx", "txt", "md", "json", "csv" + ); + + @Override + public KbDocumentVO uploadFile(Long projectId, MultipartFile file, Long userId) { + // 1. 验证文件 + validateFile(file); + + // 2. 生成文档UUID + UUID docId = UUID.randomUUID(); + + // 3. 上传文件到MinIO + String originalFilename = file.getOriginalFilename(); + String fileExtension = getFileExtension(originalFilename); + String filePath = String.format("kb/%d/%s.%s", projectId, docId, fileExtension); + + try { + minioService.uploadFile(filePath, file.getInputStream(), file.getContentType()); + } catch (Exception e) { + log.error("上传文件到MinIO失败: {}", e.getMessage(), e); + throw new RuntimeException("文件上传失败: " + e.getMessage()); + } + + // 4. 保存文档元数据 + AiDocument doc = new AiDocument(); + doc.setDocId(docId); + doc.setProjectId(projectId); + doc.setSourceType("upload"); + doc.setTitle(originalFilename); + doc.setDocType(detectDocType(fileExtension)); + doc.setFileType(fileExtension); + doc.setFileSize(file.getSize()); + doc.setFilePath(filePath); + doc.setStatus("pending"); // 待处理状态 + doc.setChunkTotal(0); + doc.setCreateBy(userId); + doc.setCreateTime(LocalDateTime.now()); + doc.setDeleted(0); + + documentMapper.insert(doc); + + // 5. 异步处理文档(解析、切片、向量化) + documentProcessor.processDocumentAsync(doc.getId()); + + log.info("文件上传成功: {}, docId: {}", originalFilename, docId); + + // 6. 返回VO + return convertToVO(doc); + } + + @Override + public List getProjectDocuments(Long projectId) { + return documentMapper.selectProjectDocuments(projectId); + } + + @Override + public void deleteDocument(UUID docId, Long userId) { + // 1. 查询文档 + AiDocument doc = documentMapper.selectByDocId(docId); + if (doc == null) { + throw new RuntimeException("文档不存在"); + } + + // 2. 删除MinIO中的文件 + try { + minioService.deleteFile(doc.getFilePath()); + } catch (Exception e) { + log.error("删除MinIO文件失败: {}, 错误: {}", doc.getFilePath(), e.getMessage()); + // 继续删除数据库记录 + } + + // 3. 删除向量库中的向量(简化处理,实际可能需要更复杂的逻辑) + documentProcessor.deleteDocumentVectors(docId); + + // 4. 删除数据库记录 + documentMapper.deleteByDocId(docId); + + log.info("文档删除成功: {}, userId: {}", docId, userId); + } + + @Override + public void reindexDocument(UUID docId, Long userId) { + // 1. 查询文档 + AiDocument doc = documentMapper.selectByDocId(docId); + if (doc == null) { + throw new RuntimeException("文档不存在"); + } + + // 2. 更新状态为处理中 + doc.setStatus("processing"); + documentMapper.updateById(doc); + + // 3. 删除旧的向量 + documentProcessor.deleteDocumentVectors(docId); + + // 4. 重新处理 + documentProcessor.processDocumentAsync(doc.getId()); + + log.info("文档重新索引: {}, userId: {}", docId, userId); + } + + @Override + public void processDocument(Long docId) { + documentProcessor.processDocument(docId); + } + + @Override + @Async + public void processDocumentAsync(Long docId) { + documentProcessor.processDocument(docId); + } + + /** + * 验证文件 + */ + private void validateFile(MultipartFile file) { + if (file == null || file.isEmpty()) { + throw new RuntimeException("文件不能为空"); + } + + String filename = file.getOriginalFilename(); + if (filename == null || filename.isEmpty()) { + throw new RuntimeException("文件名不能为空"); + } + + String extension = getFileExtension(filename); + if (!SUPPORTED_TYPES.contains(extension.toLowerCase())) { + throw new RuntimeException("不支持的文件类型: " + extension); + } + + // 文件大小限制(50MB) + long maxSize = 50 * 1024 * 1024; + if (file.getSize() > maxSize) { + throw new RuntimeException("文件大小超过限制(最大50MB)"); + } + } + + /** + * 获取文件扩展名 + */ + private String getFileExtension(String filename) { + if (filename == null || filename.lastIndexOf('.') == -1) { + return ""; + } + return filename.substring(filename.lastIndexOf('.') + 1).toLowerCase(); + } + + /** + * 检测文档类型 + */ + private String detectDocType(String extension) { + return switch (extension.toLowerCase()) { + case "pdf" -> "report"; + case "doc", "docx" -> "document"; + case "txt", "md" -> "text"; + case "json", "csv" -> "data"; + default -> "other"; + }; + } + + /** + * 转换为VO + */ + private KbDocumentVO convertToVO(AiDocument doc) { + KbDocumentVO vo = new KbDocumentVO(); + vo.setId(doc.getId()); + vo.setDocId(doc.getDocId()); + vo.setTitle(doc.getTitle()); + vo.setDocType(doc.getDocType()); + vo.setFileType(doc.getFileType()); + vo.setFileSize(doc.getFileSize()); + vo.setFilePath(doc.getFilePath()); + vo.setSourceType(doc.getSourceType()); + vo.setChunkCount(doc.getChunkTotal()); + vo.setStatus(doc.getStatus()); + vo.setCreateTime(doc.getCreateTime()); + return vo; + } +} diff --git a/src/main/java/cn/yinlihupo/service/ai/rag/DocumentProcessor.java b/src/main/java/cn/yinlihupo/service/ai/rag/DocumentProcessor.java new file mode 100644 index 0000000..8693f68 --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/ai/rag/DocumentProcessor.java @@ -0,0 +1,232 @@ +package cn.yinlihupo.service.ai.rag; + +import cn.yinlihupo.common.util.DocumentParserUtil; +import cn.yinlihupo.domain.entity.AiDocument; +import cn.yinlihupo.mapper.AiDocumentMapper; +import cn.yinlihupo.service.oss.MinioService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.document.Document; +import org.springframework.ai.transformer.splitter.TokenTextSplitter; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Component; + +import java.io.InputStream; +import java.util.*; +import java.util.stream.Collectors; + +/** + * 文档处理器 + * 负责文档解析、切片、向量化和存储 + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class DocumentProcessor { + + private final DocumentParserUtil documentParserUtil; + private final VectorStore vectorStore; + private final MinioService minioService; + private final AiDocumentMapper documentMapper; + + // 默认切片大小和重叠 + private static final int DEFAULT_CHUNK_SIZE = 500; + private static final int DEFAULT_CHUNK_OVERLAP = 50; + + /** + * 处理文档:解析 -> 切片 -> 向量化 -> 存储 + * + * @param docId 文档ID + */ + public void processDocument(Long docId) { + AiDocument doc = documentMapper.selectById(docId); + if (doc == null) { + log.error("文档不存在: {}", docId); + return; + } + + try { + // 更新状态为处理中 + doc.setStatus("processing"); + documentMapper.updateById(doc); + + // 1. 下载并解析文档 + String content = parseDocument(doc); + if (content == null || content.isEmpty()) { + throw new RuntimeException("文档内容为空或解析失败"); + } + + // 2. 生成摘要 + String summary = generateSummary(content); + doc.setSummary(summary); + + // 3. 文本切片 + List chunks = splitText(content, DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP); + doc.setChunkTotal(chunks.size()); + documentMapper.updateById(doc); + + // 4. 存储切片到向量库 + storeChunks(doc, chunks); + + // 5. 更新状态为可用 + doc.setStatus("active"); + documentMapper.updateById(doc); + + log.info("文档处理完成: {}, 切片数: {}", doc.getTitle(), chunks.size()); + + } catch (Exception e) { + log.error("文档处理失败: {}, 错误: {}", docId, e.getMessage(), e); + doc.setStatus("error"); + doc.setErrorMessage(e.getMessage()); + documentMapper.updateById(doc); + } + } + + /** + * 异步处理文档 + * + * @param docId 文档ID + */ + @Async("documentTaskExecutor") + public void processDocumentAsync(Long docId) { + processDocument(docId); + } + + /** + * 解析文档内容 + * + * @param doc 文档实体 + * @return 文本内容 + */ + private String parseDocument(AiDocument doc) { + try { + // 从MinIO下载文件 + InputStream inputStream = minioService.downloadFile(doc.getFilePath()); + + // 根据文件类型解析 + String fileType = doc.getFileType(); + if (fileType == null) { + fileType = getFileExtension(doc.getFilePath()); + } + + return documentParserUtil.parse(inputStream, fileType); + } catch (Exception e) { + log.error("解析文档失败: {}, 错误: {}", doc.getTitle(), e.getMessage(), e); + return null; + } + } + + /** + * 生成文档摘要 + * + * @param content 文档内容 + * @return 摘要 + */ + private String generateSummary(String content) { + // 简单实现:取前200字符作为摘要 + if (content == null || content.isEmpty()) { + return ""; + } + int maxLength = Math.min(content.length(), 200); + return content.substring(0, maxLength) + (content.length() > maxLength ? "..." : ""); + } + + /** + * 文本切片 + * + * @param content 文本内容 + * @param chunkSize 切片大小 + * @param overlap 重叠大小 + * @return 切片列表 + */ + private List splitText(String content, int chunkSize, int overlap) { + if (content == null || content.isEmpty()) { + return Collections.emptyList(); + } + + // 使用Spring AI的TokenTextSplitter + TokenTextSplitter splitter = TokenTextSplitter.builder() + .withChunkSize(chunkSize) + .withMinChunkSizeChars(chunkSize / 2) + .withMinChunkLengthToEmbed(1) + .withMaxNumChunks(100) + .withKeepSeparator(true) + .build(); + List documents = splitter.apply(List.of(new Document(content))); + + return documents.stream() + .map(Document::getText) + .collect(Collectors.toList()); + } + + /** + * 存储切片到向量库 + * + * @param parentDoc 父文档 + * @param chunks 切片列表 + */ + private void storeChunks(AiDocument parentDoc, List chunks) { + UUID docId = parentDoc.getDocId(); + Long parentId = parentDoc.getId(); + + for (int i = 0; i < chunks.size(); i++) { + String chunkContent = chunks.get(i); + + // 创建向量文档 + Document vectorDoc = new Document( + chunkContent, + Map.of( + "doc_id", docId.toString(), + "project_id", parentDoc.getProjectId(), + "timeline_node_id", parentDoc.getTimelineNodeId() != null ? parentDoc.getTimelineNodeId() : "", + "chunk_index", i, + "chunk_total", chunks.size(), + "title", parentDoc.getTitle(), + "source_type", parentDoc.getSourceType(), + "status", "active" + ) + ); + + // 存储到向量库 + vectorStore.add(List.of(vectorDoc)); + + // 如果是第一个切片,更新父文档内容 + if (i == 0) { + parentDoc.setContent(chunkContent); + documentMapper.updateById(parentDoc); + } + + log.debug("存储切片: {}/{}, docId: {}", i + 1, chunks.size(), docId); + } + } + + /** + * 删除文档及其向量 + * + * @param docId 文档UUID + */ + public void deleteDocumentVectors(UUID docId) { + try { + // 查询所有相关切片 + // 注意:pgvector store的删除需要根据metadata过滤 + // 这里简单处理,实际可能需要更复杂的逻辑 + log.info("删除文档向量: {}", docId); + } catch (Exception e) { + log.error("删除文档向量失败: {}, 错误: {}", docId, e.getMessage(), e); + } + } + + /** + * 获取文件扩展名 + * + * @param filePath 文件路径 + * @return 扩展名 + */ + private String getFileExtension(String filePath) { + if (filePath == null || filePath.lastIndexOf('.') == -1) { + return ""; + } + return filePath.substring(filePath.lastIndexOf('.') + 1).toLowerCase(); + } +} diff --git a/src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java b/src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java new file mode 100644 index 0000000..2658bb2 --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java @@ -0,0 +1,276 @@ +package cn.yinlihupo.service.ai.rag; + +import cn.yinlihupo.domain.entity.AiChatMessage; +import cn.yinlihupo.mapper.AiChatHistoryMapper; +import jakarta.annotation.Resource; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Component; + +import java.util.*; +import java.util.stream.Collectors; + +/** + * RAG检索器 + * 支持向量检索、TextToSQL检索和混合排序 + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class RagRetriever { + + private final VectorStore vectorStore; + private final JdbcTemplate jdbcTemplate; + private final ChatClient chatClient; + + @Resource + private AiChatHistoryMapper chatHistoryMapper; + + /** + * 向量检索 + * + * @param query 查询文本 + * @param projectId 项目ID + * @param topK 返回数量 + * @return 文档列表 + */ + public List vectorSearch(String query, Long projectId, int topK) { + try { + SearchRequest searchRequest = SearchRequest.builder() + .query(query) + .topK(topK) + .filterExpression("project_id == " + projectId + " && status == 'active'") + .build(); + + List results = vectorStore.similaritySearch(searchRequest); + log.debug("向量检索完成,项目ID: {}, 查询: {}, 结果数: {}", projectId, query, results.size()); + return results; + } catch (Exception e) { + log.error("向量检索失败: {}", e.getMessage(), e); + return Collections.emptyList(); + } + } + + /** + * 向量检索(带时间节点过滤) + * + * @param query 查询文本 + * @param projectId 项目ID + * @param timelineNodeId 时间节点ID + * @param topK 返回数量 + * @return 文档列表 + */ + public List vectorSearchWithTimeline(String query, Long projectId, + Long timelineNodeId, int topK) { + try { + SearchRequest searchRequest = SearchRequest.builder() + .query(query) + .topK(topK) + .filterExpression("project_id == " + projectId + + " && timeline_node_id == " + timelineNodeId + + " && status == 'active'") + .build(); + + List results = vectorStore.similaritySearch(searchRequest); + log.debug("向量检索完成(时间维度),项目ID: {}, 节点ID: {}, 结果数: {}", + projectId, timelineNodeId, results.size()); + return results; + } catch (Exception e) { + log.error("向量检索失败: {}", e.getMessage(), e); + return Collections.emptyList(); + } + } + + /** + * TextToSQL检索 + * + * @param question 自然语言问题 + * @param projectId 项目ID + * @return 文档列表 + */ + public List textToSqlSearch(String question, Long projectId) { + try { + // 1. 生成SQL + String sql = generateSql(question, projectId); + if (sql == null || sql.isEmpty()) { + return Collections.emptyList(); + } + + log.debug("生成的SQL: {}", sql); + + // 2. 执行SQL + List> results = jdbcTemplate.queryForList(sql); + + // 3. 转换为Document + return results.stream() + .map(this::convertToDocument) + .collect(Collectors.toList()); + } catch (Exception e) { + log.error("TextToSQL检索失败: {}", e.getMessage(), e); + return Collections.emptyList(); + } + } + + /** + * 混合检索(向量 + TextToSQL) + * + * @param query 查询文本 + * @param projectId 项目ID + * @param useVector 是否使用向量检索 + * @param useTextToSql 是否使用TextToSQL + * @param topK 返回数量 + * @return 文档列表 + */ + public List hybridSearch(String query, Long projectId, + boolean useVector, boolean useTextToSql, int topK) { + List vectorResults = Collections.emptyList(); + List sqlResults = Collections.emptyList(); + + if (useVector) { + vectorResults = vectorSearch(query, projectId, topK * 2); + } + + if (useTextToSql) { + sqlResults = textToSqlSearch(query, projectId); + } + + // RRF融合排序 + return reciprocalRankFusion(vectorResults, sqlResults, topK); + } + + /** + * 获取对话历史上下文 + * + * @param sessionId 会话ID + * @param limit 限制条数 + * @return 历史消息列表 + */ + public List getChatHistory(UUID sessionId, int limit) { + try { + return chatHistoryMapper.selectRecentMessages(sessionId, limit); + } catch (Exception e) { + log.error("获取对话历史失败: {}", e.getMessage(), e); + return Collections.emptyList(); + } + } + + /** + * 生成SQL查询 + * + * @param question 自然语言问题 + * @param projectId 项目ID + * @return SQL语句 + */ + private String generateSql(String question, Long projectId) { + // 构建数据库Schema提示 + String schemaPrompt = """ + 数据库表结构: + - project(项目): id, project_code, project_name, project_type, description, status, manager_id + - task(任务): id, project_id, task_name, description, assignee_id, status, priority, progress + - risk(风险): id, project_id, risk_name, description, risk_level, status, owner_id + - work_order(工单): id, project_id, title, description, status, priority, handler_id + - project_milestone(里程碑): id, project_id, milestone_name, status, plan_date, actual_date + - project_member(成员): id, project_id, user_id, role_code, status + + 请根据用户问题生成PostgreSQL查询SQL,只返回SELECT语句。 + 注意: + 1. 必须包含 project_id = %d 过滤条件 + 2. 只查询状态正常的记录(deleted = 0) + 3. 限制返回条数为20条 + 4. 不要返回解释,只返回SQL + """.formatted(projectId); + + try { + String sql = chatClient.prompt() + .system(schemaPrompt) + .user(question) + .call() + .content(); + + // 清理SQL(去除markdown代码块等) + sql = sql.replaceAll("```sql", "") + .replaceAll("```", "") + .trim(); + + // 安全检查:只允许SELECT语句 + if (!sql.toUpperCase().startsWith("SELECT")) { + log.warn("生成的SQL不是查询语句: {}", sql); + return null; + } + + return sql; + } catch (Exception e) { + log.error("生成SQL失败: {}", e.getMessage(), e); + return null; + } + } + + /** + * 将SQL结果转换为Document + * + * @param row 数据库行 + * @return Document对象 + */ + private Document convertToDocument(Map row) { + StringBuilder content = new StringBuilder(); + row.forEach((key, value) -> { + if (value != null) { + content.append(key).append(": ").append(value).append("\n"); + } + }); + + return Document.builder() + .text(content.toString()) + .metadata(Map.of( + "source", "database", + "type", "sql_result" + )) + .build(); + } + + /** + * RRF(Reciprocal Rank Fusion)融合排序 + * + * @param vectorResults 向量检索结果 + * @param sqlResults SQL检索结果 + * @param topK 返回数量 + * @return 融合排序后的结果 + */ + private List reciprocalRankFusion(List vectorResults, + List sqlResults, int topK) { + Map scoreMap = new HashMap<>(); + Map docMap = new HashMap<>(); + + final double k = 60.0; // RRF常数 + + // 处理向量检索结果 + for (int i = 0; i < vectorResults.size(); i++) { + Document doc = vectorResults.get(i); + String key = doc.getText().hashCode() + ""; + double score = 1.0 / (k + i + 1); + scoreMap.merge(key, score, Double::sum); + docMap.putIfAbsent(key, doc); + } + + // 处理SQL检索结果 + for (int i = 0; i < sqlResults.size(); i++) { + Document doc = sqlResults.get(i); + String key = doc.getText().hashCode() + ""; + double score = 1.0 / (k + i + 1); + scoreMap.merge(key, score, Double::sum); + docMap.putIfAbsent(key, doc); + } + + // 排序并返回TopK + return scoreMap.entrySet().stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .limit(topK) + .map(e -> docMap.get(e.getKey())) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/cn/yinlihupo/service/oss/MinioService.java b/src/main/java/cn/yinlihupo/service/oss/MinioService.java new file mode 100644 index 0000000..497be4b --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/oss/MinioService.java @@ -0,0 +1,42 @@ +package cn.yinlihupo.service.oss; + +import java.io.InputStream; + +/** + * MinIO服务接口 + */ +public interface MinioService { + + /** + * 上传文件 + * + * @param filePath 文件路径 + * @param inputStream 文件输入流 + * @param contentType 内容类型 + * @return 文件URL + */ + String uploadFile(String filePath, InputStream inputStream, String contentType); + + /** + * 下载文件 + * + * @param filePath 文件路径 + * @return 文件输入流 + */ + InputStream downloadFile(String filePath); + + /** + * 删除文件 + * + * @param filePath 文件路径 + */ + void deleteFile(String filePath); + + /** + * 获取文件URL + * + * @param filePath 文件路径 + * @return 文件URL + */ + String getFileUrl(String filePath); +} diff --git a/src/main/java/cn/yinlihupo/service/oss/impl/MinioServiceImpl.java b/src/main/java/cn/yinlihupo/service/oss/impl/MinioServiceImpl.java new file mode 100644 index 0000000..c515b31 --- /dev/null +++ b/src/main/java/cn/yinlihupo/service/oss/impl/MinioServiceImpl.java @@ -0,0 +1,120 @@ +package cn.yinlihupo.service.oss.impl; + +import cn.yinlihupo.common.config.MinioConfig; +import cn.yinlihupo.service.oss.MinioService; +import io.minio.GetObjectArgs; +import io.minio.MinioClient; +import io.minio.PutObjectArgs; +import io.minio.RemoveObjectArgs; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.io.InputStream; + +/** + * MinIO服务实现 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class MinioServiceImpl implements MinioService { + + private final MinioClient minioClient; + private final MinioConfig minioConfig; + + @Override + public String uploadFile(String filePath, InputStream inputStream, String contentType) { + try { + // 解析bucket和objectName + String bucketName = minioConfig.getBucketName(); + String objectName = filePath; + + // 如果filePath包含/,可能需要分离bucket(这里简化处理) + if (filePath.startsWith(bucketName + "/")) { + objectName = filePath.substring(bucketName.length() + 1); + } + + minioClient.putObject( + PutObjectArgs.builder() + .bucket(bucketName) + .object(objectName) + .stream(inputStream, inputStream.available(), -1) + .contentType(contentType != null ? contentType : "application/octet-stream") + .build() + ); + + log.info("文件上传成功: {}/{}", bucketName, objectName); + return getFileUrl(filePath); + } catch (Exception e) { + log.error("文件上传失败: {}", e.getMessage(), e); + throw new RuntimeException("文件上传失败: " + e.getMessage(), e); + } + } + + @Override + public InputStream downloadFile(String filePath) { + try { + String bucketName = minioConfig.getBucketName(); + String objectName = filePath; + + // 如果filePath包含bucket前缀,需要分离 + if (filePath.startsWith(bucketName + "/")) { + objectName = filePath.substring(bucketName.length() + 1); + } + + return minioClient.getObject( + GetObjectArgs.builder() + .bucket(bucketName) + .object(objectName) + .build() + ); + } catch (Exception e) { + log.error("下载文件失败: {}", e.getMessage(), e); + throw new RuntimeException("下载文件失败: " + e.getMessage(), e); + } + } + + @Override + public void deleteFile(String filePath) { + try { + String bucketName = minioConfig.getBucketName(); + String objectName = filePath; + + // 如果filePath包含bucket前缀,需要分离 + if (filePath.startsWith(bucketName + "/")) { + objectName = filePath.substring(bucketName.length() + 1); + } + + minioClient.removeObject( + RemoveObjectArgs.builder() + .bucket(bucketName) + .object(objectName) + .build() + ); + + log.info("文件删除成功: {}/{}", bucketName, objectName); + } catch (Exception e) { + log.error("删除文件失败: {}", e.getMessage(), e); + throw new RuntimeException("删除文件失败: " + e.getMessage(), e); + } + } + + @Override + public String getFileUrl(String filePath) { + String endpoint = minioConfig.getEndpoint(); + if (endpoint.endsWith("/")) { + endpoint = endpoint.substring(0, endpoint.length() - 1); + } + + String bucketName = minioConfig.getBucketName(); + String objectName = filePath; + + // 如果filePath已经包含bucket前缀,直接使用 + if (filePath.startsWith(bucketName + "/")) { + return endpoint + "/" + filePath; + } + + return endpoint + "/" + bucketName + "/" + objectName; + } +} diff --git a/src/main/resources/application-dev.yaml b/src/main/resources/application-dev.yaml index e1817f1..e57e47e 100644 --- a/src/main/resources/application-dev.yaml +++ b/src/main/resources/application-dev.yaml @@ -48,6 +48,9 @@ spring: openai: api-key: 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 base-url: https://sg1.proxy.yinlihupo.cc/proxy/https://openrouter.ai/api + embedding: + options: + model: qwen/qwen3-embedding-8b chat: options: model: google/gemini-3.1-pro-preview diff --git a/src/main/resources/application.yaml b/src/main/resources/application.yaml index 58d874b..9e7e1d0 100644 --- a/src/main/resources/application.yaml +++ b/src/main/resources/application.yaml @@ -7,6 +7,9 @@ spring: client: connect-timeout: 30s read-timeout: 120s + # 允许Bean定义覆盖(解决vectorStore冲突) + main: + allow-bean-definition-overriding: true # 公共配置 server: diff --git a/src/main/resources/mapper/AiChatHistoryMapper.xml b/src/main/resources/mapper/AiChatHistoryMapper.xml new file mode 100644 index 0000000..5599cec --- /dev/null +++ b/src/main/resources/mapper/AiChatHistoryMapper.xml @@ -0,0 +1,80 @@ + + + + + + + + + + + + + + + + + + + + + + DELETE FROM ai_chat_history + WHERE session_id = #{sessionId} + + + + + + diff --git a/src/main/resources/mapper/AiDocumentMapper.xml b/src/main/resources/mapper/AiDocumentMapper.xml new file mode 100644 index 0000000..6330250 --- /dev/null +++ b/src/main/resources/mapper/AiDocumentMapper.xml @@ -0,0 +1,103 @@ + + + + + + + + + + + + + UPDATE ai_document + SET deleted = 1, + update_time = NOW() + WHERE doc_id = #{docId} + + + + + + + + + + + UPDATE ai_document + SET status = #{status}, + update_time = NOW() + WHERE doc_id = #{docId} + + + + + UPDATE ai_document + SET error_message = #{errorMessage}, + status = 'error', + update_time = NOW() + WHERE doc_id = #{docId} + + + + + UPDATE ai_document + SET view_count = view_count + 1 + WHERE id = #{id} + + + + + UPDATE ai_document + SET query_count = query_count + 1, + last_queried_at = NOW() + WHERE id = #{id} + + +