fix(ai-chat): 优化引用文档ID处理支持字符串数组类型
- 将数据库中 referenced_doc_ids 字段从 BIGINT[] 修改为 VARCHAR(255)[] - 在实体类 AiChatMessage 中将 referencedDocIds 类型改为 String[] 并添加自定义类型处理器 - 新增 PostgresArrayTypeHandler 用于处理 PostgreSQL varchar 数组与 Java String[] 的映射 - 修改查询时 project_id 和 timeline_node_id 的过滤表达式,使用字符串匹配避免类型错误 - AiChatServiceImpl 中保存消息时改用字符串数组保存引用文档ID - KbDocumentVO 新增 fileUrl 字段映射数据库中对应字段 - 数据库映射文件 AiDocumentMapper.xml 增加 file_url 字段映射
This commit is contained in:
@@ -1106,7 +1106,7 @@ CREATE TABLE ai_chat_history (
|
|||||||
content_embedding vector(1024),
|
content_embedding vector(1024),
|
||||||
|
|
||||||
-- 引用的知识库文档
|
-- 引用的知识库文档
|
||||||
referenced_doc_ids BIGINT[],
|
referenced_doc_ids VARCHAR(255)[],
|
||||||
|
|
||||||
-- 上下文配置
|
-- 上下文配置
|
||||||
system_prompt TEXT,
|
system_prompt TEXT,
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package cn.yinlihupo.common.handler;
|
||||||
|
|
||||||
|
import org.apache.ibatis.type.BaseTypeHandler;
|
||||||
|
import org.apache.ibatis.type.JdbcType;
|
||||||
|
|
||||||
|
import java.sql.Array;
|
||||||
|
import java.sql.CallableStatement;
|
||||||
|
import java.sql.PreparedStatement;
|
||||||
|
import java.sql.ResultSet;
|
||||||
|
import java.sql.SQLException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PostgreSQL 数组类型处理器
|
||||||
|
* 用于处理 Java String[] 与 PostgreSQL varchar[] 类型之间的转换
|
||||||
|
*/
|
||||||
|
public class PostgresArrayTypeHandler extends BaseTypeHandler<String[]> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setNonNullParameter(PreparedStatement ps, int i, String[] parameter, JdbcType jdbcType) throws SQLException {
|
||||||
|
// 使用 PostgreSQL 的 createArrayOf 方法创建数组
|
||||||
|
Array array = ps.getConnection().createArrayOf("varchar", parameter);
|
||||||
|
ps.setArray(i, array);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getNullableResult(ResultSet rs, String columnName) throws SQLException {
|
||||||
|
Array array = rs.getArray(columnName);
|
||||||
|
return extractArray(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
|
||||||
|
Array array = rs.getArray(columnIndex);
|
||||||
|
return extractArray(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
|
||||||
|
Array array = cs.getArray(columnIndex);
|
||||||
|
return extractArray(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String[] extractArray(Array array) throws SQLException {
|
||||||
|
if (array == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
Object[] objArray = (Object[]) array.getArray();
|
||||||
|
if (objArray == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String[] result = new String[objArray.length];
|
||||||
|
for (int i = 0; i < objArray.length; i++) {
|
||||||
|
result[i] = objArray[i] != null ? objArray[i].toString() : null;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package cn.yinlihupo.domain.entity;
|
package cn.yinlihupo.domain.entity;
|
||||||
|
|
||||||
import com.baomidou.mybatisplus.annotation.IdType;
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableField;
|
||||||
import com.baomidou.mybatisplus.annotation.TableId;
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
import com.baomidou.mybatisplus.annotation.TableName;
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -59,9 +60,10 @@ public class AiChatMessage {
|
|||||||
private String contentEmbedding;
|
private String contentEmbedding;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 引用的文档ID列表(JSON数组)
|
* 引用的文档ID列表(字符串数组)
|
||||||
*/
|
*/
|
||||||
private String referencedDocIds;
|
@TableField(typeHandler = cn.yinlihupo.common.handler.PostgresArrayTypeHandler.class)
|
||||||
|
private String[] referencedDocIds;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 系统提示词
|
* 系统提示词
|
||||||
|
|||||||
@@ -45,6 +45,11 @@ public class KbDocumentVO {
|
|||||||
*/
|
*/
|
||||||
private String filePath;
|
private String filePath;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文件URL
|
||||||
|
*/
|
||||||
|
private String fileUrl;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 来源类型
|
* 来源类型
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -134,9 +134,14 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
() -> {
|
() -> {
|
||||||
// 保存助手消息
|
// 保存助手消息
|
||||||
int responseTime = (int) (System.currentTimeMillis() - startTime);
|
int responseTime = (int) (System.currentTimeMillis() - startTime);
|
||||||
|
// 提取引用的文档ID列表
|
||||||
|
String[] docIds = retrievedDocs.stream()
|
||||||
|
.map(doc -> doc.getMetadata().getOrDefault("doc_id", "").toString())
|
||||||
|
.filter(id -> !id.isEmpty())
|
||||||
|
.toArray(String[]::new);
|
||||||
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
|
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
|
||||||
request.getTimelineNodeId(), "assistant",
|
request.getTimelineNodeId(), "assistant",
|
||||||
fullResponse.toString(), JSON.toJSONString(referencedDocs));
|
fullResponse.toString(), docIds);
|
||||||
|
|
||||||
// 发送完成消息
|
// 发送完成消息
|
||||||
sendEvent(emitter, "complete", Map.of(
|
sendEvent(emitter, "complete", Map.of(
|
||||||
@@ -334,7 +339,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
*/
|
*/
|
||||||
private Long saveMessage(String sessionId, Long userId, Long projectId,
|
private Long saveMessage(String sessionId, Long userId, Long projectId,
|
||||||
Long timelineNodeId, String role, String content,
|
Long timelineNodeId, String role, String content,
|
||||||
String referencedDocIds) {
|
String[] referencedDocIds) {
|
||||||
// 获取当前最大序号
|
// 获取当前最大序号
|
||||||
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
|
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
|
||||||
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;
|
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;
|
||||||
|
|||||||
@@ -41,10 +41,11 @@ public class RagRetriever {
|
|||||||
*/
|
*/
|
||||||
public List<Document> vectorSearch(String query, Long projectId, int topK) {
|
public List<Document> vectorSearch(String query, Long projectId, int topK) {
|
||||||
try {
|
try {
|
||||||
|
// project_id 在 metadata 中是字符串类型,需要用字符串比较
|
||||||
SearchRequest searchRequest = SearchRequest.builder()
|
SearchRequest searchRequest = SearchRequest.builder()
|
||||||
.query(query)
|
.query(query)
|
||||||
.topK(topK)
|
.topK(topK)
|
||||||
.filterExpression("project_id == " + projectId + " && status == 'active'")
|
.filterExpression("project_id == '" + projectId + "' && status == 'active'")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
||||||
@@ -68,12 +69,13 @@ public class RagRetriever {
|
|||||||
public List<Document> vectorSearchWithTimeline(String query, Long projectId,
|
public List<Document> vectorSearchWithTimeline(String query, Long projectId,
|
||||||
Long timelineNodeId, int topK) {
|
Long timelineNodeId, int topK) {
|
||||||
try {
|
try {
|
||||||
|
// project_id 和 timeline_node_id 在 metadata 中是字符串类型
|
||||||
SearchRequest searchRequest = SearchRequest.builder()
|
SearchRequest searchRequest = SearchRequest.builder()
|
||||||
.query(query)
|
.query(query)
|
||||||
.topK(topK)
|
.topK(topK)
|
||||||
.filterExpression("project_id == " + projectId +
|
.filterExpression("project_id == '" + projectId +
|
||||||
" && timeline_node_id == " + timelineNodeId +
|
"' && timeline_node_id == '" + timelineNodeId +
|
||||||
" && status == 'active'")
|
"' && status == 'active'")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
vs.file_type as fileType,
|
vs.file_type as fileType,
|
||||||
vs.file_size as fileSize,
|
vs.file_size as fileSize,
|
||||||
vs.file_path as filePath,
|
vs.file_path as filePath,
|
||||||
|
vs.file_url as fileUrl,
|
||||||
vs.source_type as sourceType,
|
vs.source_type as sourceType,
|
||||||
vs.chunk_total as chunkCount,
|
vs.chunk_total as chunkCount,
|
||||||
vs.status,
|
vs.status,
|
||||||
|
|||||||
Reference in New Issue
Block a user