fix(ai-chat): 优化引用文档ID处理支持字符串数组类型

- 将数据库中 referenced_doc_ids 字段从 BIGINT[] 修改为 VARCHAR(255)[]
- 在实体类 AiChatMessage 中将 referencedDocIds 类型改为 String[] 并添加自定义类型处理器
- 新增 PostgresArrayTypeHandler 用于处理 PostgreSQL varchar 数组与 Java String[] 的映射
- 修改查询时 project_id 和 timeline_node_id 的过滤表达式,使用字符串匹配避免类型错误
- AiChatServiceImpl 中保存消息时改用字符串数组保存引用文档ID
- KbDocumentVO 新增 fileUrl 字段映射数据库中对应字段
- 数据库映射文件 AiDocumentMapper.xml 增加 file_url 字段映射
This commit is contained in:
2026-03-30 18:59:52 +08:00
parent 0b011bf1a7
commit 8008d367e8
7 changed files with 81 additions and 9 deletions

View File

@@ -1106,7 +1106,7 @@ CREATE TABLE ai_chat_history (
content_embedding vector(1024), content_embedding vector(1024),
-- 引用的知识库文档 -- 引用的知识库文档
referenced_doc_ids BIGINT[], referenced_doc_ids VARCHAR(255)[],
-- 上下文配置 -- 上下文配置
system_prompt TEXT, system_prompt TEXT,

View File

@@ -0,0 +1,57 @@
package cn.yinlihupo.common.handler;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import java.sql.Array;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
/**
* PostgreSQL 数组类型处理器
* 用于处理 Java String[] 与 PostgreSQL varchar[] 类型之间的转换
*/
public class PostgresArrayTypeHandler extends BaseTypeHandler<String[]> {
@Override
public void setNonNullParameter(PreparedStatement ps, int i, String[] parameter, JdbcType jdbcType) throws SQLException {
// 使用 PostgreSQL 的 createArrayOf 方法创建数组
Array array = ps.getConnection().createArrayOf("varchar", parameter);
ps.setArray(i, array);
}
@Override
public String[] getNullableResult(ResultSet rs, String columnName) throws SQLException {
Array array = rs.getArray(columnName);
return extractArray(array);
}
@Override
public String[] getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
Array array = rs.getArray(columnIndex);
return extractArray(array);
}
@Override
public String[] getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
Array array = cs.getArray(columnIndex);
return extractArray(array);
}
private String[] extractArray(Array array) throws SQLException {
if (array == null) {
return null;
}
Object[] objArray = (Object[]) array.getArray();
if (objArray == null) {
return null;
}
String[] result = new String[objArray.length];
for (int i = 0; i < objArray.length; i++) {
result[i] = objArray[i] != null ? objArray[i].toString() : null;
}
return result;
}
}

View File

@@ -1,6 +1,7 @@
package cn.yinlihupo.domain.entity; package cn.yinlihupo.domain.entity;
import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data; import lombok.Data;
@@ -59,9 +60,10 @@ public class AiChatMessage {
private String contentEmbedding; private String contentEmbedding;
/** /**
* 引用的文档ID列表(JSON数组) * 引用的文档ID列表(字符串数组)
*/ */
private String referencedDocIds; @TableField(typeHandler = cn.yinlihupo.common.handler.PostgresArrayTypeHandler.class)
private String[] referencedDocIds;
/** /**
* 系统提示词 * 系统提示词

View File

@@ -45,6 +45,11 @@ public class KbDocumentVO {
*/ */
private String filePath; private String filePath;
/**
* 文件URL
*/
private String fileUrl;
/** /**
* 来源类型 * 来源类型
*/ */

View File

@@ -134,9 +134,14 @@ public class AiChatServiceImpl implements AiChatService {
() -> { () -> {
// 保存助手消息 // 保存助手消息
int responseTime = (int) (System.currentTimeMillis() - startTime); int responseTime = (int) (System.currentTimeMillis() - startTime);
// 提取引用的文档ID列表
String[] docIds = retrievedDocs.stream()
.map(doc -> doc.getMetadata().getOrDefault("doc_id", "").toString())
.filter(id -> !id.isEmpty())
.toArray(String[]::new);
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(), Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
request.getTimelineNodeId(), "assistant", request.getTimelineNodeId(), "assistant",
fullResponse.toString(), JSON.toJSONString(referencedDocs)); fullResponse.toString(), docIds);
// 发送完成消息 // 发送完成消息
sendEvent(emitter, "complete", Map.of( sendEvent(emitter, "complete", Map.of(
@@ -334,7 +339,7 @@ public class AiChatServiceImpl implements AiChatService {
*/ */
private Long saveMessage(String sessionId, Long userId, Long projectId, private Long saveMessage(String sessionId, Long userId, Long projectId,
Long timelineNodeId, String role, String content, Long timelineNodeId, String role, String content,
String referencedDocIds) { String[] referencedDocIds) {
// 获取当前最大序号 // 获取当前最大序号
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId); Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1; int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;

View File

@@ -41,10 +41,11 @@ public class RagRetriever {
*/ */
public List<Document> vectorSearch(String query, Long projectId, int topK) { public List<Document> vectorSearch(String query, Long projectId, int topK) {
try { try {
// project_id 在 metadata 中是字符串类型,需要用字符串比较
SearchRequest searchRequest = SearchRequest.builder() SearchRequest searchRequest = SearchRequest.builder()
.query(query) .query(query)
.topK(topK) .topK(topK)
.filterExpression("project_id == " + projectId + " && status == 'active'") .filterExpression("project_id == '" + projectId + "' && status == 'active'")
.build(); .build();
List<Document> results = vectorStore.similaritySearch(searchRequest); List<Document> results = vectorStore.similaritySearch(searchRequest);
@@ -68,12 +69,13 @@ public class RagRetriever {
public List<Document> vectorSearchWithTimeline(String query, Long projectId, public List<Document> vectorSearchWithTimeline(String query, Long projectId,
Long timelineNodeId, int topK) { Long timelineNodeId, int topK) {
try { try {
// project_id 和 timeline_node_id 在 metadata 中是字符串类型
SearchRequest searchRequest = SearchRequest.builder() SearchRequest searchRequest = SearchRequest.builder()
.query(query) .query(query)
.topK(topK) .topK(topK)
.filterExpression("project_id == " + projectId + .filterExpression("project_id == '" + projectId +
" && timeline_node_id == " + timelineNodeId + "' && timeline_node_id == '" + timelineNodeId +
" && status == 'active'") "' && status == 'active'")
.build(); .build();
List<Document> results = vectorStore.similaritySearch(searchRequest); List<Document> results = vectorStore.similaritySearch(searchRequest);

View File

@@ -11,6 +11,7 @@
vs.file_type as fileType, vs.file_type as fileType,
vs.file_size as fileSize, vs.file_size as fileSize,
vs.file_path as filePath, vs.file_path as filePath,
vs.file_url as fileUrl,
vs.source_type as sourceType, vs.source_type as sourceType,
vs.chunk_total as chunkCount, vs.chunk_total as chunkCount,
vs.status, vs.status,