import torch import numpy as np from pathlib import Path from typing import List, Tuple, Dict, Any import json import os from OHS.Embedder import * from OHS.Reranker import * from OHS.Script import * class OHS_RAG_System: def __init__(self, embedder_model_dir: str, reranker_model_dir: str, json_path: str, update_vector_database: bool = False, N: int = 50, rerank_N: int = 10): """ 初始化 RAG 系统 Args: embedder_model_dir: 嵌入模型路径 reranker_model_dir: 重排序模型路径 json_path: 话术数据 JSON 文件路径 update_vector_database: 是否更新向量数据库(True: 强制重建,False: 使用现有缓存), 若较缓存向量库相比, 异议处理数据库发生变化, 请将该参数设为True。 N: 召回数量 rerank_N: 重排序后返回数量 """ assert N >= rerank_N, "N must be greater than or equal to rerank_N" assert rerank_N > 0, "rerank_N must be greater than 0" assert N > 0, "N must be greater than 0" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.N = N self.rerank_N = rerank_N self.embedder = Embedder(embedder_model_dir, device=self.device) self.reranker = Reranker(reranker_model_dir, device=self.device) self.dataset = ScriptDataset(json_path) self.retriever = ScriptRetriever(self.embedder, self.dataset, N=self.N, force_rebuild=update_vector_database) def run(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]: """ 查询RAG系统 Args: query: 查询文本 Returns: List[Tuple[ObjectionHandleScript, float]]: 重排序后的话术结果 """ # 召回 top_scripts = self.retriever.get_top_scripts(query) top_scripts = [script for script, _ in top_scripts] # 重排序 results = self.reranker.rerank(query, top_scripts, self.rerank_N) return results if __name__=="__main__": embedder_model_dir = os.path.join("OHS", "Qwen3-Embedding-0.6B") reranker_model_dir = os.path.join("OHS", "Qwen3-Reranker-0.6B") json_path = os.path.join("OHS", "scripts_deduplicated.json") rag_system = OHS_RAG_System(embedder_model_dir, reranker_model_dir, json_path, N=50, rerank_N=10) query = "赵士杰老师太贵了。" results = rag_system.run(query) for script, score in results: print(f"Score: {score:.4f}, Script: {script.get_script()}")