feat(ai-chat): 实现AI对话流式响应及会话管理功能
- 新增AiChatController,提供SSE流式对话、新建会话、获取会话列表、获取会话历史记录和删除会话接口 - 实现AiChatService及其Impl,支持多轮对话流式响应、RAG检索辅助回答、会话权限校验和消息持久化 - 定义AiChatMessage实体映射ai_chat_history表,存储对话消息及元信息 - 新增MyBatis映射文件AiChatHistoryMapper.xml,支持会话查询、消息查询和会话删除等数据库操作 - 添加ChatRequest、CreateSessionRequest、ChatSessionVO等数据传输对象,规范接口参数与返回 - 实现DocumentProcessor用于文档解析、切片、向量化及存储,支持异步处理文档以供知识检索 - 设计系统提示词模板,结合项目及检索文档内容构建上下文,提高AI回答准确性和专业性 - 增加错误处理和权限验证,确保用户只能访问及操作自己的会话 - 优化流式响应的SSE事件推送,支持开始、块数据、引用文档、错误和完成事件通知 - 实现Token估算方法,统计交流token数用于统计和性能监控
This commit is contained in:
@@ -112,7 +112,6 @@ public class AiChatController {
|
|||||||
ChatSessionVO session = aiChatService.createSession(
|
ChatSessionVO session = aiChatService.createSession(
|
||||||
userId,
|
userId,
|
||||||
request.getProjectId(),
|
request.getProjectId(),
|
||||||
request.getTimelineNodeId(),
|
|
||||||
request.getFirstMessage(),
|
request.getFirstMessage(),
|
||||||
request.getSessionTitle()
|
request.getSessionTitle()
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -18,11 +18,6 @@ public class ChatRequest {
|
|||||||
*/
|
*/
|
||||||
private Long projectId;
|
private Long projectId;
|
||||||
|
|
||||||
/**
|
|
||||||
* 时间节点ID(可选,用于时间维度知识库)
|
|
||||||
*/
|
|
||||||
private Long timelineNodeId;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 用户消息内容
|
* 用户消息内容
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -13,11 +13,6 @@ public class CreateSessionRequest {
|
|||||||
*/
|
*/
|
||||||
private Long projectId;
|
private Long projectId;
|
||||||
|
|
||||||
/**
|
|
||||||
* 时间节点ID(可选)
|
|
||||||
*/
|
|
||||||
private Long timelineNodeId;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 首条消息内容(用于生成会话标题)
|
* 首条消息内容(用于生成会话标题)
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -39,11 +39,6 @@ public class AiChatMessage {
|
|||||||
*/
|
*/
|
||||||
private Long projectId;
|
private Long projectId;
|
||||||
|
|
||||||
/**
|
|
||||||
* 关联时间节点ID
|
|
||||||
*/
|
|
||||||
private Long timelineNodeId;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 角色:user-用户, assistant-助手, system-系统
|
* 角色:user-用户, assistant-助手, system-系统
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -29,11 +29,6 @@ public class AiDocument {
|
|||||||
*/
|
*/
|
||||||
private Long projectId;
|
private Long projectId;
|
||||||
|
|
||||||
/**
|
|
||||||
* 关联时间节点ID
|
|
||||||
*/
|
|
||||||
private Long timelineNodeId;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 关联知识库ID
|
* 关联知识库ID
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
package cn.yinlihupo.domain.entity;
|
|
||||||
|
|
||||||
import cn.yinlihupo.common.handler.JsonbTypeHandler;
|
|
||||||
import com.baomidou.mybatisplus.annotation.*;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.time.LocalDate;
|
|
||||||
import java.time.LocalDateTime;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 项目时间节点实体类
|
|
||||||
* 对应数据库表: project_timeline
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@TableName("project_timeline")
|
|
||||||
public class ProjectTimeline {
|
|
||||||
|
|
||||||
@TableId(type = IdType.ASSIGN_ID)
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 项目ID
|
|
||||||
*/
|
|
||||||
private Long projectId;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 节点名称
|
|
||||||
*/
|
|
||||||
private String nodeName;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 节点类型: phase-阶段, milestone-里程碑, event-事件, checkpoint-检查点
|
|
||||||
*/
|
|
||||||
private String nodeType;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 父节点ID
|
|
||||||
*/
|
|
||||||
private Long parentId;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 计划日期
|
|
||||||
*/
|
|
||||||
private LocalDate planDate;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 实际日期
|
|
||||||
*/
|
|
||||||
private LocalDate actualDate;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 描述
|
|
||||||
*/
|
|
||||||
private String description;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 状态: pending-待开始, in_progress-进行中, completed-已完成, delayed-延期
|
|
||||||
*/
|
|
||||||
private String status;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 排序
|
|
||||||
*/
|
|
||||||
private Integer sortOrder;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 知识库范围配置["report","file","risk","ticket"]
|
|
||||||
*/
|
|
||||||
@TableField(typeHandler = JsonbTypeHandler.class)
|
|
||||||
private List<String> kbScope;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 扩展数据
|
|
||||||
*/
|
|
||||||
@TableField(typeHandler = JsonbTypeHandler.class)
|
|
||||||
private Object extraData;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 创建时间
|
|
||||||
*/
|
|
||||||
@TableField(fill = FieldFill.INSERT)
|
|
||||||
private LocalDateTime createTime;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 更新时间
|
|
||||||
*/
|
|
||||||
@TableField(fill = FieldFill.INSERT_UPDATE)
|
|
||||||
private LocalDateTime updateTime;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 删除标记
|
|
||||||
*/
|
|
||||||
@TableLogic
|
|
||||||
private Integer deleted;
|
|
||||||
}
|
|
||||||
@@ -30,16 +30,6 @@ public class ChatSessionVO {
|
|||||||
*/
|
*/
|
||||||
private String projectName;
|
private String projectName;
|
||||||
|
|
||||||
/**
|
|
||||||
* 时间节点ID
|
|
||||||
*/
|
|
||||||
private Long timelineNodeId;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 时间节点名称
|
|
||||||
*/
|
|
||||||
private String timelineNodeName;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 最后消息时间
|
* 最后消息时间
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -586,55 +586,4 @@ public class ProjectDetailVO {
|
|||||||
*/
|
*/
|
||||||
private LocalDateTime discoverTime;
|
private LocalDateTime discoverTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 时间节点信息
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public static class TimelineInfo {
|
|
||||||
/**
|
|
||||||
* 节点ID
|
|
||||||
*/
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 节点名称
|
|
||||||
*/
|
|
||||||
private String nodeName;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 节点类型
|
|
||||||
*/
|
|
||||||
private String nodeType;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 计划日期
|
|
||||||
*/
|
|
||||||
private LocalDate planDate;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 实际日期
|
|
||||||
*/
|
|
||||||
private LocalDate actualDate;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 描述
|
|
||||||
*/
|
|
||||||
private String description;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 状态
|
|
||||||
*/
|
|
||||||
private String status;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 排序
|
|
||||||
*/
|
|
||||||
private Integer sortOrder;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 知识库范围
|
|
||||||
*/
|
|
||||||
private List<String> kbScope;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,13 +26,11 @@ public interface AiChatService {
|
|||||||
*
|
*
|
||||||
* @param userId 用户ID
|
* @param userId 用户ID
|
||||||
* @param projectId 项目ID
|
* @param projectId 项目ID
|
||||||
* @param timelineNodeId 时间节点ID(可选)
|
|
||||||
* @param firstMessage 首条消息(用于生成标题)
|
* @param firstMessage 首条消息(用于生成标题)
|
||||||
* @param customTitle 自定义标题(可选)
|
* @param customTitle 自定义标题(可选)
|
||||||
* @return 会话信息
|
* @return 会话信息
|
||||||
*/
|
*/
|
||||||
ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId,
|
ChatSessionVO createSession(Long userId, Long projectId, String firstMessage, String customTitle);
|
||||||
String firstMessage, String customTitle);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取用户会话列表
|
* 获取用户会话列表
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
if (isNewSession) {
|
if (isNewSession) {
|
||||||
sessionId = UUID.randomUUID().toString();
|
sessionId = UUID.randomUUID().toString();
|
||||||
String title = generateSessionTitle(request.getMessage());
|
String title = generateSessionTitle(request.getMessage());
|
||||||
createSession(userId, request.getProjectId(), request.getTimelineNodeId(), request.getMessage(), title);
|
createSession(userId, request.getProjectId(), request.getMessage(), title);
|
||||||
} else {
|
} else {
|
||||||
// 验证会话权限
|
// 验证会话权限
|
||||||
if (!hasSessionAccess(sessionId, userId)) {
|
if (!hasSessionAccess(sessionId, userId)) {
|
||||||
@@ -93,8 +93,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
));
|
));
|
||||||
|
|
||||||
// 2. 保存用户消息
|
// 2. 保存用户消息
|
||||||
saveMessage(finalSessionId, userId, request.getProjectId(),
|
saveMessage(finalSessionId, userId, request.getProjectId(), "user", request.getMessage(), null);
|
||||||
request.getTimelineNodeId(), "user", request.getMessage(), null);
|
|
||||||
|
|
||||||
// 3. RAG检索
|
// 3. RAG检索
|
||||||
List<Document> retrievedDocs = performRetrieval(request);
|
List<Document> retrievedDocs = performRetrieval(request);
|
||||||
@@ -140,7 +139,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
.filter(id -> !id.isEmpty())
|
.filter(id -> !id.isEmpty())
|
||||||
.toArray(String[]::new);
|
.toArray(String[]::new);
|
||||||
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
|
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
|
||||||
request.getTimelineNodeId(), "assistant",
|
"assistant",
|
||||||
fullResponse.toString(), docIds);
|
fullResponse.toString(), docIds);
|
||||||
|
|
||||||
// 发送完成消息
|
// 发送完成消息
|
||||||
@@ -166,8 +165,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatSessionVO createSession(Long userId, Long projectId, Long timelineNodeId,
|
public ChatSessionVO createSession(Long userId, Long projectId, String firstMessage, String customTitle) {
|
||||||
String firstMessage, String customTitle) {
|
|
||||||
UUID sessionId = UUID.randomUUID();
|
UUID sessionId = UUID.randomUUID();
|
||||||
String sessionIdStr = sessionId.toString();
|
String sessionIdStr = sessionId.toString();
|
||||||
String title = customTitle;
|
String title = customTitle;
|
||||||
@@ -182,7 +180,6 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
message.setSessionTitle(title);
|
message.setSessionTitle(title);
|
||||||
message.setUserId(userId);
|
message.setUserId(userId);
|
||||||
message.setProjectId(projectId);
|
message.setProjectId(projectId);
|
||||||
message.setTimelineNodeId(timelineNodeId);
|
|
||||||
message.setRole("system");
|
message.setRole("system");
|
||||||
message.setContent("会话创建");
|
message.setContent("会话创建");
|
||||||
message.setMessageIndex(0);
|
message.setMessageIndex(0);
|
||||||
@@ -200,7 +197,6 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
vo.setProjectName(project.getProjectName());
|
vo.setProjectName(project.getProjectName());
|
||||||
}
|
}
|
||||||
|
|
||||||
vo.setTimelineNodeId(timelineNodeId);
|
|
||||||
vo.setMessageCount(1);
|
vo.setMessageCount(1);
|
||||||
vo.setCreateTime(LocalDateTime.now());
|
vo.setCreateTime(LocalDateTime.now());
|
||||||
|
|
||||||
@@ -258,15 +254,6 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
if (!Boolean.TRUE.equals(request.getUseRag()) && !Boolean.TRUE.equals(request.getUseTextToSql())) {
|
if (!Boolean.TRUE.equals(request.getUseRag()) && !Boolean.TRUE.equals(request.getUseTextToSql())) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (request.getTimelineNodeId() != null) {
|
|
||||||
return ragRetriever.vectorSearchWithTimeline(
|
|
||||||
request.getMessage(),
|
|
||||||
request.getProjectId(),
|
|
||||||
request.getTimelineNodeId(),
|
|
||||||
5
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
return ragRetriever.hybridSearch(
|
return ragRetriever.hybridSearch(
|
||||||
request.getMessage(),
|
request.getMessage(),
|
||||||
request.getProjectId(),
|
request.getProjectId(),
|
||||||
@@ -275,7 +262,6 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
5
|
5
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 构建系统Prompt
|
* 构建系统Prompt
|
||||||
@@ -338,8 +324,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
* 保存消息
|
* 保存消息
|
||||||
*/
|
*/
|
||||||
private Long saveMessage(String sessionId, Long userId, Long projectId,
|
private Long saveMessage(String sessionId, Long userId, Long projectId,
|
||||||
Long timelineNodeId, String role, String content,
|
String role, String content, String[] referencedDocIds) {
|
||||||
String[] referencedDocIds) {
|
|
||||||
// 获取当前最大序号
|
// 获取当前最大序号
|
||||||
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
|
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
|
||||||
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;
|
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;
|
||||||
@@ -348,7 +333,6 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
message.setSessionId(sessionId);
|
message.setSessionId(sessionId);
|
||||||
message.setUserId(userId);
|
message.setUserId(userId);
|
||||||
message.setProjectId(projectId);
|
message.setProjectId(projectId);
|
||||||
message.setTimelineNodeId(timelineNodeId);
|
|
||||||
message.setRole(role);
|
message.setRole(role);
|
||||||
message.setContent(content);
|
message.setContent(content);
|
||||||
message.setReferencedDocIds(referencedDocIds);
|
message.setReferencedDocIds(referencedDocIds);
|
||||||
|
|||||||
@@ -183,7 +183,6 @@ public class DocumentProcessor {
|
|||||||
Map<String, Object> metadata = new HashMap<>();
|
Map<String, Object> metadata = new HashMap<>();
|
||||||
// 项目关联属性(用于检索过滤)
|
// 项目关联属性(用于检索过滤)
|
||||||
metadata.put("project_id", doc.getProjectId() != null ? doc.getProjectId().toString() : "");
|
metadata.put("project_id", doc.getProjectId() != null ? doc.getProjectId().toString() : "");
|
||||||
metadata.put("timeline_node_id", doc.getTimelineNodeId() != null ? doc.getTimelineNodeId().toString() : "");
|
|
||||||
metadata.put("kb_id", doc.getKbId() != null ? doc.getKbId().toString() : "");
|
metadata.put("kb_id", doc.getKbId() != null ? doc.getKbId().toString() : "");
|
||||||
// 文档来源信息
|
// 文档来源信息
|
||||||
metadata.put("source_type", doc.getSourceType() != null ? doc.getSourceType() : "");
|
metadata.put("source_type", doc.getSourceType() != null ? doc.getSourceType() : "");
|
||||||
|
|||||||
@@ -57,37 +57,6 @@ public class RagRetriever {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 向量检索(带时间节点过滤)
|
|
||||||
*
|
|
||||||
* @param query 查询文本
|
|
||||||
* @param projectId 项目ID
|
|
||||||
* @param timelineNodeId 时间节点ID
|
|
||||||
* @param topK 返回数量
|
|
||||||
* @return 文档列表
|
|
||||||
*/
|
|
||||||
public List<Document> vectorSearchWithTimeline(String query, Long projectId,
|
|
||||||
Long timelineNodeId, int topK) {
|
|
||||||
try {
|
|
||||||
// project_id 和 timeline_node_id 在 metadata 中是字符串类型
|
|
||||||
SearchRequest searchRequest = SearchRequest.builder()
|
|
||||||
.query(query)
|
|
||||||
.topK(topK)
|
|
||||||
.filterExpression("project_id == '" + projectId +
|
|
||||||
"' && timeline_node_id == '" + timelineNodeId +
|
|
||||||
"' && status == 'active'")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
|
||||||
log.debug("向量检索完成(时间维度),项目ID: {}, 节点ID: {}, 结果数: {}",
|
|
||||||
projectId, timelineNodeId, results.size());
|
|
||||||
return results;
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("向量检索失败: {}", e.getMessage(), e);
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TextToSQL检索
|
* TextToSQL检索
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -9,19 +9,16 @@
|
|||||||
MAX(ach.session_title) as sessionTitle,
|
MAX(ach.session_title) as sessionTitle,
|
||||||
ach.project_id as projectId,
|
ach.project_id as projectId,
|
||||||
p.project_name as projectName,
|
p.project_name as projectName,
|
||||||
MAX(ach.timeline_node_id) as timelineNodeId,
|
|
||||||
pt.node_name as timelineNodeName,
|
|
||||||
MAX(ach.create_time) as lastMessageTime,
|
MAX(ach.create_time) as lastMessageTime,
|
||||||
COUNT(*) as messageCount,
|
COUNT(*) as messageCount,
|
||||||
MIN(ach.create_time) as createTime
|
MIN(ach.create_time) as createTime
|
||||||
FROM ai_chat_history ach
|
FROM ai_chat_history ach
|
||||||
LEFT JOIN project p ON ach.project_id = p.id AND p.deleted = 0
|
LEFT JOIN project p ON ach.project_id = p.id AND p.deleted = 0
|
||||||
LEFT JOIN project_timeline pt ON ach.timeline_node_id = pt.id AND pt.deleted = 0
|
|
||||||
WHERE ach.user_id = #{userId}
|
WHERE ach.user_id = #{userId}
|
||||||
<if test="projectId != null">
|
<if test="projectId != null">
|
||||||
AND ach.project_id = #{projectId}
|
AND ach.project_id = #{projectId}
|
||||||
</if>
|
</if>
|
||||||
GROUP BY ach.session_id, ach.project_id, p.project_name, pt.node_name
|
GROUP BY ach.session_id, ach.project_id, p.project_name
|
||||||
ORDER BY lastMessageTime DESC
|
ORDER BY lastMessageTime DESC
|
||||||
</select>
|
</select>
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user