Files
object-handle-RAG/__main__.py
2026-01-29 18:49:31 +08:00

65 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict, Any
import json
import os
from OHS.Embedder import *
from OHS.Reranker import *
from OHS.Script import *
class OHS_RAG_System:
def __init__(self, embedder_model_dir: str, reranker_model_dir: str, json_path: str, update_vector_database: bool = False, N: int = 50, rerank_N: int = 10):
"""
初始化 RAG 系统
Args:
embedder_model_dir: 嵌入模型路径
reranker_model_dir: 重排序模型路径
json_path: 话术数据 JSON 文件路径
update_vector_database: 是否更新向量数据库True: 强制重建False: 使用现有缓存), 若较缓存向量库相比, 异议处理数据库发生变化, 请将该参数设为True。
N: 召回数量
rerank_N: 重排序后返回数量
"""
assert N >= rerank_N, "N must be greater than or equal to rerank_N"
assert rerank_N > 0, "rerank_N must be greater than 0"
assert N > 0, "N must be greater than 0"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.N = N
self.rerank_N = rerank_N
self.embedder = Embedder(embedder_model_dir, device=self.device)
self.reranker = Reranker(reranker_model_dir, device=self.device)
self.dataset = ScriptDataset(json_path)
self.retriever = ScriptRetriever(self.embedder, self.dataset, N=self.N, force_rebuild=update_vector_database)
def run(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]:
"""
查询RAG系统
Args:
query: 查询文本
Returns:
List[Tuple[ObjectionHandleScript, float]]: 重排序后的话术结果
"""
# 召回
top_scripts = self.retriever.get_top_scripts(query)
top_scripts = [script for script, _ in top_scripts]
# 重排序
results = self.reranker.rerank(query, top_scripts, self.rerank_N)
return results
if __name__=="__main__":
embedder_model_dir = os.path.join("OHS", "Qwen3-Embedding-0.6B")
reranker_model_dir = os.path.join("OHS", "Qwen3-Reranker-0.6B")
json_path = os.path.join("OHS", "scripts_deduplicated.json")
rag_system = OHS_RAG_System(embedder_model_dir, reranker_model_dir, json_path, N=50, rerank_N=10)
query = "赵士杰老师太贵了。"
results = rag_system.run(query)
for script, score in results:
print(f"Score: {score:.4f}, Script: {script.get_script()}")