65 lines
2.6 KiB
Python
65 lines
2.6 KiB
Python
import torch
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from typing import List, Tuple, Dict, Any
|
||
import json
|
||
import os
|
||
|
||
from OHS.Embedder import *
|
||
from OHS.Reranker import *
|
||
from OHS.Script import *
|
||
|
||
class OHS_RAG_System:
|
||
def __init__(self, embedder_model_dir: str, reranker_model_dir: str, json_path: str, update_vector_database: bool = False, N: int = 50, rerank_N: int = 10):
|
||
"""
|
||
初始化 RAG 系统
|
||
|
||
Args:
|
||
embedder_model_dir: 嵌入模型路径
|
||
reranker_model_dir: 重排序模型路径
|
||
json_path: 话术数据 JSON 文件路径
|
||
update_vector_database: 是否更新向量数据库(True: 强制重建,False: 使用现有缓存), 若较缓存向量库相比, 异议处理数据库发生变化, 请将该参数设为True。
|
||
N: 召回数量
|
||
rerank_N: 重排序后返回数量
|
||
"""
|
||
assert N >= rerank_N, "N must be greater than or equal to rerank_N"
|
||
assert rerank_N > 0, "rerank_N must be greater than 0"
|
||
assert N > 0, "N must be greater than 0"
|
||
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
self.N = N
|
||
self.rerank_N = rerank_N
|
||
|
||
self.embedder = Embedder(embedder_model_dir, device=self.device)
|
||
self.reranker = Reranker(reranker_model_dir, device=self.device)
|
||
self.dataset = ScriptDataset(json_path)
|
||
|
||
self.retriever = ScriptRetriever(self.embedder, self.dataset, N=self.N, force_rebuild=update_vector_database)
|
||
|
||
def run(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]:
|
||
"""
|
||
查询RAG系统
|
||
Args:
|
||
query: 查询文本
|
||
Returns:
|
||
List[Tuple[ObjectionHandleScript, float]]: 重排序后的话术结果
|
||
"""
|
||
# 召回
|
||
top_scripts = self.retriever.get_top_scripts(query)
|
||
top_scripts = [script for script, _ in top_scripts]
|
||
|
||
# 重排序
|
||
results = self.reranker.rerank(query, top_scripts, self.rerank_N)
|
||
return results
|
||
|
||
if __name__=="__main__":
|
||
embedder_model_dir = os.path.join("OHS", "Qwen3-Embedding-0.6B")
|
||
reranker_model_dir = os.path.join("OHS", "Qwen3-Reranker-0.6B")
|
||
json_path = os.path.join("OHS", "scripts_deduplicated.json")
|
||
|
||
rag_system = OHS_RAG_System(embedder_model_dir, reranker_model_dir, json_path, N=50, rerank_N=10)
|
||
query = "赵士杰老师太贵了。"
|
||
results = rag_system.run(query)
|
||
for script, score in results:
|
||
print(f"Score: {score:.4f}, Script: {script.get_script()}") |