fix(ai-chat): 优化引用文档ID处理支持字符串数组类型
- 将数据库中 referenced_doc_ids 字段从 BIGINT[] 修改为 VARCHAR(255)[] - 在实体类 AiChatMessage 中将 referencedDocIds 类型改为 String[] 并添加自定义类型处理器 - 新增 PostgresArrayTypeHandler 用于处理 PostgreSQL varchar 数组与 Java String[] 的映射 - 修改查询时 project_id 和 timeline_node_id 的过滤表达式,使用字符串匹配避免类型错误 - AiChatServiceImpl 中保存消息时改用字符串数组保存引用文档ID - KbDocumentVO 新增 fileUrl 字段映射数据库中对应字段 - 数据库映射文件 AiDocumentMapper.xml 增加 file_url 字段映射
This commit is contained in:
@@ -1106,7 +1106,7 @@ CREATE TABLE ai_chat_history (
|
||||
content_embedding vector(1024),
|
||||
|
||||
-- 引用的知识库文档
|
||||
referenced_doc_ids BIGINT[],
|
||||
referenced_doc_ids VARCHAR(255)[],
|
||||
|
||||
-- 上下文配置
|
||||
system_prompt TEXT,
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package cn.yinlihupo.common.handler;
|
||||
|
||||
import org.apache.ibatis.type.BaseTypeHandler;
|
||||
import org.apache.ibatis.type.JdbcType;
|
||||
|
||||
import java.sql.Array;
|
||||
import java.sql.CallableStatement;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
|
||||
/**
|
||||
* PostgreSQL 数组类型处理器
|
||||
* 用于处理 Java String[] 与 PostgreSQL varchar[] 类型之间的转换
|
||||
*/
|
||||
public class PostgresArrayTypeHandler extends BaseTypeHandler<String[]> {
|
||||
|
||||
@Override
|
||||
public void setNonNullParameter(PreparedStatement ps, int i, String[] parameter, JdbcType jdbcType) throws SQLException {
|
||||
// 使用 PostgreSQL 的 createArrayOf 方法创建数组
|
||||
Array array = ps.getConnection().createArrayOf("varchar", parameter);
|
||||
ps.setArray(i, array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] getNullableResult(ResultSet rs, String columnName) throws SQLException {
|
||||
Array array = rs.getArray(columnName);
|
||||
return extractArray(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
|
||||
Array array = rs.getArray(columnIndex);
|
||||
return extractArray(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
|
||||
Array array = cs.getArray(columnIndex);
|
||||
return extractArray(array);
|
||||
}
|
||||
|
||||
private String[] extractArray(Array array) throws SQLException {
|
||||
if (array == null) {
|
||||
return null;
|
||||
}
|
||||
Object[] objArray = (Object[]) array.getArray();
|
||||
if (objArray == null) {
|
||||
return null;
|
||||
}
|
||||
String[] result = new String[objArray.length];
|
||||
for (int i = 0; i < objArray.length; i++) {
|
||||
result[i] = objArray[i] != null ? objArray[i].toString() : null;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package cn.yinlihupo.domain.entity;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
@@ -59,9 +60,10 @@ public class AiChatMessage {
|
||||
private String contentEmbedding;
|
||||
|
||||
/**
|
||||
* 引用的文档ID列表(JSON数组)
|
||||
* 引用的文档ID列表(字符串数组)
|
||||
*/
|
||||
private String referencedDocIds;
|
||||
@TableField(typeHandler = cn.yinlihupo.common.handler.PostgresArrayTypeHandler.class)
|
||||
private String[] referencedDocIds;
|
||||
|
||||
/**
|
||||
* 系统提示词
|
||||
|
||||
@@ -45,6 +45,11 @@ public class KbDocumentVO {
|
||||
*/
|
||||
private String filePath;
|
||||
|
||||
/**
|
||||
* 文件URL
|
||||
*/
|
||||
private String fileUrl;
|
||||
|
||||
/**
|
||||
* 来源类型
|
||||
*/
|
||||
|
||||
@@ -134,9 +134,14 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
() -> {
|
||||
// 保存助手消息
|
||||
int responseTime = (int) (System.currentTimeMillis() - startTime);
|
||||
// 提取引用的文档ID列表
|
||||
String[] docIds = retrievedDocs.stream()
|
||||
.map(doc -> doc.getMetadata().getOrDefault("doc_id", "").toString())
|
||||
.filter(id -> !id.isEmpty())
|
||||
.toArray(String[]::new);
|
||||
Long messageId = saveMessage(finalSessionId, userId, request.getProjectId(),
|
||||
request.getTimelineNodeId(), "assistant",
|
||||
fullResponse.toString(), JSON.toJSONString(referencedDocs));
|
||||
fullResponse.toString(), docIds);
|
||||
|
||||
// 发送完成消息
|
||||
sendEvent(emitter, "complete", Map.of(
|
||||
@@ -334,7 +339,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
*/
|
||||
private Long saveMessage(String sessionId, Long userId, Long projectId,
|
||||
Long timelineNodeId, String role, String content,
|
||||
String referencedDocIds) {
|
||||
String[] referencedDocIds) {
|
||||
// 获取当前最大序号
|
||||
Integer maxIndex = chatHistoryMapper.selectMaxMessageIndex(sessionId);
|
||||
int nextIndex = (maxIndex != null ? maxIndex : 0) + 1;
|
||||
|
||||
@@ -41,10 +41,11 @@ public class RagRetriever {
|
||||
*/
|
||||
public List<Document> vectorSearch(String query, Long projectId, int topK) {
|
||||
try {
|
||||
// project_id 在 metadata 中是字符串类型,需要用字符串比较
|
||||
SearchRequest searchRequest = SearchRequest.builder()
|
||||
.query(query)
|
||||
.topK(topK)
|
||||
.filterExpression("project_id == " + projectId + " && status == 'active'")
|
||||
.filterExpression("project_id == '" + projectId + "' && status == 'active'")
|
||||
.build();
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
||||
@@ -68,12 +69,13 @@ public class RagRetriever {
|
||||
public List<Document> vectorSearchWithTimeline(String query, Long projectId,
|
||||
Long timelineNodeId, int topK) {
|
||||
try {
|
||||
// project_id 和 timeline_node_id 在 metadata 中是字符串类型
|
||||
SearchRequest searchRequest = SearchRequest.builder()
|
||||
.query(query)
|
||||
.topK(topK)
|
||||
.filterExpression("project_id == " + projectId +
|
||||
" && timeline_node_id == " + timelineNodeId +
|
||||
" && status == 'active'")
|
||||
.filterExpression("project_id == '" + projectId +
|
||||
"' && timeline_node_id == '" + timelineNodeId +
|
||||
"' && status == 'active'")
|
||||
.build();
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(searchRequest);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
vs.file_type as fileType,
|
||||
vs.file_size as fileSize,
|
||||
vs.file_path as filePath,
|
||||
vs.file_url as fileUrl,
|
||||
vs.source_type as sourceType,
|
||||
vs.chunk_total as chunkCount,
|
||||
vs.status,
|
||||
|
||||
Reference in New Issue
Block a user