feat(ai-chat): 实现AI对话流式响应及会话管理功能
- 新增AiChatController,提供SSE流式对话、新建会话、获取会话列表、获取会话历史记录和删除会话接口 - 实现AiChatService及其Impl,支持多轮对话流式响应、RAG检索辅助回答、会话权限校验和消息持久化 - 定义AiChatMessage实体映射ai_chat_history表,存储对话消息及元信息 - 新增MyBatis映射文件AiChatHistoryMapper.xml,支持会话查询、消息查询和会话删除等数据库操作 - 添加ChatRequest、CreateSessionRequest、ChatSessionVO等数据传输对象,规范接口参数与返回 - 实现DocumentProcessor用于文档解析、切片、向量化及存储,支持异步处理文档以供知识检索 - 设计系统提示词模板,结合项目及检索文档内容构建上下文,提高AI回答准确性和专业性 - 增加错误处理和权限验证,确保用户只能访问及操作自己的会话 - 优化流式响应的SSE事件推送,支持开始、块数据、引用文档、错误和完成事件通知 - 实现Token估算方法,统计交流token数用于统计和性能监控
This commit is contained in:
@@ -112,7 +112,6 @@ public class AiChatController {
|
||||
ChatSessionVO session = aiChatService.createSession(
|
||||
userId,
|
||||
request.getProjectId(),
|
||||
request.getTimelineNodeId(),
|
||||
request.getFirstMessage(),
|
||||
request.getSessionTitle()
|
||||
);
|
||||
|
||||
@@ -18,11 +18,6 @@ public class ChatRequest {
|
||||
*/
|
||||
private Long projectId;
|
||||
|
||||
/**
|
||||
* 时间节点ID(可选,用于时间维度知识库)
|
||||
*/
|
||||
private Long timelineNodeId;
|
||||
|
||||
/**
|
||||
* 用户消息内容
|
||||
*/
|
||||
|
||||
@@ -13,11 +13,6 @@ public class CreateSessionRequest {
|
||||
*/
|
||||
private Long projectId;
|
||||
|
||||
/**
|
||||
* 时间节点ID(可选)
|
||||
*/
|
||||
private Long timelineNodeId;
|
||||
|
||||
/**
|
||||
* 首条消息内容(用于生成会话标题)
|
||||
*/
|
||||
|
||||
@@ -39,11 +39,6 @@ public class AiChatMessage {
|
||||
*/
|
||||
private Long projectId;
|
||||
|
||||
/**
|
||||
* 关联时间节点ID
|
||||
*/
|
||||
private Long timelineNodeId;
|
||||
|
||||
/**
|
||||
* 角色:user-用户, assistant-助手, system-系统
|
||||
*/
|
||||
|
||||
@@ -29,11 +29,6 @@ public class AiDocument {
|
||||
*/
|
||||
private Long projectId;
|
||||
|
||||
/**
|
||||
* 关联时间节点ID
|
||||
*/
|
||||
private Long timelineNodeId;
|
||||
|
||||
/**
|
||||
* 关联知识库ID
|
||||
*/
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
package cn.yinlihupo.domain.entity;
|
||||
|
||||
import cn.yinlihupo.common.handler.JsonbTypeHandler;
|
||||
import com.baomidou.mybatisplus.annotation.*;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 项目时间节点实体类
|
||||
* 对应数据库表: project_timeline
|
||||
*/
|
||||
@Data
|
||||
@TableName("project_timeline")
|
||||
public class ProjectTimeline {
|
||||
|
||||
@TableId(type = IdType.ASSIGN_ID)
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 项目ID
|
||||
*/
|
||||
private Long projectId;
|
||||
|
||||
/**
|
||||
* 节点名称
|
||||
*/
|
||||
private String nodeName;
|
||||
|
||||
/**
|
||||
* 节点类型: phase-阶段, milestone-里程碑, event-事件, checkpoint-检查点
|
||||
*/
|
||||
private String nodeType;
|
||||
|
||||
/**
|
||||
* 父节点ID
|
||||
*/
|
||||
private Long parentId;
|
||||
|
||||
/**
|
||||
* 计划日期
|
||||
*/
|
||||
private LocalDate planDate;
|
||||
|
||||
/**
|
||||
* 实际日期
|
||||
*/
|
||||
private LocalDate actualDate;
|
||||
|
||||
/**
|
||||
* 描述
|
||||
*/
|
||||
private String description;
|
||||
|
||||
/**
|
||||
* 状态: pending-待开始, in_progress-进行中, completed-已完成, delayed-延期
|
||||
*/
|
||||
private String status;
|
||||
|
||||
/**
|
||||
* 排序
|
||||
*/
|
||||
private Integer sortOrder;
|
||||
|
||||
/**
|
||||
* 知识库范围配置["report","file","risk","ticket"]
|
||||
*/
|
||||
@TableField(typeHandler = JsonbTypeHandler.class)
|
||||
private List<String> kbScope;
|
||||
|
||||
/**
|
||||
* 扩展数据
|
||||
*/
|
||||
@TableField(typeHandler = JsonbTypeHandler.class)
|
||||
private Object extraData;
|
||||
|
||||
/**
|
||||
* 创建时间
|
||||
*/
|
||||
@TableField(fill = FieldFill.INSERT)
|
||||
private LocalDateTime createTime;
|
||||
|
||||
/**
|
||||
* 更新时间
|
||||
*/
|
||||
@TableField(fill = FieldFill.INSERT_UPDATE)
|
||||
private LocalDateTime updateTime;
|
||||
|
||||
/**
|
||||
* 删除标记
|
||||
*/
|
||||
@TableLogic
|
||||
private Integer deleted;
|
||||
}
|
||||
@@ -30,16 +30,6 @@ public class ChatSessionVO {
|
||||
*/
|
||||
private String projectName;
|
||||
|
||||
/**
|
||||
* 时间节点ID
|
||||
*/
|
||||
private Long timelineNodeId;
|
||||
|
||||
/**
|
||||
* 时间节点名称
|
||||
*/
|
||||
private String timelineNodeName;
|
||||
|
||||
/**
|
||||
* 最后消息时间
|
||||
*/
|
||||
|
||||
@@ -586,55 +586,4 @@ public class ProjectDetailVO {
|
||||
*/
|
||||
private LocalDateTime discoverTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* 时间节点信息
|
||||
*/
|
||||
@Data
|
||||
public static class TimelineInfo {
|
||||
/**
|
||||
* 节点ID
|
||||
*/
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 节点名称
|
||||
*/
|
||||
private String nodeName;
|
||||
|
||||
/**
|
||||
* 节点类型
|
||||
*/
|
||||
private String nodeType;
|
||||
|
||||
/**
|
||||
* 计划日期
|
||||
*/
|
||||
private LocalDate planDate;
|
||||
|
||||
/**
|
||||
* 实际日期
|
||||
*/
|
||||
private LocalDate actualDate;
|
||||
|
||||
/**
|
||||
* 描述
|
||||
*/
|
||||
private String description;
|
||||
|
||||
/**
|
||||
* 状态
|
||||
*/
|
||||
private String status;
|
||||
|
||||
/**
|
||||
* 排序
|
||||
*/
|
||||
private Integer sortOrder;
|
||||
|
||||
/**
|
||||
* 知识库范围
|
||||
*/
|
||||
private List<String> kbScope;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,13 +26,11 @@ public interface AiChatService {
|
||||
*
|
||||
* @param userId 用户ID
|
||||
* @param projectId 项目ID
|
||||
* @param timelineNodeId 时间节点ID(可选)
|
||||
* @param firstMessage 首条消息(用于生成标题)
|
||||
* @param customTitle 自定义标题(可选)
|
||||
* @return 会话信息
|
||||
*/
|
||||
ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId,
|
||||
String firstMessage, String customTitle);
|
||||
ChatSessionVO createSession(Long userId, Long projectId, String firstMessage, String customTitle);
|
||||
|
||||
/**
|
||||
* 获取用户会话列表
|
||||
|
||||
@@ -75,7 +75,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
if (isNewSession) {
|
||||
sessionId = UUID.randomUUID().toString();
|
||||
String title = generateSessionTitle(request.getMessage());
|
||||
createSession(userId, request.getProjectId(), request.getTimelineNodeId(), request.getMessage(), title);
|
||||
createSession(userId, request.getProjectId(), request.getMessage(), title);
|
||||
} else {
|
||||
// 验证会话权限
|
||||
if (!hasSessionAccess(sessionId, userId)) {
|
||||
@@ -93,8 +93,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
));
|
||||
|
||||
// 2. 保存用户消息
|
||||
saveMessage(finalSessionId, userId, request.getProjectId(),
|
||||
request.getTimelineNodeId(), "user", request.getMessage(), null);
|
||||
saveMessage(finalSessionId, userId, request.getProjectId(), "user", request.getMessage(), null);
|
||||
|
||||
// 3. RAG检索
|
||||
List<Document> retrievedDocs = performRetrieval(request);
|
||||
@@ -140,7 +139,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
.filter(id -> !id.isEmpty())
|
||||
.toArray(String[]::new);
|
||||
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
|
||||
request.getTimelineNodeId(), "assistant",
|
||||
"assistant",
|
||||
fullResponse.toString(), docIds);
|
||||
|
||||
// 发送完成消息
|
||||
@@ -166,8 +165,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId,
|
||||
String firstMessage, String customTitle) {
|
||||
public ChatSessionVO createSession(Long userId, Long projectId, String firstMessage, String customTitle) {
|
||||
UUID sessionId = UUID.randomUUID();
|
||||
String sessionIdStr = sessionId.toString();
|
||||
String title = customTitle;
|
||||
@@ -182,7 +180,6 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
message.setSessionTitle(title);
|
||||
message.setUserId(userId);
|
||||
message.setProjectId(projectId);
|
||||
message.setTimelineNodeId(timelineNodeId);
|
||||
message.setRole("system");
|
||||
message.setContent("会话创建");
|
||||
message.setMessageIndex(0);
|
||||
@@ -200,7 +197,6 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
vo.setProjectName(project.getProjectName());
|
||||
}
|
||||
|
||||
vo.setTimelineNodeId(timelineNodeId);
|
||||
vo.setMessageCount(1);
|
||||
vo.setCreateTime(LocalDateTime.now());
|
||||
|
||||
@@ -258,23 +254,13 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
if (!Boolean.TRUE.equals(request.getUseRag()) && !Boolean.TRUE.equals(request.getUseTextToSql())) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
if (request.getTimelineNodeId() != null) {
|
||||
return ragRetriever.vectorSearchWithTimeline(
|
||||
request.getMessage(),
|
||||
request.getProjectId(),
|
||||
request.getTimelineNodeId(),
|
||||
5
|
||||
);
|
||||
} else {
|
||||
return ragRetriever.hybridSearch(
|
||||
request.getMessage(),
|
||||
request.getProjectId(),
|
||||
Boolean.TRUE.equals(request.getUseRag()),
|
||||
Boolean.TRUE.equals(request.getUseTextToSql()),
|
||||
5
|
||||
);
|
||||
}
|
||||
return ragRetriever.hybridSearch(
|
||||
request.getMessage(),
|
||||
request.getProjectId(),
|
||||
Boolean.TRUE.equals(request.getUseRag()),
|
||||
Boolean.TRUE.equals(request.getUseTextToSql()),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -338,8 +324,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
* 保存消息
|
||||
*/
|
||||
private Long saveMessage(String sessionId, Long userId, Long projectId,
|
||||
Long timelineNodeId, String role, String content,
|
||||
String[] referencedDocIds) {
|
||||
String role, String content, String[] referencedDocIds) {
|
||||
// 获取当前最大序号
|
||||
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
|
||||
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;
|
||||
@@ -348,7 +333,6 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
message.setSessionId(sessionId);
|
||||
message.setUserId(userId);
|
||||
message.setProjectId(projectId);
|
||||
message.setTimelineNodeId(timelineNodeId);
|
||||
message.setRole(role);
|
||||
message.setContent(content);
|
||||
message.setReferencedDocIds(referencedDocIds);
|
||||
|
||||
@@ -183,7 +183,6 @@ public class DocumentProcessor {
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
// 项目关联属性(用于检索过滤)
|
||||
metadata.put("project_id", doc.getProjectId() != null ? doc.getProjectId().toString() : "");
|
||||
metadata.put("timeline_node_id", doc.getTimelineNodeId() != null ? doc.getTimelineNodeId().toString() : "");
|
||||
metadata.put("kb_id", doc.getKbId() != null ? doc.getKbId().toString() : "");
|
||||
// 文档来源信息
|
||||
metadata.put("source_type", doc.getSourceType() != null ? doc.getSourceType() : "");
|
||||
|
||||
@@ -57,37 +57,6 @@ public class RagRetriever {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 向量检索(带时间节点过滤)
|
||||
*
|
||||
* @param query 查询文本
|
||||
* @param projectId 项目ID
|
||||
* @param timelineNodeId 时间节点ID
|
||||
* @param topK 返回数量
|
||||
* @return 文档列表
|
||||
*/
|
||||
public List<Document> vectorSearchWithTimeline(String query, Long projectId,
|
||||
Long timelineNodeId, int topK) {
|
||||
try {
|
||||
// project_id 和 timeline_node_id 在 metadata 中是字符串类型
|
||||
SearchRequest searchRequest = SearchRequest.builder()
|
||||
.query(query)
|
||||
.topK(topK)
|
||||
.filterExpression("project_id == '" + projectId +
|
||||
"' && timeline_node_id == '" + timelineNodeId +
|
||||
"' && status == 'active'")
|
||||
.build();
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
||||
log.debug("向量检索完成(时间维度),项目ID: {}, 节点ID: {}, 结果数: {}",
|
||||
projectId, timelineNodeId, results.size());
|
||||
return results;
|
||||
} catch (Exception e) {
|
||||
log.error("向量检索失败: {}", e.getMessage(), e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* TextToSQL检索
|
||||
*
|
||||
|
||||
@@ -9,19 +9,16 @@
|
||||
MAX(ach.session_title) as sessionTitle,
|
||||
ach.project_id as projectId,
|
||||
p.project_name as projectName,
|
||||
MAX(ach.timeline_node_id) as timelineNodeId,
|
||||
pt.node_name as timelineNodeName,
|
||||
MAX(ach.create_time) as lastMessageTime,
|
||||
COUNT(*) as messageCount,
|
||||
MIN(ach.create_time) as createTime
|
||||
FROM ai_chat_history ach
|
||||
LEFT JOIN project p ON ach.project_id = p.id AND p.deleted = 0
|
||||
LEFT JOIN project_timeline pt ON ach.timeline_node_id = pt.id AND pt.deleted = 0
|
||||
WHERE ach.user_id = #{userId}
|
||||
<if test="projectId != null">
|
||||
AND ach.project_id = #{projectId}
|
||||
</if>
|
||||
GROUP BY ach.session_id, ach.project_id, p.project_name, pt.node_name
|
||||
GROUP BY ach.session_id, ach.project_id, p.project_name
|
||||
ORDER BY lastMessageTime DESC
|
||||
</select>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user