From 8008d367e8d520b35ce71164f90008337bb08d59 Mon Sep 17 00:00:00 2001 From: JiaoTianBo Date: Mon, 30 Mar 2026 18:59:52 +0800 Subject: [PATCH] =?UTF-8?q?fix(ai-chat):=20=E4=BC=98=E5=8C=96=E5=BC=95?= =?UTF-8?q?=E7=94=A8=E6=96=87=E6=A1=A3ID=E5=A4=84=E7=90=86=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=AD=97=E7=AC=A6=E4=B8=B2=E6=95=B0=E7=BB=84=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将数据库中 referenced_doc_ids 字段从 BIGINT[] 修改为 VARCHAR(255)[] - 在实体类 AiChatMessage 中将 referencedDocIds 类型改为 String[] 并添加自定义类型处理器 - 新增 PostgresArrayTypeHandler 用于处理 PostgreSQL varchar 数组与 Java String[] 的映射 - 修改查询时 project_id 和 timeline_node_id 的过滤表达式,使用字符串匹配避免类型错误 - AiChatServiceImpl 中保存消息时改用字符串数组保存引用文档ID - KbDocumentVO 新增 fileUrl 字段映射数据库中对应字段 - 数据库映射文件 AiDocumentMapper.xml 增加 file_url 字段映射 --- docs/dev-ops/pgsql/sql/weform_run.sql | 2 +- .../handler/PostgresArrayTypeHandler.java | 57 +++++++++++++++++++ .../domain/entity/AiChatMessage.java | 6 +- .../cn/yinlihupo/domain/vo/KbDocumentVO.java | 5 ++ .../service/ai/impl/AiChatServiceImpl.java | 9 ++- .../service/ai/rag/RagRetriever.java | 10 ++-- .../resources/mapper/AiDocumentMapper.xml | 1 + 7 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 src/main/java/cn/yinlihupo/common/handler/PostgresArrayTypeHandler.java diff --git a/docs/dev-ops/pgsql/sql/weform_run.sql b/docs/dev-ops/pgsql/sql/weform_run.sql index 5d54dfa..c0b2e57 100644 --- a/docs/dev-ops/pgsql/sql/weform_run.sql +++ b/docs/dev-ops/pgsql/sql/weform_run.sql @@ -1106,7 +1106,7 @@ CREATE TABLE ai_chat_history ( content_embedding vector(1024), -- 引用的知识库文档 - referenced_doc_ids BIGINT[], + referenced_doc_ids VARCHAR(255)[], -- 上下文配置 system_prompt TEXT, diff --git a/src/main/java/cn/yinlihupo/common/handler/PostgresArrayTypeHandler.java b/src/main/java/cn/yinlihupo/common/handler/PostgresArrayTypeHandler.java new file mode 100644 index 0000000..166a9d5 --- /dev/null +++ b/src/main/java/cn/yinlihupo/common/handler/PostgresArrayTypeHandler.java @@ -0,0 +1,57 @@ +package cn.yinlihupo.common.handler; + +import org.apache.ibatis.type.BaseTypeHandler; +import org.apache.ibatis.type.JdbcType; + +import java.sql.Array; +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * PostgreSQL 数组类型处理器 + * 用于处理 Java String[] 与 PostgreSQL varchar[] 类型之间的转换 + */ +public class PostgresArrayTypeHandler extends BaseTypeHandler { + + @Override + public void setNonNullParameter(PreparedStatement ps, int i, String[] parameter, JdbcType jdbcType) throws SQLException { + // 使用 PostgreSQL 的 createArrayOf 方法创建数组 + Array array = ps.getConnection().createArrayOf("varchar", parameter); + ps.setArray(i, array); + } + + @Override + public String[] getNullableResult(ResultSet rs, String columnName) throws SQLException { + Array array = rs.getArray(columnName); + return extractArray(array); + } + + @Override + public String[] getNullableResult(ResultSet rs, int columnIndex) throws SQLException { + Array array = rs.getArray(columnIndex); + return extractArray(array); + } + + @Override + public String[] getNullableResult(CallableStatement cs, int columnIndex) throws SQLException { + Array array = cs.getArray(columnIndex); + return extractArray(array); + } + + private String[] extractArray(Array array) throws SQLException { + if (array == null) { + return null; + } + Object[] objArray = (Object[]) array.getArray(); + if (objArray == null) { + return null; + } + String[] result = new String[objArray.length]; + for (int i = 0; i < objArray.length; i++) { + result[i] = objArray[i] != null ? objArray[i].toString() : null; + } + return result; + } +} diff --git a/src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java b/src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java index 3779adf..1534691 100644 --- a/src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java +++ b/src/main/java/cn/yinlihupo/domain/entity/AiChatMessage.java @@ -1,6 +1,7 @@ package cn.yinlihupo.domain.entity; import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import lombok.Data; @@ -59,9 +60,10 @@ public class AiChatMessage { private String contentEmbedding; /** - * 引用的文档ID列表(JSON数组) + * 引用的文档ID列表(字符串数组) */ - private String referencedDocIds; + @TableField(typeHandler = cn.yinlihupo.common.handler.PostgresArrayTypeHandler.class) + private String[] referencedDocIds; /** * 系统提示词 diff --git a/src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java b/src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java index 1913d95..75854df 100644 --- a/src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java +++ b/src/main/java/cn/yinlihupo/domain/vo/KbDocumentVO.java @@ -45,6 +45,11 @@ public class KbDocumentVO { */ private String filePath; + /** + * 文件URL + */ + private String fileUrl; + /** * 来源类型 */ diff --git a/src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java b/src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java index dd81ac1..fb0a769 100644 --- a/src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java +++ b/src/main/java/cn/yinlihupo/service/ai/impl/AiChatServiceImpl.java @@ -134,9 +134,14 @@ public class AiChatServiceImpl implements AiChatService { () -> { // 保存助手消息 int responseTime = (int) (System.currentTimeMillis() - startTime); + // 提取引用的文档ID列表 + String[] docIds = retrievedDocs.stream() + .map(doc -> doc.getMetadata().getOrDefault("doc_id", "").toString()) + .filter(id -> !id.isEmpty()) + .toArray(String[]::new); Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(), request.getTimelineNodeId(), "assistant", - fullResponse.toString(), JSON.toJSONString(referencedDocs)); + fullResponse.toString(), docIds); // 发送完成消息 sendEvent(emitter, "complete", Map.of( @@ -334,7 +339,7 @@ public class AiChatServiceImpl implements AiChatService { */ private Long saveMessage(String sessionId, Long userId, Long projectId, Long timelineNodeId, String role, String content, - String referencedDocIds) { + String[] referencedDocIds) { // 获取当前最大序号 Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId); int nextIndex = (maxIndex != null ? maxIndex : 0) + 1; diff --git a/src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java b/src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java index b936e41..ec77444 100644 --- a/src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java +++ b/src/main/java/cn/yinlihupo/service/ai/rag/RagRetriever.java @@ -41,10 +41,11 @@ public class RagRetriever { */ public List vectorSearch(String query, Long projectId, int topK) { try { + // project_id 在 metadata 中是字符串类型,需要用字符串比较 SearchRequest searchRequest = SearchRequest.builder() .query(query) .topK(topK) - .filterExpression("project_id == " + projectId + " && status == 'active'") + .filterExpression("project_id == '" + projectId + "' && status == 'active'") .build(); List results = vectorStore.similaritySearch(searchRequest); @@ -68,12 +69,13 @@ public class RagRetriever { public List vectorSearchWithTimeline(String query, Long projectId, Long timelineNodeId, int topK) { try { + // project_id 和 timeline_node_id 在 metadata 中是字符串类型 SearchRequest searchRequest = SearchRequest.builder() .query(query) .topK(topK) - .filterExpression("project_id == " + projectId + - " && timeline_node_id == " + timelineNodeId + - " && status == 'active'") + .filterExpression("project_id == '" + projectId + + "' && timeline_node_id == '" + timelineNodeId + + "' && status == 'active'") .build(); List results = vectorStore.similaritySearch(searchRequest); diff --git a/src/main/resources/mapper/AiDocumentMapper.xml b/src/main/resources/mapper/AiDocumentMapper.xml index 9721354..24f241c 100644 --- a/src/main/resources/mapper/AiDocumentMapper.xml +++ b/src/main/resources/mapper/AiDocumentMapper.xml @@ -11,6 +11,7 @@ vs.file_type as fileType, vs.file_size as fileSize, vs.file_path as filePath, + vs.file_url as fileUrl, vs.source_type as sourceType, vs.chunk_total as chunkCount, vs.status,