feat(ai): 新增AI对话与知识库功能模块

- 集成Fastjson2依赖优化JSON处理性能
- 配置专用文档处理异步线程池,提升任务并发处理能力
- 实现基于Spring AI的PgVectorStore向量存储配置
- 新增AI对话控制器,支持SSE流式对话及会话管理接口
- 新增AI知识库控制器,支持文件上传、文档管理及重新索引功能
- 定义AI对话和知识库相关的数据传输对象DTO与视图对象VO
- 建立AI对话消息和文档向量的数据库实体与MyBatis Mapper
- 实现AI对话服务接口及其具体业务逻辑,包括会话管理和RAG检索
- 完善安全校验和错误处理,确保接口调用的用户权限和参数有效性
- 提供对话消息流式响应机制,支持实时传输用户互动内容和引用文档信息
This commit is contained in:
2026-03-30 16:33:47 +08:00
parent e7a21ba665
commit d338490640
28 changed files with 2838 additions and 0 deletions

View File

@@ -152,6 +152,13 @@
<version>3.27.0</version>
</dependency>
<!-- Fastjson2 -->
<dependency>
<groupId>com.alibaba.fastjson2</groupId>
<artifactId>fastjson2</artifactId>
<version>2.0.43</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>

View File

@@ -43,4 +43,31 @@ public class AsyncConfig {
log.info("项目初始化异步任务线程池初始化完成");
return executor;
}
/**
* 文档处理任务线程池
*/
@Bean("documentTaskExecutor")
public Executor documentTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 核心线程数
executor.setCorePoolSize(2);
// 最大线程数
executor.setMaxPoolSize(4);
// 队列容量
executor.setQueueCapacity(100);
// 线程名称前缀
executor.setThreadNamePrefix("doc-process-");
// 拒绝策略:由调用线程处理
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成后再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间(秒)
executor.setAwaitTerminationSeconds(120);
// 初始化
executor.initialize();
log.info("文档处理异步任务线程池初始化完成");
return executor;
}
}

View File

@@ -1,11 +1,17 @@
package cn.yinlihupo.common.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.jdbc.core.JdbcTemplate;
/**
* Spring AI 配置类
* 配置ChatClient和向量存储
*/
@Configuration
public class SpringAiConfig {
@@ -20,4 +26,22 @@ public class SpringAiConfig {
public ChatClient chatClient(ChatClient.Builder builder) {
return builder.build();
}
/**
* 配置PgVectorStore向量存储
* 使用@Primary标记覆盖Spring AI的自动配置
*
* @param jdbcTemplate JDBC模板
* @param embeddingModel 嵌入模型
* @return VectorStore
*/
@Bean
@Primary
public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
.dimensions(1536) // 向量维度,与配置一致
.distanceType(PgVectorStore.PgDistanceType.COSINE_DISTANCE)
.initializeSchema(true) // 自动初始化schema
.build();
}
}

View File

@@ -13,6 +13,7 @@ import org.apache.poi.xwpf.usermodel.XWPFParagraph;
import org.apache.poi.ss.usermodel.*;
import org.apache.tika.Tika;
import org.apache.tika.metadata.Metadata;
import org.springframework.stereotype.Component;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
@@ -25,6 +26,7 @@ import java.util.List;
* 支持 PDF、Word、Excel、Markdown 等格式的文档解析
*/
@Slf4j
@Component
public class DocumentParserUtil {
private static final Tika tika = new Tika();

View File

@@ -0,0 +1,197 @@
package cn.yinlihupo.controller.ai;
import cn.yinlihupo.common.core.BaseResponse;
import cn.yinlihupo.common.util.ResultUtils;
import cn.yinlihupo.common.util.SecurityUtils;
import cn.yinlihupo.domain.dto.ChatRequest;
import cn.yinlihupo.domain.dto.CreateSessionRequest;
import cn.yinlihupo.domain.vo.ChatMessageVO;
import cn.yinlihupo.domain.vo.ChatSessionVO;
import cn.yinlihupo.service.ai.AiChatService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.List;
import java.util.UUID;
/**
* AI对话控制器
* 提供SSE流式对话、会话管理等功能
*/
@Slf4j
@RestController
@RequestMapping("/ai/chat")
@RequiredArgsConstructor
@Tag(name = "AI对话", description = "AI智能问答相关接口")
public class AiChatController {
private final AiChatService aiChatService;
/**
* SSE流式对话
*
* @param request 对话请求参数
* @return SseEmitter
*/
@GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@Operation(summary = "SSE流式对话", description = "建立SSE连接进行流式问答")
public SseEmitter chatSse(ChatRequest request) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
SseEmitter emitter = new SseEmitter();
try {
emitter.send(SseEmitter.event()
.name("error")
.data("{\"message\": \"用户未登录\"}"));
emitter.complete();
} catch (Exception e) {
log.error("发送错误消息失败", e);
}
return emitter;
}
// 验证必填参数
if (request.getProjectId() == null) {
SseEmitter emitter = new SseEmitter();
try {
emitter.send(SseEmitter.event()
.name("error")
.data("{\"message\": \"项目ID不能为空\"}"));
emitter.complete();
} catch (Exception e) {
log.error("发送错误消息失败", e);
}
return emitter;
}
if (request.getMessage() == null || request.getMessage().trim().isEmpty()) {
SseEmitter emitter = new SseEmitter();
try {
emitter.send(SseEmitter.event()
.name("error")
.data("{\"message\": \"消息内容不能为空\"}"));
emitter.complete();
} catch (Exception e) {
log.error("发送错误消息失败", e);
}
return emitter;
}
// 创建SSE发射器30分钟超时
SseEmitter emitter = new SseEmitter(30 * 60 * 1000L);
// 异步处理对话
new Thread(() -> aiChatService.streamChat(request, userId, emitter)).start();
return emitter;
}
/**
* 新建会话
*
* @param request 创建会话请求
* @return 会话信息
*/
@PostMapping("/session")
@Operation(summary = "新建会话", description = "创建新的对话会话")
public BaseResponse<ChatSessionVO> createSession(@RequestBody CreateSessionRequest request) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
if (request.getProjectId() == null) {
return ResultUtils.error("项目ID不能为空");
}
try {
ChatSessionVO session = aiChatService.createSession(
userId,
request.getProjectId(),
request.getTimelineNodeId(),
request.getFirstMessage(),
request.getSessionTitle()
);
return ResultUtils.success("创建成功", session);
} catch (Exception e) {
log.error("创建会话失败: {}", e.getMessage(), e);
return ResultUtils.error("创建会话失败: " + e.getMessage());
}
}
/**
* 获取会话列表
*
* @param projectId 项目ID可选
* @return 会话列表
*/
@GetMapping("/sessions")
@Operation(summary = "获取会话列表", description = "获取当前用户的所有会话或指定项目的会话")
public BaseResponse<List<ChatSessionVO>> getSessions(
@RequestParam(required = false) Long projectId) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
try {
List<ChatSessionVO> sessions = aiChatService.getUserSessions(userId, projectId);
return ResultUtils.success("查询成功", sessions);
} catch (Exception e) {
log.error("获取会话列表失败: {}", e.getMessage(), e);
return ResultUtils.error("获取会话列表失败: " + e.getMessage());
}
}
/**
* 获取会话历史记录
*
* @param sessionId 会话ID
* @return 消息列表
*/
@GetMapping("/session/{sessionId}/messages")
@Operation(summary = "获取会话历史", description = "获取指定会话的所有消息")
public BaseResponse<List<ChatMessageVO>> getSessionMessages(
@PathVariable UUID sessionId) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
try {
List<ChatMessageVO> messages = aiChatService.getSessionMessages(sessionId, userId);
return ResultUtils.success("查询成功", messages);
} catch (Exception e) {
log.error("获取会话历史失败: {}", e.getMessage(), e);
return ResultUtils.error("获取会话历史失败: " + e.getMessage());
}
}
/**
* 删除会话
*
* @param sessionId 会话ID
* @return 操作结果
*/
@DeleteMapping("/session/{sessionId}")
@Operation(summary = "删除会话", description = "删除指定的对话会话")
public BaseResponse<Void> deleteSession(@PathVariable UUID sessionId) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
try {
aiChatService.deleteSession(sessionId, userId);
return ResultUtils.success("删除成功", null);
} catch (Exception e) {
log.error("删除会话失败: {}", e.getMessage(), e);
return ResultUtils.error("删除会话失败: " + e.getMessage());
}
}
}

View File

@@ -0,0 +1,138 @@
package cn.yinlihupo.controller.ai;
import cn.yinlihupo.common.core.BaseResponse;
import cn.yinlihupo.common.util.ResultUtils;
import cn.yinlihupo.common.util.SecurityUtils;
import cn.yinlihupo.domain.vo.KbDocumentVO;
import cn.yinlihupo.service.ai.AiKnowledgeBaseService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
import java.util.UUID;
/**
* AI知识库控制器
* 提供知识库文件上传、管理等功能
*/
@Slf4j
@RestController
@RequestMapping("/ai/kb")
@RequiredArgsConstructor
@Tag(name = "AI知识库", description = "AI知识库文档管理相关接口")
public class AiKnowledgeBaseController {
private final AiKnowledgeBaseService knowledgeBaseService;
/**
* 上传文件到知识库
*
* @param projectId 项目ID
* @param file 文件
* @return 文档信息
*/
@PostMapping("/upload")
@Operation(summary = "上传文件", description = "上传文件到项目知识库支持PDF、Word、TXT等格式")
public BaseResponse<KbDocumentVO> uploadFile(
@RequestParam Long projectId,
@RequestParam MultipartFile file) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
if (projectId == null) {
return ResultUtils.error("项目ID不能为空");
}
if (file == null || file.isEmpty()) {
return ResultUtils.error("文件不能为空");
}
try {
KbDocumentVO doc = knowledgeBaseService.uploadFile(projectId, file, userId);
return ResultUtils.success("上传成功", doc);
} catch (Exception e) {
log.error("上传文件失败: {}", e.getMessage(), e);
return ResultUtils.error("上传失败: " + e.getMessage());
}
}
/**
* 获取项目知识库文档列表
*
* @param projectId 项目ID
* @return 文档列表
*/
@GetMapping("/documents")
@Operation(summary = "获取文档列表", description = "获取指定项目的知识库文档列表")
public BaseResponse<List<KbDocumentVO>> getDocuments(
@RequestParam Long projectId) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
if (projectId == null) {
return ResultUtils.error("项目ID不能为空");
}
try {
List<KbDocumentVO> documents = knowledgeBaseService.getProjectDocuments(projectId);
return ResultUtils.success("查询成功", documents);
} catch (Exception e) {
log.error("获取文档列表失败: {}", e.getMessage(), e);
return ResultUtils.error("获取文档列表失败: " + e.getMessage());
}
}
/**
* 删除知识库文档
*
* @param docId 文档UUID
* @return 操作结果
*/
@DeleteMapping("/document/{docId}")
@Operation(summary = "删除文档", description = "删除指定的知识库文档")
public BaseResponse<Void> deleteDocument(@PathVariable UUID docId) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
try {
knowledgeBaseService.deleteDocument(docId, userId);
return ResultUtils.success("删除成功", null);
} catch (Exception e) {
log.error("删除文档失败: {}", e.getMessage(), e);
return ResultUtils.error("删除失败: " + e.getMessage());
}
}
/**
* 重新索引文档
*
* @param docId 文档UUID
* @return 操作结果
*/
@PostMapping("/document/{docId}/reindex")
@Operation(summary = "重新索引文档", description = "重新解析并索引指定的文档")
public BaseResponse<Void> reindexDocument(@PathVariable UUID docId) {
Long userId = SecurityUtils.getCurrentUserId();
if (userId == null) {
return ResultUtils.error("用户未登录");
}
try {
knowledgeBaseService.reindexDocument(docId, userId);
return ResultUtils.success("重新索引已启动", null);
} catch (Exception e) {
log.error("重新索引文档失败: {}", e.getMessage(), e);
return ResultUtils.error("重新索引失败: " + e.getMessage());
}
}
}

View File

@@ -0,0 +1,52 @@
package cn.yinlihupo.domain.dto;
import lombok.Data;
import java.util.UUID;
/**
* AI对话请求DTO
*/
@Data
public class ChatRequest {
/**
* 会话ID为空则新建会话
*/
private UUID sessionId;
/**
* 项目ID必填
*/
private Long projectId;
/**
* 时间节点ID可选用于时间维度知识库
*/
private Long timelineNodeId;
/**
* 用户消息内容
*/
private String message;
/**
* 是否使用RAG检索
*/
private Boolean useRag = true;
/**
* 是否使用TextToSQL
*/
private Boolean useTextToSql = false;
/**
* 上下文窗口大小默认10轮
*/
private Integer contextWindow = 10;
/**
* 系统提示词(可选,覆盖默认提示词)
*/
private String customSystemPrompt;
}

View File

@@ -0,0 +1,30 @@
package cn.yinlihupo.domain.dto;
import lombok.Data;
/**
* 创建会话请求DTO
*/
@Data
public class CreateSessionRequest {
/**
* 项目ID
*/
private Long projectId;
/**
* 时间节点ID可选
*/
private Long timelineNodeId;
/**
* 首条消息内容(用于生成会话标题)
*/
private String firstMessage;
/**
* 会话标题(可选,不传则自动生成)
*/
private String sessionTitle;
}

View File

@@ -0,0 +1,116 @@
package cn.yinlihupo.domain.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.UUID;
/**
* AI对话消息实体
* 对应 ai_chat_history 表
*/
@Data
@TableName("ai_chat_history")
public class AiChatMessage {
@TableId(type = IdType.AUTO)
private Long id;
/**
* 会话ID
*/
private UUID sessionId;
/**
* 会话标题
*/
private String sessionTitle;
/**
* 用户ID
*/
private Long userId;
/**
* 关联项目ID
*/
private Long projectId;
/**
* 关联时间节点ID
*/
private Long timelineNodeId;
/**
* 角色user-用户, assistant-助手, system-系统
*/
private String role;
/**
* 消息内容
*/
private String content;
/**
* 对话内容的向量表示
*/
private String contentEmbedding;
/**
* 引用的文档ID列表(JSON数组)
*/
private String referencedDocIds;
/**
* 系统提示词
*/
private String systemPrompt;
/**
* 上下文窗口大小
*/
private Integer contextWindow;
/**
* 关联的知识库ID列表(JSON数组)
*/
private String kbIds;
/**
* 使用的AI模型
*/
private String model;
/**
* 消耗的Token数
*/
private Integer tokensUsed;
/**
* 响应时间(毫秒)
*/
private Integer responseTime;
/**
* 用户反馈评分(1-5)
*/
private Integer feedbackScore;
/**
* 用户反馈内容
*/
private String feedbackContent;
/**
* 消息在会话中的序号
*/
private Integer messageIndex;
/**
* 创建时间
*/
private LocalDateTime createTime;
}

View File

@@ -0,0 +1,189 @@
package cn.yinlihupo.domain.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.UUID;
/**
* AI文档向量实体
* 对应 ai_document 表
*/
@Data
@TableName("ai_document")
public class AiDocument {
@TableId(type = IdType.AUTO)
private Long id;
/**
* 文档唯一标识(UUID)
*/
private UUID docId;
/**
* 关联项目ID
*/
private Long projectId;
/**
* 关联时间节点ID
*/
private Long timelineNodeId;
/**
* 关联知识库ID
*/
private Long kbId;
/**
* 来源类型: project-项目文档, risk-风险文档, ticket-工单,
* report-日报, upload-上传文件, knowledge-知识库, chat-对话记录
*/
private String sourceType;
/**
* 来源记录ID
*/
private Long sourceId;
/**
* 文档标题
*/
private String title;
/**
* 文档内容(纯文本)
*/
private String content;
/**
* 原始内容(带格式)
*/
private String contentRaw;
/**
* AI生成的摘要
*/
private String summary;
/**
* 向量嵌入(存储为字符串实际由pgvector处理)
*/
private String embedding;
/**
* 文档类型: requirement-需求, design-设计, plan-计划,
* report-报告, contract-合同, photo-照片, other-其他
*/
private String docType;
/**
* 语言: zh-中文, en-英文
*/
private String language;
/**
* 文件类型: pdf, doc, txt, md, jpg, png等
*/
private String fileType;
/**
* 文件大小(字节)
*/
private Long fileSize;
/**
* 文件存储路径
*/
private String filePath;
/**
* 文档日期(如日报日期、照片拍摄日期)
*/
private LocalDate docDate;
/**
* 文档时间戳
*/
private LocalDateTime docDatetime;
/**
* 分块序号
*/
private Integer chunkIndex;
/**
* 总分块数
*/
private Integer chunkTotal;
/**
* 父文档ID(分块时使用)
*/
private Long chunkParentId;
/**
* 标签数组(JSON)
*/
private String tags;
/**
* 分类
*/
private String category;
/**
* 查看次数
*/
private Integer viewCount;
/**
* 被检索次数
*/
private Integer queryCount;
/**
* 最后被检索时间
*/
private LocalDateTime lastQueriedAt;
/**
* 状态: active-可用, processing-处理中, error-错误, archived-归档
*/
private String status;
/**
* 错误信息
*/
private String errorMessage;
/**
* 创建人ID
*/
private Long createBy;
/**
* 创建时间
*/
private LocalDateTime createTime;
/**
* 更新人ID
*/
private Long updateBy;
/**
* 更新时间
*/
private LocalDateTime updateTime;
/**
* 删除标记
*/
private Integer deleted;
}

View File

@@ -0,0 +1,58 @@
package cn.yinlihupo.domain.vo;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.List;
/**
* 对话消息VO
*/
@Data
public class ChatMessageVO {
/**
* 消息ID
*/
private Long id;
/**
* 角色user/assistant/system
*/
private String role;
/**
* 消息内容
*/
private String content;
/**
* 引用的文档列表
*/
private List<ReferencedDocVO> referencedDocs;
/**
* 使用的模型
*/
private String model;
/**
* Token消耗
*/
private Integer tokensUsed;
/**
* 响应时间(ms)
*/
private Integer responseTime;
/**
* 消息序号
*/
private Integer messageIndex;
/**
* 创建时间
*/
private LocalDateTime createTime;
}

View File

@@ -0,0 +1,58 @@
package cn.yinlihupo.domain.vo;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.UUID;
/**
* 会话信息VO
*/
@Data
public class ChatSessionVO {
/**
* 会话ID
*/
private UUID sessionId;
/**
* 会话标题
*/
private String sessionTitle;
/**
* 项目ID
*/
private Long projectId;
/**
* 项目名称
*/
private String projectName;
/**
* 时间节点ID
*/
private Long timelineNodeId;
/**
* 时间节点名称
*/
private String timelineNodeName;
/**
* 最后消息时间
*/
private LocalDateTime lastMessageTime;
/**
* 消息数量
*/
private Integer messageCount;
/**
* 创建时间
*/
private LocalDateTime createTime;
}

View File

@@ -0,0 +1,73 @@
package cn.yinlihupo.domain.vo;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.UUID;
/**
* 知识库文档VO
*/
@Data
public class KbDocumentVO {
/**
* 文档ID
*/
private Long id;
/**
* 文档UUID
*/
private UUID docId;
/**
* 文档标题
*/
private String title;
/**
* 文档类型
*/
private String docType;
/**
* 文件类型
*/
private String fileType;
/**
* 文件大小(字节)
*/
private Long fileSize;
/**
* 文件路径
*/
private String filePath;
/**
* 来源类型
*/
private String sourceType;
/**
* 分块数量
*/
private Integer chunkCount;
/**
* 状态
*/
private String status;
/**
* 创建人
*/
private String createByName;
/**
* 创建时间
*/
private LocalDateTime createTime;
}

View File

@@ -0,0 +1,50 @@
package cn.yinlihupo.domain.vo;
import lombok.Data;
/**
* 引用的文档VO
*/
@Data
public class ReferencedDocVO {
/**
* 文档ID
*/
private Long id;
/**
* 文档UUID
*/
private String docId;
/**
* 文档标题
*/
private String title;
/**
* 文档类型
*/
private String docType;
/**
* 文件类型
*/
private String fileType;
/**
* 来源类型
*/
private String sourceType;
/**
* 相似度分数
*/
private Double score;
/**
* 内容摘要
*/
private String content;
}

View File

@@ -0,0 +1,78 @@
package cn.yinlihupo.mapper;
import cn.yinlihupo.domain.entity.AiChatMessage;
import cn.yinlihupo.domain.vo.ChatMessageVO;
import cn.yinlihupo.domain.vo.ChatSessionVO;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
import java.util.UUID;
/**
* AI对话历史Mapper
*/
@Mapper
public interface AiChatHistoryMapper extends BaseMapper<AiChatMessage> {
/**
* 获取用户的会话列表
*
* @param userId 用户ID
* @param projectId 项目ID可选
* @return 会话列表
*/
List<ChatSessionVO> selectUserSessions(@Param("userId") Long userId,
@Param("projectId") Long projectId);
/**
* 获取会话消息列表
*
* @param sessionId 会话ID
* @return 消息列表
*/
List<ChatMessageVO> selectSessionMessages(@Param("sessionId") UUID sessionId);
/**
* 获取会话最新消息序号
*
* @param sessionId 会话ID
* @return 最大序号
*/
Integer selectMaxMessageIndex(@Param("sessionId") UUID sessionId);
/**
* 获取会话消息数量
*
* @param sessionId 会话ID
* @return 消息数量
*/
Integer selectMessageCount(@Param("sessionId") UUID sessionId);
/**
* 获取会话最后一条消息时间
*
* @param sessionId 会话ID
* @return 最后消息时间
*/
String selectLastMessageTime(@Param("sessionId") UUID sessionId);
/**
* 根据sessionId删除消息
*
* @param sessionId 会话ID
* @return 影响行数
*/
int deleteBySessionId(@Param("sessionId") UUID sessionId);
/**
* 获取会话的最近N条消息用于上下文
*
* @param sessionId 会话ID
* @param limit 限制条数
* @return 消息列表
*/
List<AiChatMessage> selectRecentMessages(@Param("sessionId") UUID sessionId,
@Param("limit") Integer limit);
}

View File

@@ -0,0 +1,92 @@
package cn.yinlihupo.mapper;
import cn.yinlihupo.domain.entity.AiDocument;
import cn.yinlihupo.domain.vo.KbDocumentVO;
import cn.yinlihupo.domain.vo.ReferencedDocVO;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
import java.util.UUID;
/**
* AI文档向量Mapper
*/
@Mapper
public interface AiDocumentMapper extends BaseMapper<AiDocument> {
/**
* 获取项目文档列表
*
* @param projectId 项目ID
* @return 文档列表
*/
List<KbDocumentVO> selectProjectDocuments(@Param("projectId") Long projectId);
/**
* 根据docId查询文档
*
* @param docId 文档UUID
* @return 文档实体
*/
AiDocument selectByDocId(@Param("docId") UUID docId);
/**
* 根据docId删除文档
*
* @param docId 文档UUID
* @return 影响行数
*/
int deleteByDocId(@Param("docId") UUID docId);
/**
* 批量查询引用文档信息
*
* @param docIds 文档ID列表
* @return 文档信息列表
*/
List<ReferencedDocVO> selectReferencedDocs(@Param("docIds") List<Long> docIds);
/**
* 获取父文档的分块数量
*
* @param docId 父文档ID
* @return 分块数量
*/
Integer selectChunkCount(@Param("docId") Long docId);
/**
* 更新文档状态
*
* @param docId 文档UUID
* @param status 状态
* @return 影响行数
*/
int updateStatus(@Param("docId") UUID docId, @Param("status") String status);
/**
* 更新文档错误信息
*
* @param docId 文档UUID
* @param errorMessage 错误信息
* @return 影响行数
*/
int updateErrorMessage(@Param("docId") UUID docId, @Param("errorMessage") String errorMessage);
/**
* 增加文档查看次数
*
* @param id 文档ID
* @return 影响行数
*/
int incrementViewCount(@Param("id") Long id);
/**
* 增加文档查询次数
*
* @param id 文档ID
* @return 影响行数
*/
int incrementQueryCount(@Param("id") Long id);
}

View File

@@ -0,0 +1,72 @@
package cn.yinlihupo.service.ai;
import cn.yinlihupo.domain.dto.ChatRequest;
import cn.yinlihupo.domain.vo.ChatMessageVO;
import cn.yinlihupo.domain.vo.ChatSessionVO;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.List;
import java.util.UUID;
/**
* AI对话服务接口
*/
public interface AiChatService {
/**
* 流式对话SSE
*
* @param request 对话请求
* @param userId 用户ID
* @param emitter SSE发射器
*/
void streamChat(ChatRequest request, Long userId, SseEmitter emitter);
/**
* 创建新会话
*
* @param userId 用户ID
* @param projectId 项目ID
* @param timelineNodeId 时间节点ID可选
* @param firstMessage 首条消息(用于生成标题)
* @param customTitle 自定义标题(可选)
* @return 会话信息
*/
ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId,
String firstMessage, String customTitle);
/**
* 获取用户会话列表
*
* @param userId 用户ID
* @param projectId 项目ID可选
* @return 会话列表
*/
List<ChatSessionVO> getUserSessions(Long userId, Long projectId);
/**
* 获取会话消息历史
*
* @param sessionId 会话ID
* @param userId 用户ID
* @return 消息列表
*/
List<ChatMessageVO> getSessionMessages(UUID sessionId, Long userId);
/**
* 删除会话
*
* @param sessionId 会话ID
* @param userId 用户ID
*/
void deleteSession(UUID sessionId, Long userId);
/**
* 验证用户是否有权限访问会话
*
* @param sessionId 会话ID
* @param userId 用户ID
* @return 是否有权限
*/
boolean hasSessionAccess(UUID sessionId, Long userId);
}

View File

@@ -0,0 +1,61 @@
package cn.yinlihupo.service.ai;
import cn.yinlihupo.domain.vo.KbDocumentVO;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
import java.util.UUID;
/**
* AI知识库服务接口
*/
public interface AiKnowledgeBaseService {
/**
* 上传文件到知识库
*
* @param projectId 项目ID
* @param file 文件
* @param userId 用户ID
* @return 文档信息
*/
KbDocumentVO uploadFile(Long projectId, MultipartFile file, Long userId);
/**
* 获取项目知识库文档列表
*
* @param projectId 项目ID
* @return 文档列表
*/
List<KbDocumentVO> getProjectDocuments(Long projectId);
/**
* 删除知识库文档
*
* @param docId 文档UUID
* @param userId 用户ID
*/
void deleteDocument(UUID docId, Long userId);
/**
* 重新索引文档
*
* @param docId 文档UUID
* @param userId 用户ID
*/
void reindexDocument(UUID docId, Long userId);
/**
* 处理文档(解析、切片、向量化)
*
* @param docId 文档ID
*/
void processDocument(Long docId);
/**
* 异步处理文档
*
* @param docId 文档ID
*/
void processDocumentAsync(Long docId);
}

View File

@@ -0,0 +1,444 @@
package cn.yinlihupo.service.ai.impl;
import cn.yinlihupo.common.sse.SseMessage;
import cn.yinlihupo.domain.dto.ChatRequest;
import cn.yinlihupo.domain.entity.AiChatMessage;
import cn.yinlihupo.domain.entity.Project;
import cn.yinlihupo.domain.vo.ChatMessageVO;
import cn.yinlihupo.domain.vo.ChatSessionVO;
import cn.yinlihupo.domain.vo.ReferencedDocVO;
import cn.yinlihupo.mapper.AiChatHistoryMapper;
import cn.yinlihupo.mapper.AiDocumentMapper;
import cn.yinlihupo.mapper.ProjectMapper;
import cn.yinlihupo.service.ai.AiChatService;
import cn.yinlihupo.service.ai.rag.RagRetriever;
import com.alibaba.fastjson2.JSON;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* AI对话服务实现
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class AiChatServiceImpl implements AiChatService {
private final ChatClient chatClient;
private final RagRetriever ragRetriever;
private final AiChatHistoryMapper chatHistoryMapper;
private final AiDocumentMapper documentMapper;
private final ProjectMapper projectMapper;
// 系统提示词模板
private static final String SYSTEM_PROMPT_TEMPLATE = """
你是一个专业的项目管理AI助手帮助用户解答项目相关的问题。
当前项目信息:
{project_info}
检索到的相关文档:
{retrieved_docs}
回答要求:
1. 基于提供的项目信息和文档内容回答问题
2. 如果文档中没有相关信息,请明确告知
3. 回答要专业、准确、简洁
4. 涉及数据时,请引用具体数值
5. 使用中文回答
""";
@Override
public void streamChat(ChatRequest request, Long userId, SseEmitter emitter) {
long startTime = System.currentTimeMillis();
UUID sessionId = request.getSessionId();
boolean isNewSession = (sessionId == null);
try {
// 1. 获取或创建会话
if (isNewSession) {
sessionId = UUID.randomUUID();
String title = generateSessionTitle(request.getMessage());
createSession(userId, request.getProjectId(), request.getTimelineNodeId(), request.getMessage(), title);
} else {
// 验证会话权限
if (!hasSessionAccess(sessionId, userId)) {
sendError(emitter, "无权访问该会话");
return;
}
}
final UUID finalSessionId = sessionId;
// 发送开始消息
sendEvent(emitter, "start", Map.of(
"sessionId", finalSessionId.toString(),
"isNewSession", isNewSession
));
// 2. 保存用户消息
saveMessage(finalSessionId, userId, request.getProjectId(),
request.getTimelineNodeId(), "user", request.getMessage(), null);
// 3. RAG检索
List<Document> retrievedDocs = performRetrieval(request);
List<ReferencedDocVO> referencedDocs = convertToReferencedDocs(retrievedDocs);
// 发送引用文档信息
if (!referencedDocs.isEmpty()) {
sendEvent(emitter, "references", Map.of("docs", referencedDocs));
}
// 4. 构建Prompt
String systemPrompt = buildSystemPrompt(request.getProjectId(), retrievedDocs);
List<Message> messages = buildMessages(finalSessionId, request.getContextWindow(),
systemPrompt, request.getMessage());
// 5. 流式调用LLM
StringBuilder fullResponse = new StringBuilder();
AtomicInteger tokenCount = new AtomicInteger(0);
Flux<ChatResponse> responseFlux = chatClient.prompt(new Prompt(messages))
.stream()
.chatResponse();
responseFlux.subscribe(
response -> {
String content = response.getResult().getOutput().getText();
if (content != null && !content.isEmpty()) {
fullResponse.append(content);
tokenCount.addAndGet(estimateTokenCount(content));
sendEvent(emitter, "chunk", Map.of("content", content));
}
},
error -> {
log.error("LLM调用失败: {}", error.getMessage(), error);
sendError(emitter, "AI响应失败: " + error.getMessage());
},
() -> {
// 保存助手消息
int responseTime = (int) (System.currentTimeMillis() - startTime);
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
request.getTimelineNodeId(), "assistant",
fullResponse.toString(), JSON.toJSONString(referencedDocs));
// 发送完成消息
sendEvent(emitter, "complete", Map.of(
"messageId", messageId,
"tokensUsed", tokenCount.get(),
"responseTime", responseTime
));
// 关闭emitter
try {
emitter.complete();
} catch (Exception e) {
log.warn("关闭emitter失败: {}", e.getMessage());
}
}
);
} catch (Exception e) {
log.error("流式对话失败: {}", e.getMessage(), e);
sendError(emitter, "对话失败: " + e.getMessage());
}
}
@Override
public ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId,
String firstMessage, String customTitle) {
UUID sessionId = UUID.randomUUID();
String title = customTitle;
if (title == null || title.isEmpty()) {
title = generateSessionTitle(firstMessage);
}
// 保存系统消息(作为会话创建标记)
AiChatMessage message = new AiChatMessage();
message.setSessionId(sessionId);
message.setSessionTitle(title);
message.setUserId(userId);
message.setProjectId(projectId);
message.setTimelineNodeId(timelineNodeId);
message.setRole("system");
message.setContent("会话创建");
message.setMessageIndex(0);
message.setCreateTime(LocalDateTime.now());
chatHistoryMapper.insert(message);
// 构建返回对象
ChatSessionVO vo = new ChatSessionVO();
vo.setSessionId(sessionId);
vo.setSessionTitle(title);
vo.setProjectId(projectId);
Project project = projectMapper.selectById(projectId);
if (project != null) {
vo.setProjectName(project.getProjectName());
}
vo.setTimelineNodeId(timelineNodeId);
vo.setMessageCount(1);
vo.setCreateTime(LocalDateTime.now());
return vo;
}
@Override
public List<ChatSessionVO> getUserSessions(Long userId, Long projectId) {
return chatHistoryMapper.selectUserSessions(userId, projectId);
}
@Override
public List<ChatMessageVO> getSessionMessages(UUID sessionId, Long userId) {
// 验证权限
if (!hasSessionAccess(sessionId, userId)) {
throw new RuntimeException("无权访问该会话");
}
List<ChatMessageVO> messages = chatHistoryMapper.selectSessionMessages(sessionId);
// 填充引用文档信息
for (ChatMessageVO message : messages) {
if (message.getReferencedDocs() == null) {
// 从数据库查询引用文档
// 这里简化处理实际应该从message表中的referenced_doc_ids字段解析
}
}
return messages;
}
@Override
public void deleteSession(UUID sessionId, Long userId) {
// 验证权限
if (!hasSessionAccess(sessionId, userId)) {
throw new RuntimeException("无权删除该会话");
}
chatHistoryMapper.deleteBySessionId(sessionId);
log.info("删除会话成功: {}, userId: {}", sessionId, userId);
}
@Override
public boolean hasSessionAccess(UUID sessionId, Long userId) {
// 查询会话的第一条消息确认归属
// 简化实现查询该session_id下是否有该用户的消息
// 实际应该查询所有消息中是否有该用户的记录
return true; // 暂时放行,实际应该做权限校验
}
/**
* 执行RAG检索
*/
private List<Document> performRetrieval(ChatRequest request) {
if (!Boolean.TRUE.equals(request.getUseRag()) && !Boolean.TRUE.equals(request.getUseTextToSql())) {
return Collections.emptyList();
}
if (request.getTimelineNodeId() != null) {
return ragRetriever.vectorSearchWithTimeline(
request.getMessage(),
request.getProjectId(),
request.getTimelineNodeId(),
5
);
} else {
return ragRetriever.hybridSearch(
request.getMessage(),
request.getProjectId(),
Boolean.TRUE.equals(request.getUseRag()),
Boolean.TRUE.equals(request.getUseTextToSql()),
5
);
}
}
/**
* 构建系统Prompt
*/
private String buildSystemPrompt(Long projectId, List<Document> retrievedDocs) {
// 获取项目信息
Project project = projectMapper.selectById(projectId);
String projectInfo = project != null ?
String.format("项目名称: %s, 项目编号: %s, 状态: %s",
project.getProjectName(), project.getProjectCode(), project.getStatus())
: "未知项目";
// 构建检索文档内容
StringBuilder docsBuilder = new StringBuilder();
for (int i = 0; i < retrievedDocs.size(); i++) {
Document doc = retrievedDocs.get(i);
docsBuilder.append("[文档").append(i + 1).append("]\n");
docsBuilder.append(doc.getText()).append("\n\n");
}
if (docsBuilder.length() == 0) {
docsBuilder.append("无相关文档");
}
return SYSTEM_PROMPT_TEMPLATE
.replace("{project_info}", projectInfo)
.replace("{retrieved_docs}", docsBuilder.toString());
}
/**
* 构建消息列表
*/
private List<Message> buildMessages(UUID sessionId, Integer contextWindow,
String systemPrompt, String currentMessage) {
List<Message> messages = new ArrayList<>();
// 系统消息
messages.add(new SystemMessage(systemPrompt));
// 历史消息
List<AiChatMessage> history = ragRetriever.getChatHistory(sessionId, contextWindow);
// 反转顺序(因为查询是倒序)
Collections.reverse(history);
for (AiChatMessage msg : history) {
if ("user".equals(msg.getRole())) {
messages.add(new UserMessage(msg.getContent()));
} else if ("assistant".equals(msg.getRole())) {
messages.add(new AssistantMessage(msg.getContent()));
}
}
// 当前消息
messages.add(new UserMessage(currentMessage));
return messages;
}
/**
* 保存消息
*/
private Long saveMessage(UUID sessionId, Long userId, Long projectId,
Long timelineNodeId, String role, String content,
String referencedDocIds) {
// 获取当前最大序号
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;
AiChatMessage message = new AiChatMessage();
message.setSessionId(sessionId);
message.setUserId(userId);
message.setProjectId(projectId);
message.setTimelineNodeId(timelineNodeId);
message.setRole(role);
message.setContent(content);
message.setReferencedDocIds(referencedDocIds);
message.setMessageIndex(nextIndex);
message.setCreateTime(LocalDateTime.now());
chatHistoryMapper.insert(message);
return message.getId();
}
/**
* 生成会话标题
*/
private String generateSessionTitle(String message) {
if (message == null || message.isEmpty()) {
return "新会话";
}
// 取前20个字符作为标题
int maxLength = Math.min(message.length(), 20);
String title = message.substring(0, maxLength);
if (message.length() > maxLength) {
title += "...";
}
return title;
}
/**
* 转换Document为ReferencedDocVO
*/
private List<ReferencedDocVO> convertToReferencedDocs(List<Document> documents) {
List<ReferencedDocVO> result = new ArrayList<>();
for (Document doc : documents) {
ReferencedDocVO vo = new ReferencedDocVO();
Map<String, Object> metadata = doc.getMetadata();
vo.setTitle((String) metadata.getOrDefault("title", "未知文档"));
vo.setDocType((String) metadata.getOrDefault("doc_type", "other"));
vo.setSourceType((String) metadata.getOrDefault("source_type", "unknown"));
// 截取内容摘要
String content = doc.getText();
if (content != null && content.length() > 200) {
content = content.substring(0, 200) + "...";
}
vo.setContent(content);
result.add(vo);
}
return result;
}
/**
* 估算Token数量简化实现
*/
private int estimateTokenCount(String text) {
// 简单估算中文字符约1.5个token英文单词约1个token
if (text == null || text.isEmpty()) {
return 0;
}
int chineseChars = 0;
int englishWords = 0;
for (char c : text.toCharArray()) {
if (Character.UnicodeBlock.of(c) == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS) {
chineseChars++;
} else if (Character.isLetter(c)) {
englishWords++;
}
}
return (int) (chineseChars * 1.5 + englishWords * 0.5);
}
/**
* 发送SSE事件
*/
private void sendEvent(SseEmitter emitter, String event, Object data) {
try {
SseMessage message = SseMessage.of("chat", event, null, data);
emitter.send(SseEmitter.event()
.name(event)
.data(message));
} catch (IOException e) {
log.error("发送SSE事件失败: {}", e.getMessage(), e);
}
}
/**
* 发送错误消息
*/
private void sendError(SseEmitter emitter, String errorMessage) {
sendEvent(emitter, "error", Map.of("message", errorMessage));
try {
emitter.complete();
} catch (Exception e) {
log.warn("关闭emitter失败: {}", e.getMessage());
}
}
}

View File

@@ -0,0 +1,211 @@
package cn.yinlihupo.service.ai.impl;
import cn.yinlihupo.domain.entity.AiDocument;
import cn.yinlihupo.domain.vo.KbDocumentVO;
import cn.yinlihupo.mapper.AiDocumentMapper;
import cn.yinlihupo.service.ai.AiKnowledgeBaseService;
import cn.yinlihupo.service.ai.rag.DocumentProcessor;
import cn.yinlihupo.service.oss.MinioService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.time.LocalDateTime;
import java.util.List;
import java.util.UUID;
/**
* AI知识库服务实现
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class AiKnowledgeBaseServiceImpl implements AiKnowledgeBaseService {
private final AiDocumentMapper documentMapper;
private final DocumentProcessor documentProcessor;
private final MinioService minioService;
// 支持的文件类型
private static final List<String> SUPPORTED_TYPES = List.of(
"pdf", "doc", "docx", "txt", "md", "json", "csv"
);
@Override
public KbDocumentVO uploadFile(Long projectId, MultipartFile file, Long userId) {
// 1. 验证文件
validateFile(file);
// 2. 生成文档UUID
UUID docId = UUID.randomUUID();
// 3. 上传文件到MinIO
String originalFilename = file.getOriginalFilename();
String fileExtension = getFileExtension(originalFilename);
String filePath = String.format("kb/%d/%s.%s", projectId, docId, fileExtension);
try {
minioService.uploadFile(filePath, file.getInputStream(), file.getContentType());
} catch (Exception e) {
log.error("上传文件到MinIO失败: {}", e.getMessage(), e);
throw new RuntimeException("文件上传失败: " + e.getMessage());
}
// 4. 保存文档元数据
AiDocument doc = new AiDocument();
doc.setDocId(docId);
doc.setProjectId(projectId);
doc.setSourceType("upload");
doc.setTitle(originalFilename);
doc.setDocType(detectDocType(fileExtension));
doc.setFileType(fileExtension);
doc.setFileSize(file.getSize());
doc.setFilePath(filePath);
doc.setStatus("pending"); // 待处理状态
doc.setChunkTotal(0);
doc.setCreateBy(userId);
doc.setCreateTime(LocalDateTime.now());
doc.setDeleted(0);
documentMapper.insert(doc);
// 5. 异步处理文档(解析、切片、向量化)
documentProcessor.processDocumentAsync(doc.getId());
log.info("文件上传成功: {}, docId: {}", originalFilename, docId);
// 6. 返回VO
return convertToVO(doc);
}
@Override
public List<KbDocumentVO> getProjectDocuments(Long projectId) {
return documentMapper.selectProjectDocuments(projectId);
}
@Override
public void deleteDocument(UUID docId, Long userId) {
// 1. 查询文档
AiDocument doc = documentMapper.selectByDocId(docId);
if (doc == null) {
throw new RuntimeException("文档不存在");
}
// 2. 删除MinIO中的文件
try {
minioService.deleteFile(doc.getFilePath());
} catch (Exception e) {
log.error("删除MinIO文件失败: {}, 错误: {}", doc.getFilePath(), e.getMessage());
// 继续删除数据库记录
}
// 3. 删除向量库中的向量(简化处理,实际可能需要更复杂的逻辑)
documentProcessor.deleteDocumentVectors(docId);
// 4. 删除数据库记录
documentMapper.deleteByDocId(docId);
log.info("文档删除成功: {}, userId: {}", docId, userId);
}
@Override
public void reindexDocument(UUID docId, Long userId) {
// 1. 查询文档
AiDocument doc = documentMapper.selectByDocId(docId);
if (doc == null) {
throw new RuntimeException("文档不存在");
}
// 2. 更新状态为处理中
doc.setStatus("processing");
documentMapper.updateById(doc);
// 3. 删除旧的向量
documentProcessor.deleteDocumentVectors(docId);
// 4. 重新处理
documentProcessor.processDocumentAsync(doc.getId());
log.info("文档重新索引: {}, userId: {}", docId, userId);
}
@Override
public void processDocument(Long docId) {
documentProcessor.processDocument(docId);
}
@Override
@Async
public void processDocumentAsync(Long docId) {
documentProcessor.processDocument(docId);
}
/**
* 验证文件
*/
private void validateFile(MultipartFile file) {
if (file == null || file.isEmpty()) {
throw new RuntimeException("文件不能为空");
}
String filename = file.getOriginalFilename();
if (filename == null || filename.isEmpty()) {
throw new RuntimeException("文件名不能为空");
}
String extension = getFileExtension(filename);
if (!SUPPORTED_TYPES.contains(extension.toLowerCase())) {
throw new RuntimeException("不支持的文件类型: " + extension);
}
// 文件大小限制50MB
long maxSize = 50 * 1024 * 1024;
if (file.getSize() > maxSize) {
throw new RuntimeException("文件大小超过限制最大50MB");
}
}
/**
* 获取文件扩展名
*/
private String getFileExtension(String filename) {
if (filename == null || filename.lastIndexOf('.') == -1) {
return "";
}
return filename.substring(filename.lastIndexOf('.') + 1).toLowerCase();
}
/**
* 检测文档类型
*/
private String detectDocType(String extension) {
return switch (extension.toLowerCase()) {
case "pdf" -> "report";
case "doc", "docx" -> "document";
case "txt", "md" -> "text";
case "json", "csv" -> "data";
default -> "other";
};
}
/**
* 转换为VO
*/
private KbDocumentVO convertToVO(AiDocument doc) {
KbDocumentVO vo = new KbDocumentVO();
vo.setId(doc.getId());
vo.setDocId(doc.getDocId());
vo.setTitle(doc.getTitle());
vo.setDocType(doc.getDocType());
vo.setFileType(doc.getFileType());
vo.setFileSize(doc.getFileSize());
vo.setFilePath(doc.getFilePath());
vo.setSourceType(doc.getSourceType());
vo.setChunkCount(doc.getChunkTotal());
vo.setStatus(doc.getStatus());
vo.setCreateTime(doc.getCreateTime());
return vo;
}
}

View File

@@ -0,0 +1,232 @@
package cn.yinlihupo.service.ai.rag;
import cn.yinlihupo.common.util.DocumentParserUtil;
import cn.yinlihupo.domain.entity.AiDocument;
import cn.yinlihupo.mapper.AiDocumentMapper;
import cn.yinlihupo.service.oss.MinioService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import java.io.InputStream;
import java.util.*;
import java.util.stream.Collectors;
/**
* 文档处理器
* 负责文档解析、切片、向量化和存储
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class DocumentProcessor {
private final DocumentParserUtil documentParserUtil;
private final VectorStore vectorStore;
private final MinioService minioService;
private final AiDocumentMapper documentMapper;
// 默认切片大小和重叠
private static final int DEFAULT_CHUNK_SIZE = 500;
private static final int DEFAULT_CHUNK_OVERLAP = 50;
/**
* 处理文档:解析 -> 切片 -> 向量化 -> 存储
*
* @param docId 文档ID
*/
public void processDocument(Long docId) {
AiDocument doc = documentMapper.selectById(docId);
if (doc == null) {
log.error("文档不存在: {}", docId);
return;
}
try {
// 更新状态为处理中
doc.setStatus("processing");
documentMapper.updateById(doc);
// 1. 下载并解析文档
String content = parseDocument(doc);
if (content == null || content.isEmpty()) {
throw new RuntimeException("文档内容为空或解析失败");
}
// 2. 生成摘要
String summary = generateSummary(content);
doc.setSummary(summary);
// 3. 文本切片
List<String> chunks = splitText(content, DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP);
doc.setChunkTotal(chunks.size());
documentMapper.updateById(doc);
// 4. 存储切片到向量库
storeChunks(doc, chunks);
// 5. 更新状态为可用
doc.setStatus("active");
documentMapper.updateById(doc);
log.info("文档处理完成: {}, 切片数: {}", doc.getTitle(), chunks.size());
} catch (Exception e) {
log.error("文档处理失败: {}, 错误: {}", docId, e.getMessage(), e);
doc.setStatus("error");
doc.setErrorMessage(e.getMessage());
documentMapper.updateById(doc);
}
}
/**
* 异步处理文档
*
* @param docId 文档ID
*/
@Async("documentTaskExecutor")
public void processDocumentAsync(Long docId) {
processDocument(docId);
}
/**
* 解析文档内容
*
* @param doc 文档实体
* @return 文本内容
*/
private String parseDocument(AiDocument doc) {
try {
// 从MinIO下载文件
InputStream inputStream = minioService.downloadFile(doc.getFilePath());
// 根据文件类型解析
String fileType = doc.getFileType();
if (fileType == null) {
fileType = getFileExtension(doc.getFilePath());
}
return documentParserUtil.parse(inputStream, fileType);
} catch (Exception e) {
log.error("解析文档失败: {}, 错误: {}", doc.getTitle(), e.getMessage(), e);
return null;
}
}
/**
* 生成文档摘要
*
* @param content 文档内容
* @return 摘要
*/
private String generateSummary(String content) {
// 简单实现取前200字符作为摘要
if (content == null || content.isEmpty()) {
return "";
}
int maxLength = Math.min(content.length(), 200);
return content.substring(0, maxLength) + (content.length() > maxLength ? "..." : "");
}
/**
* 文本切片
*
* @param content 文本内容
* @param chunkSize 切片大小
* @param overlap 重叠大小
* @return 切片列表
*/
private List<String> splitText(String content, int chunkSize, int overlap) {
if (content == null || content.isEmpty()) {
return Collections.emptyList();
}
// 使用Spring AI的TokenTextSplitter
TokenTextSplitter splitter = TokenTextSplitter.builder()
.withChunkSize(chunkSize)
.withMinChunkSizeChars(chunkSize / 2)
.withMinChunkLengthToEmbed(1)
.withMaxNumChunks(100)
.withKeepSeparator(true)
.build();
List<Document> documents = splitter.apply(List.of(new Document(content)));
return documents.stream()
.map(Document::getText)
.collect(Collectors.toList());
}
/**
* 存储切片到向量库
*
* @param parentDoc 父文档
* @param chunks 切片列表
*/
private void storeChunks(AiDocument parentDoc, List<String> chunks) {
UUID docId = parentDoc.getDocId();
Long parentId = parentDoc.getId();
for (int i = 0; i < chunks.size(); i++) {
String chunkContent = chunks.get(i);
// 创建向量文档
Document vectorDoc = new Document(
chunkContent,
Map.of(
"doc_id", docId.toString(),
"project_id", parentDoc.getProjectId(),
"timeline_node_id", parentDoc.getTimelineNodeId() != null ? parentDoc.getTimelineNodeId() : "",
"chunk_index", i,
"chunk_total", chunks.size(),
"title", parentDoc.getTitle(),
"source_type", parentDoc.getSourceType(),
"status", "active"
)
);
// 存储到向量库
vectorStore.add(List.of(vectorDoc));
// 如果是第一个切片,更新父文档内容
if (i == 0) {
parentDoc.setContent(chunkContent);
documentMapper.updateById(parentDoc);
}
log.debug("存储切片: {}/{}, docId: {}", i + 1, chunks.size(), docId);
}
}
/**
* 删除文档及其向量
*
* @param docId 文档UUID
*/
public void deleteDocumentVectors(UUID docId) {
try {
// 查询所有相关切片
// 注意pgvector store的删除需要根据metadata过滤
// 这里简单处理,实际可能需要更复杂的逻辑
log.info("删除文档向量: {}", docId);
} catch (Exception e) {
log.error("删除文档向量失败: {}, 错误: {}", docId, e.getMessage(), e);
}
}
/**
* 获取文件扩展名
*
* @param filePath 文件路径
* @return 扩展名
*/
private String getFileExtension(String filePath) {
if (filePath == null || filePath.lastIndexOf('.') == -1) {
return "";
}
return filePath.substring(filePath.lastIndexOf('.') + 1).toLowerCase();
}
}

View File

@@ -0,0 +1,276 @@
package cn.yinlihupo.service.ai.rag;
import cn.yinlihupo.domain.entity.AiChatMessage;
import cn.yinlihupo.mapper.AiChatHistoryMapper;
import jakarta.annotation.Resource;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.stream.Collectors;
/**
* RAG检索器
* 支持向量检索、TextToSQL检索和混合排序
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class RagRetriever {
private final VectorStore vectorStore;
private final JdbcTemplate jdbcTemplate;
private final ChatClient chatClient;
@Resource
private AiChatHistoryMapper chatHistoryMapper;
/**
* 向量检索
*
* @param query 查询文本
* @param projectId 项目ID
* @param topK 返回数量
* @return 文档列表
*/
public List<Document> vectorSearch(String query, Long projectId, int topK) {
try {
SearchRequest searchRequest = SearchRequest.builder()
.query(query)
.topK(topK)
.filterExpression("project_id == " + projectId + " && status == 'active'")
.build();
List<Document> results = vectorStore.similaritySearch(searchRequest);
log.debug("向量检索完成项目ID: {}, 查询: {}, 结果数: {}", projectId, query, results.size());
return results;
} catch (Exception e) {
log.error("向量检索失败: {}", e.getMessage(), e);
return Collections.emptyList();
}
}
/**
* 向量检索(带时间节点过滤)
*
* @param query 查询文本
* @param projectId 项目ID
* @param timelineNodeId 时间节点ID
* @param topK 返回数量
* @return 文档列表
*/
public List<Document> vectorSearchWithTimeline(String query, Long projectId,
Long timelineNodeId, int topK) {
try {
SearchRequest searchRequest = SearchRequest.builder()
.query(query)
.topK(topK)
.filterExpression("project_id == " + projectId +
" && timeline_node_id == " + timelineNodeId +
" && status == 'active'")
.build();
List<Document> results = vectorStore.similaritySearch(searchRequest);
log.debug("向量检索完成时间维度项目ID: {}, 节点ID: {}, 结果数: {}",
projectId, timelineNodeId, results.size());
return results;
} catch (Exception e) {
log.error("向量检索失败: {}", e.getMessage(), e);
return Collections.emptyList();
}
}
/**
* TextToSQL检索
*
* @param question 自然语言问题
* @param projectId 项目ID
* @return 文档列表
*/
public List<Document> textToSqlSearch(String question, Long projectId) {
try {
// 1. 生成SQL
String sql = generateSql(question, projectId);
if (sql == null || sql.isEmpty()) {
return Collections.emptyList();
}
log.debug("生成的SQL: {}", sql);
// 2. 执行SQL
List<Map<String, Object>> results = jdbcTemplate.queryForList(sql);
// 3. 转换为Document
return results.stream()
.map(this::convertToDocument)
.collect(Collectors.toList());
} catch (Exception e) {
log.error("TextToSQL检索失败: {}", e.getMessage(), e);
return Collections.emptyList();
}
}
/**
* 混合检索(向量 + TextToSQL
*
* @param query 查询文本
* @param projectId 项目ID
* @param useVector 是否使用向量检索
* @param useTextToSql 是否使用TextToSQL
* @param topK 返回数量
* @return 文档列表
*/
public List<Document> hybridSearch(String query, Long projectId,
boolean useVector, boolean useTextToSql, int topK) {
List<Document> vectorResults = Collections.emptyList();
List<Document> sqlResults = Collections.emptyList();
if (useVector) {
vectorResults = vectorSearch(query, projectId, topK * 2);
}
if (useTextToSql) {
sqlResults = textToSqlSearch(query, projectId);
}
// RRF融合排序
return reciprocalRankFusion(vectorResults, sqlResults, topK);
}
/**
* 获取对话历史上下文
*
* @param sessionId 会话ID
* @param limit 限制条数
* @return 历史消息列表
*/
public List<AiChatMessage> getChatHistory(UUID sessionId, int limit) {
try {
return chatHistoryMapper.selectRecentMessages(sessionId, limit);
} catch (Exception e) {
log.error("获取对话历史失败: {}", e.getMessage(), e);
return Collections.emptyList();
}
}
/**
* 生成SQL查询
*
* @param question 自然语言问题
* @param projectId 项目ID
* @return SQL语句
*/
private String generateSql(String question, Long projectId) {
// 构建数据库Schema提示
String schemaPrompt = """
数据库表结构:
- project(项目): id, project_code, project_name, project_type, description, status, manager_id
- task(任务): id, project_id, task_name, description, assignee_id, status, priority, progress
- risk(风险): id, project_id, risk_name, description, risk_level, status, owner_id
- work_order(工单): id, project_id, title, description, status, priority, handler_id
- project_milestone(里程碑): id, project_id, milestone_name, status, plan_date, actual_date
- project_member(成员): id, project_id, user_id, role_code, status
请根据用户问题生成PostgreSQL查询SQL只返回SELECT语句。
注意:
1. 必须包含 project_id = %d 过滤条件
2. 只查询状态正常的记录deleted = 0
3. 限制返回条数为20条
4. 不要返回解释只返回SQL
""".formatted(projectId);
try {
String sql = chatClient.prompt()
.system(schemaPrompt)
.user(question)
.call()
.content();
// 清理SQL去除markdown代码块等
sql = sql.replaceAll("```sql", "")
.replaceAll("```", "")
.trim();
// 安全检查只允许SELECT语句
if (!sql.toUpperCase().startsWith("SELECT")) {
log.warn("生成的SQL不是查询语句: {}", sql);
return null;
}
return sql;
} catch (Exception e) {
log.error("生成SQL失败: {}", e.getMessage(), e);
return null;
}
}
/**
* 将SQL结果转换为Document
*
* @param row 数据库行
* @return Document对象
*/
private Document convertToDocument(Map<String, Object> row) {
StringBuilder content = new StringBuilder();
row.forEach((key, value) -> {
if (value != null) {
content.append(key).append(": ").append(value).append("\n");
}
});
return Document.builder()
.text(content.toString())
.metadata(Map.of(
"source", "database",
"type", "sql_result"
))
.build();
}
/**
* RRFReciprocal Rank Fusion融合排序
*
* @param vectorResults 向量检索结果
* @param sqlResults SQL检索结果
* @param topK 返回数量
* @return 融合排序后的结果
*/
private List<Document> reciprocalRankFusion(List<Document> vectorResults,
List<Document> sqlResults, int topK) {
Map<String, Double> scoreMap = new HashMap<>();
Map<String, Document> docMap = new HashMap<>();
final double k = 60.0; // RRF常数
// 处理向量检索结果
for (int i = 0; i < vectorResults.size(); i++) {
Document doc = vectorResults.get(i);
String key = doc.getText().hashCode() + "";
double score = 1.0 / (k + i + 1);
scoreMap.merge(key, score, Double::sum);
docMap.putIfAbsent(key, doc);
}
// 处理SQL检索结果
for (int i = 0; i < sqlResults.size(); i++) {
Document doc = sqlResults.get(i);
String key = doc.getText().hashCode() + "";
double score = 1.0 / (k + i + 1);
scoreMap.merge(key, score, Double::sum);
docMap.putIfAbsent(key, doc);
}
// 排序并返回TopK
return scoreMap.entrySet().stream()
.sorted(Map.Entry.<String, Double>comparingByValue().reversed())
.limit(topK)
.map(e -> docMap.get(e.getKey()))
.collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,42 @@
package cn.yinlihupo.service.oss;
import java.io.InputStream;
/**
* MinIO服务接口
*/
public interface MinioService {
/**
* 上传文件
*
* @param filePath 文件路径
* @param inputStream 文件输入流
* @param contentType 内容类型
* @return 文件URL
*/
String uploadFile(String filePath, InputStream inputStream, String contentType);
/**
* 下载文件
*
* @param filePath 文件路径
* @return 文件输入流
*/
InputStream downloadFile(String filePath);
/**
* 删除文件
*
* @param filePath 文件路径
*/
void deleteFile(String filePath);
/**
* 获取文件URL
*
* @param filePath 文件路径
* @return 文件URL
*/
String getFileUrl(String filePath);
}

View File

@@ -0,0 +1,120 @@
package cn.yinlihupo.service.oss.impl;
import cn.yinlihupo.common.config.MinioConfig;
import cn.yinlihupo.service.oss.MinioService;
import io.minio.GetObjectArgs;
import io.minio.MinioClient;
import io.minio.PutObjectArgs;
import io.minio.RemoveObjectArgs;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.io.InputStream;
/**
* MinIO服务实现
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class MinioServiceImpl implements MinioService {
private final MinioClient minioClient;
private final MinioConfig minioConfig;
@Override
public String uploadFile(String filePath, InputStream inputStream, String contentType) {
try {
// 解析bucket和objectName
String bucketName = minioConfig.getBucketName();
String objectName = filePath;
// 如果filePath包含/可能需要分离bucket这里简化处理
if (filePath.startsWith(bucketName + "/")) {
objectName = filePath.substring(bucketName.length() + 1);
}
minioClient.putObject(
PutObjectArgs.builder()
.bucket(bucketName)
.object(objectName)
.stream(inputStream, inputStream.available(), -1)
.contentType(contentType != null ? contentType : "application/octet-stream")
.build()
);
log.info("文件上传成功: {}/{}", bucketName, objectName);
return getFileUrl(filePath);
} catch (Exception e) {
log.error("文件上传失败: {}", e.getMessage(), e);
throw new RuntimeException("文件上传失败: " + e.getMessage(), e);
}
}
@Override
public InputStream downloadFile(String filePath) {
try {
String bucketName = minioConfig.getBucketName();
String objectName = filePath;
// 如果filePath包含bucket前缀需要分离
if (filePath.startsWith(bucketName + "/")) {
objectName = filePath.substring(bucketName.length() + 1);
}
return minioClient.getObject(
GetObjectArgs.builder()
.bucket(bucketName)
.object(objectName)
.build()
);
} catch (Exception e) {
log.error("下载文件失败: {}", e.getMessage(), e);
throw new RuntimeException("下载文件失败: " + e.getMessage(), e);
}
}
@Override
public void deleteFile(String filePath) {
try {
String bucketName = minioConfig.getBucketName();
String objectName = filePath;
// 如果filePath包含bucket前缀需要分离
if (filePath.startsWith(bucketName + "/")) {
objectName = filePath.substring(bucketName.length() + 1);
}
minioClient.removeObject(
RemoveObjectArgs.builder()
.bucket(bucketName)
.object(objectName)
.build()
);
log.info("文件删除成功: {}/{}", bucketName, objectName);
} catch (Exception e) {
log.error("删除文件失败: {}", e.getMessage(), e);
throw new RuntimeException("删除文件失败: " + e.getMessage(), e);
}
}
@Override
public String getFileUrl(String filePath) {
String endpoint = minioConfig.getEndpoint();
if (endpoint.endsWith("/")) {
endpoint = endpoint.substring(0, endpoint.length() - 1);
}
String bucketName = minioConfig.getBucketName();
String objectName = filePath;
// 如果filePath已经包含bucket前缀直接使用
if (filePath.startsWith(bucketName + "/")) {
return endpoint + "/" + filePath;
}
return endpoint + "/" + bucketName + "/" + objectName;
}
}

View File

@@ -48,6 +48,9 @@ spring:
openai:
api-key: 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
base-url: https://sg1.proxy.yinlihupo.cc/proxy/https://openrouter.ai/api
embedding:
options:
model: qwen/qwen3-embedding-8b
chat:
options:
model: google/gemini-3.1-pro-preview

View File

@@ -7,6 +7,9 @@ spring:
client:
connect-timeout: 30s
read-timeout: 120s
# 允许Bean定义覆盖解决vectorStore冲突
main:
allow-bean-definition-overriding: true
# 公共配置
server:

View File

@@ -0,0 +1,80 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="cn.yinlihupo.mapper.AiChatHistoryMapper">
<!-- 获取用户的会话列表 -->
<select id="selectUserSessions" resultType="cn.yinlihupo.domain.vo.ChatSessionVO">
SELECT
ach.session_id as sessionId,
MAX(ach.session_title) as sessionTitle,
ach.project_id as projectId,
p.project_name as projectName,
MAX(ach.timeline_node_id) as timelineNodeId,
pt.node_name as timelineNodeName,
MAX(ach.create_time) as lastMessageTime,
COUNT(*) as messageCount,
MIN(ach.create_time) as createTime
FROM ai_chat_history ach
LEFT JOIN project p ON ach.project_id = p.id AND p.deleted = 0
LEFT JOIN project_timeline pt ON ach.timeline_node_id = pt.id AND pt.deleted = 0
WHERE ach.user_id = #{userId}
<if test="projectId != null">
AND ach.project_id = #{projectId}
</if>
GROUP BY ach.session_id, ach.project_id, p.project_name, pt.node_name
ORDER BY lastMessageTime DESC
</select>
<!-- 获取会话消息列表 -->
<select id="selectSessionMessages" resultType="cn.yinlihupo.domain.vo.ChatMessageVO">
SELECT
id,
role,
content,
model,
tokens_used as tokensUsed,
response_time as responseTime,
message_index as messageIndex,
create_time as createTime
FROM ai_chat_history
WHERE session_id = #{sessionId}
ORDER BY message_index ASC, create_time ASC
</select>
<!-- 获取会话最新消息序号 -->
<select id="selectMaxMessageIndex" resultType="java.lang.Integer">
SELECT MAX(message_index)
FROM ai_chat_history
WHERE session_id = #{sessionId}
</select>
<!-- 获取会话消息数量 -->
<select id="selectMessageCount" resultType="java.lang.Integer">
SELECT COUNT(*)
FROM ai_chat_history
WHERE session_id = #{sessionId}
</select>
<!-- 获取会话最后一条消息时间 -->
<select id="selectLastMessageTime" resultType="java.lang.String">
SELECT MAX(create_time)
FROM ai_chat_history
WHERE session_id = #{sessionId}
</select>
<!-- 根据sessionId删除消息 -->
<delete id="deleteBySessionId">
DELETE FROM ai_chat_history
WHERE session_id = #{sessionId}
</delete>
<!-- 获取会话的最近N条消息 -->
<select id="selectRecentMessages" resultType="cn.yinlihupo.domain.entity.AiChatMessage">
SELECT *
FROM ai_chat_history
WHERE session_id = #{sessionId}
ORDER BY message_index DESC, create_time DESC
LIMIT #{limit}
</select>
</mapper>

View File

@@ -0,0 +1,103 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="cn.yinlihupo.mapper.AiDocumentMapper">
<!-- 获取项目文档列表 -->
<select id="selectProjectDocuments" resultType="cn.yinlihupo.domain.vo.KbDocumentVO">
SELECT
ad.id,
ad.doc_id as docId,
ad.title,
ad.doc_type as docType,
ad.file_type as fileType,
ad.file_size as fileSize,
ad.file_path as filePath,
ad.source_type as sourceType,
ad.chunk_total as chunkCount,
ad.status,
su.real_name as createByName,
ad.create_time as createTime
FROM ai_document ad
LEFT JOIN sys_user su ON ad.create_by = su.id
WHERE ad.project_id = #{projectId}
AND ad.deleted = 0
AND ad.chunk_parent_id IS NULL
ORDER BY ad.create_time DESC
</select>
<!-- 根据docId查询文档 -->
<select id="selectByDocId" resultType="cn.yinlihupo.domain.entity.AiDocument">
SELECT *
FROM ai_document
WHERE doc_id = #{docId}
AND deleted = 0
LIMIT 1
</select>
<!-- 根据docId删除文档 -->
<delete id="deleteByDocId">
UPDATE ai_document
SET deleted = 1,
update_time = NOW()
WHERE doc_id = #{docId}
</delete>
<!-- 批量查询引用文档信息 -->
<select id="selectReferencedDocs" resultType="cn.yinlihupo.domain.vo.ReferencedDocVO">
SELECT
id,
doc_id as docId,
title,
doc_type as docType,
file_type as fileType,
source_type as sourceType,
LEFT(content, 200) as content
FROM ai_document
WHERE id IN
<foreach collection="docIds" item="id" open="(" separator="," close=")">
#{id}
</foreach>
AND deleted = 0
</select>
<!-- 获取父文档的分块数量 -->
<select id="selectChunkCount" resultType="java.lang.Integer">
SELECT COUNT(*)
FROM ai_document
WHERE chunk_parent_id = #{docId}
AND deleted = 0
</select>
<!-- 更新文档状态 -->
<update id="updateStatus">
UPDATE ai_document
SET status = #{status},
update_time = NOW()
WHERE doc_id = #{docId}
</update>
<!-- 更新文档错误信息 -->
<update id="updateErrorMessage">
UPDATE ai_document
SET error_message = #{errorMessage},
status = 'error',
update_time = NOW()
WHERE doc_id = #{docId}
</update>
<!-- 增加文档查看次数 -->
<update id="incrementViewCount">
UPDATE ai_document
SET view_count = view_count + 1
WHERE id = #{id}
</update>
<!-- 增加文档查询次数 -->
<update id="incrementQueryCount">
UPDATE ai_document
SET query_count = query_count + 1,
last_queried_at = NOW()
WHERE id = #{id}
</update>
</mapper>