fix(ai-chat): 优化引用文档ID处理支持字符串数组类型

- 将数据库中 referenced_doc_ids 字段从 BIGINT[] 修改为 VARCHAR(255)[]
- 在实体类 AiChatMessage 中将 referencedDocIds 类型改为 String[] 并添加自定义类型处理器
- 新增 PostgresArrayTypeHandler 用于处理 PostgreSQL varchar 数组与 Java String[] 的映射
- 修改查询时 project_id 和 timeline_node_id 的过滤表达式,使用字符串匹配避免类型错误
- AiChatServiceImpl 中保存消息时改用字符串数组保存引用文档ID
- KbDocumentVO 新增 fileUrl 字段映射数据库中对应字段
- 数据库映射文件 AiDocumentMapper.xml 增加 file_url 字段映射
This commit is contained in:
2026-03-30 18:59:52 +08:00
parent 0b011bf1a7
commit 8008d367e8
7 changed files with 81 additions and 9 deletions

View File

@@ -1106,7 +1106,7 @@ CREATE TABLE ai_chat_history (
content_embedding vector(1024),
-- 引用的知识库文档
referenced_doc_ids BIGINT[],
referenced_doc_ids VARCHAR(255)[],
-- 上下文配置
system_prompt TEXT,

View File

@@ -0,0 +1,57 @@
package cn.yinlihupo.common.handler;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import java.sql.Array;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
/**
* PostgreSQL 数组类型处理器
* 用于处理 Java String[] 与 PostgreSQL varchar[] 类型之间的转换
*/
public class PostgresArrayTypeHandler extends BaseTypeHandler<String[]> {
@Override
public void setNonNullParameter(PreparedStatement ps, int i, String[] parameter, JdbcType jdbcType) throws SQLException {
// 使用 PostgreSQL 的 createArrayOf 方法创建数组
Array array = ps.getConnection().createArrayOf("varchar", parameter);
ps.setArray(i, array);
}
@Override
public String[] getNullableResult(ResultSet rs, String columnName) throws SQLException {
Array array = rs.getArray(columnName);
return extractArray(array);
}
@Override
public String[] getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
Array array = rs.getArray(columnIndex);
return extractArray(array);
}
@Override
public String[] getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
Array array = cs.getArray(columnIndex);
return extractArray(array);
}
private String[] extractArray(Array array) throws SQLException {
if (array == null) {
return null;
}
Object[] objArray = (Object[]) array.getArray();
if (objArray == null) {
return null;
}
String[] result = new String[objArray.length];
for (int i = 0; i < objArray.length; i++) {
result[i] = objArray[i] != null ? objArray[i].toString() : null;
}
return result;
}
}

View File

@@ -1,6 +1,7 @@
package cn.yinlihupo.domain.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
@@ -59,9 +60,10 @@ public class AiChatMessage {
private String contentEmbedding;
/**
* 引用的文档ID列表(JSON数组)
* 引用的文档ID列表(字符串数组)
*/
private String referencedDocIds;
@TableField(typeHandler = cn.yinlihupo.common.handler.PostgresArrayTypeHandler.class)
private String[] referencedDocIds;
/**
* 系统提示词

View File

@@ -45,6 +45,11 @@ public class KbDocumentVO {
*/
private String filePath;
/**
* 文件URL
*/
private String fileUrl;
/**
* 来源类型
*/

View File

@@ -134,9 +134,14 @@ public class AiChatServiceImpl implements AiChatService {
() -> {
// 保存助手消息
int responseTime = (int) (System.currentTimeMillis() - startTime);
// 提取引用的文档ID列表
String[] docIds = retrievedDocs.stream()
.map(doc -> doc.getMetadata().getOrDefault("doc_id", "").toString())
.filter(id -> !id.isEmpty())
.toArray(String[]::new);
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
request.getTimelineNodeId(), "assistant",
fullResponse.toString(), JSON.toJSONString(referencedDocs));
fullResponse.toString(), docIds);
// 发送完成消息
sendEvent(emitter, "complete", Map.of(
@@ -334,7 +339,7 @@ public class AiChatServiceImpl implements AiChatService {
*/
private Long saveMessage(String sessionId, Long userId, Long projectId,
Long timelineNodeId, String role, String content,
String referencedDocIds) {
String[] referencedDocIds) {
// 获取当前最大序号
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;

View File

@@ -41,10 +41,11 @@ public class RagRetriever {
*/
public List<Document> vectorSearch(String query, Long projectId, int topK) {
try {
// project_id 在 metadata 中是字符串类型,需要用字符串比较
SearchRequest searchRequest = SearchRequest.builder()
.query(query)
.topK(topK)
.filterExpression("project_id == " + projectId + " && status == 'active'")
.filterExpression("project_id == '" + projectId + "' && status == 'active'")
.build();
List<Document> results = vectorStore.similaritySearch(searchRequest);
@@ -68,12 +69,13 @@ public class RagRetriever {
public List<Document> vectorSearchWithTimeline(String query, Long projectId,
Long timelineNodeId, int topK) {
try {
// project_id 和 timeline_node_id 在 metadata 中是字符串类型
SearchRequest searchRequest = SearchRequest.builder()
.query(query)
.topK(topK)
.filterExpression("project_id == " + projectId +
" && timeline_node_id == " + timelineNodeId +
" && status == 'active'")
.filterExpression("project_id == '" + projectId +
"' && timeline_node_id == '" + timelineNodeId +
"' && status == 'active'")
.build();
List<Document> results = vectorStore.similaritySearch(searchRequest);

View File

@@ -11,6 +11,7 @@
vs.file_type as fileType,
vs.file_size as fileSize,
vs.file_path as filePath,
vs.file_url as fileUrl,
vs.source_type as sourceType,
vs.chunk_total as chunkCount,
vs.status,