Upload files to "/"
This commit is contained in:
65
__main__.py
Normal file
65
__main__.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict, Any
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from OHS.Embedder import *
|
||||||
|
from OHS.Reranker import *
|
||||||
|
from OHS.Script import *
|
||||||
|
|
||||||
|
class OHS_RAG_System:
|
||||||
|
def __init__(self, embedder_model_dir: str, reranker_model_dir: str, json_path: str, update_vector_database: bool = False, N: int = 50, rerank_N: int = 10):
|
||||||
|
"""
|
||||||
|
初始化 RAG 系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder_model_dir: 嵌入模型路径
|
||||||
|
reranker_model_dir: 重排序模型路径
|
||||||
|
json_path: 话术数据 JSON 文件路径
|
||||||
|
update_vector_database: 是否更新向量数据库(True: 强制重建,False: 使用现有缓存), 若较缓存向量库相比, 异议处理数据库发生变化, 请将该参数设为True。
|
||||||
|
N: 召回数量
|
||||||
|
rerank_N: 重排序后返回数量
|
||||||
|
"""
|
||||||
|
assert N >= rerank_N, "N must be greater than or equal to rerank_N"
|
||||||
|
assert rerank_N > 0, "rerank_N must be greater than 0"
|
||||||
|
assert N > 0, "N must be greater than 0"
|
||||||
|
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
self.N = N
|
||||||
|
self.rerank_N = rerank_N
|
||||||
|
|
||||||
|
self.embedder = Embedder(embedder_model_dir, device=self.device)
|
||||||
|
self.reranker = Reranker(reranker_model_dir, device=self.device)
|
||||||
|
self.dataset = ScriptDataset(json_path)
|
||||||
|
|
||||||
|
self.retriever = ScriptRetriever(self.embedder, self.dataset, N=self.N, force_rebuild=update_vector_database)
|
||||||
|
|
||||||
|
def run(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]:
|
||||||
|
"""
|
||||||
|
查询RAG系统
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
Returns:
|
||||||
|
List[Tuple[ObjectionHandleScript, float]]: 重排序后的话术结果
|
||||||
|
"""
|
||||||
|
# 召回
|
||||||
|
top_scripts = self.retriever.get_top_scripts(query)
|
||||||
|
top_scripts = [script for script, _ in top_scripts]
|
||||||
|
|
||||||
|
# 重排序
|
||||||
|
results = self.reranker.rerank(query, top_scripts, self.rerank_N)
|
||||||
|
return results
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
embedder_model_dir = os.path.join("OHS", "Qwen3-Embedding-0.6B")
|
||||||
|
reranker_model_dir = os.path.join("OHS", "Qwen3-Reranker-0.6B")
|
||||||
|
json_path = os.path.join("OHS", "scripts_deduplicated.json")
|
||||||
|
|
||||||
|
rag_system = OHS_RAG_System(embedder_model_dir, reranker_model_dir, json_path, N=50, rerank_N=10)
|
||||||
|
query = "赵士杰老师太贵了。"
|
||||||
|
results = rag_system.run(query)
|
||||||
|
for script, score in results:
|
||||||
|
print(f"Score: {score:.4f}, Script: {script.get_script()}")
|
||||||
Reference in New Issue
Block a user