Upload files to "/"

This commit is contained in:
2026-01-29 18:49:31 +08:00
commit 7847e4abf2

65
__main__.py Normal file
View File

@@ -0,0 +1,65 @@
import torch
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict, Any
import json
import os
from OHS.Embedder import *
from OHS.Reranker import *
from OHS.Script import *
class OHS_RAG_System:
def __init__(self, embedder_model_dir: str, reranker_model_dir: str, json_path: str, update_vector_database: bool = False, N: int = 50, rerank_N: int = 10):
"""
初始化 RAG 系统
Args:
embedder_model_dir: 嵌入模型路径
reranker_model_dir: 重排序模型路径
json_path: 话术数据 JSON 文件路径
update_vector_database: 是否更新向量数据库True: 强制重建False: 使用现有缓存), 若较缓存向量库相比, 异议处理数据库发生变化, 请将该参数设为True。
N: 召回数量
rerank_N: 重排序后返回数量
"""
assert N >= rerank_N, "N must be greater than or equal to rerank_N"
assert rerank_N > 0, "rerank_N must be greater than 0"
assert N > 0, "N must be greater than 0"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.N = N
self.rerank_N = rerank_N
self.embedder = Embedder(embedder_model_dir, device=self.device)
self.reranker = Reranker(reranker_model_dir, device=self.device)
self.dataset = ScriptDataset(json_path)
self.retriever = ScriptRetriever(self.embedder, self.dataset, N=self.N, force_rebuild=update_vector_database)
def run(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]:
"""
查询RAG系统
Args:
query: 查询文本
Returns:
List[Tuple[ObjectionHandleScript, float]]: 重排序后的话术结果
"""
# 召回
top_scripts = self.retriever.get_top_scripts(query)
top_scripts = [script for script, _ in top_scripts]
# 重排序
results = self.reranker.rerank(query, top_scripts, self.rerank_N)
return results
if __name__=="__main__":
embedder_model_dir = os.path.join("OHS", "Qwen3-Embedding-0.6B")
reranker_model_dir = os.path.join("OHS", "Qwen3-Reranker-0.6B")
json_path = os.path.join("OHS", "scripts_deduplicated.json")
rag_system = OHS_RAG_System(embedder_model_dir, reranker_model_dir, json_path, N=50, rerank_N=10)
query = "赵士杰老师太贵了。"
results = rag_system.run(query)
for script, score in results:
print(f"Score: {score:.4f}, Script: {script.get_script()}")