From 7847e4abf221235bd0051f56d7bc35674a801dc8 Mon Sep 17 00:00:00 2001 From: WangZiFan Date: Thu, 29 Jan 2026 18:49:31 +0800 Subject: [PATCH] Upload files to "/" --- __main__.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 __main__.py diff --git a/__main__.py b/__main__.py new file mode 100644 index 0000000..022a443 --- /dev/null +++ b/__main__.py @@ -0,0 +1,65 @@ +import torch +import numpy as np +from pathlib import Path +from typing import List, Tuple, Dict, Any +import json +import os + +from OHS.Embedder import * +from OHS.Reranker import * +from OHS.Script import * + +class OHS_RAG_System: + def __init__(self, embedder_model_dir: str, reranker_model_dir: str, json_path: str, update_vector_database: bool = False, N: int = 50, rerank_N: int = 10): + """ + 初始化 RAG 系统 + + Args: + embedder_model_dir: 嵌入模型路径 + reranker_model_dir: 重排序模型路径 + json_path: 话术数据 JSON 文件路径 + update_vector_database: 是否更新向量数据库(True: 强制重建,False: 使用现有缓存), 若较缓存向量库相比, 异议处理数据库发生变化, 请将该参数设为True。 + N: 召回数量 + rerank_N: 重排序后返回数量 + """ + assert N >= rerank_N, "N must be greater than or equal to rerank_N" + assert rerank_N > 0, "rerank_N must be greater than 0" + assert N > 0, "N must be greater than 0" + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.N = N + self.rerank_N = rerank_N + + self.embedder = Embedder(embedder_model_dir, device=self.device) + self.reranker = Reranker(reranker_model_dir, device=self.device) + self.dataset = ScriptDataset(json_path) + + self.retriever = ScriptRetriever(self.embedder, self.dataset, N=self.N, force_rebuild=update_vector_database) + + def run(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]: + """ + 查询RAG系统 + Args: + query: 查询文本 + Returns: + List[Tuple[ObjectionHandleScript, float]]: 重排序后的话术结果 + """ + # 召回 + top_scripts = self.retriever.get_top_scripts(query) + top_scripts = [script for script, _ in top_scripts] + + # 重排序 + results = self.reranker.rerank(query, top_scripts, self.rerank_N) + return results + +if __name__=="__main__": + embedder_model_dir = os.path.join("OHS", "Qwen3-Embedding-0.6B") + reranker_model_dir = os.path.join("OHS", "Qwen3-Reranker-0.6B") + json_path = os.path.join("OHS", "scripts_deduplicated.json") + + rag_system = OHS_RAG_System(embedder_model_dir, reranker_model_dir, json_path, N=50, rerank_N=10) + query = "赵士杰老师太贵了。" + results = rag_system.run(query) + for script, score in results: + print(f"Score: {score:.4f}, Script: {script.get_script()}") \ No newline at end of file