Upload files to "/"
This commit is contained in:
65
__main__.py
Normal file
65
__main__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import json
|
||||
import os
|
||||
|
||||
from OHS.Embedder import *
|
||||
from OHS.Reranker import *
|
||||
from OHS.Script import *
|
||||
|
||||
class OHS_RAG_System:
|
||||
def __init__(self, embedder_model_dir: str, reranker_model_dir: str, json_path: str, update_vector_database: bool = False, N: int = 50, rerank_N: int = 10):
|
||||
"""
|
||||
初始化 RAG 系统
|
||||
|
||||
Args:
|
||||
embedder_model_dir: 嵌入模型路径
|
||||
reranker_model_dir: 重排序模型路径
|
||||
json_path: 话术数据 JSON 文件路径
|
||||
update_vector_database: 是否更新向量数据库(True: 强制重建,False: 使用现有缓存), 若较缓存向量库相比, 异议处理数据库发生变化, 请将该参数设为True。
|
||||
N: 召回数量
|
||||
rerank_N: 重排序后返回数量
|
||||
"""
|
||||
assert N >= rerank_N, "N must be greater than or equal to rerank_N"
|
||||
assert rerank_N > 0, "rerank_N must be greater than 0"
|
||||
assert N > 0, "N must be greater than 0"
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.N = N
|
||||
self.rerank_N = rerank_N
|
||||
|
||||
self.embedder = Embedder(embedder_model_dir, device=self.device)
|
||||
self.reranker = Reranker(reranker_model_dir, device=self.device)
|
||||
self.dataset = ScriptDataset(json_path)
|
||||
|
||||
self.retriever = ScriptRetriever(self.embedder, self.dataset, N=self.N, force_rebuild=update_vector_database)
|
||||
|
||||
def run(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]:
|
||||
"""
|
||||
查询RAG系统
|
||||
Args:
|
||||
query: 查询文本
|
||||
Returns:
|
||||
List[Tuple[ObjectionHandleScript, float]]: 重排序后的话术结果
|
||||
"""
|
||||
# 召回
|
||||
top_scripts = self.retriever.get_top_scripts(query)
|
||||
top_scripts = [script for script, _ in top_scripts]
|
||||
|
||||
# 重排序
|
||||
results = self.reranker.rerank(query, top_scripts, self.rerank_N)
|
||||
return results
|
||||
|
||||
if __name__=="__main__":
|
||||
embedder_model_dir = os.path.join("OHS", "Qwen3-Embedding-0.6B")
|
||||
reranker_model_dir = os.path.join("OHS", "Qwen3-Reranker-0.6B")
|
||||
json_path = os.path.join("OHS", "scripts_deduplicated.json")
|
||||
|
||||
rag_system = OHS_RAG_System(embedder_model_dir, reranker_model_dir, json_path, N=50, rerank_N=10)
|
||||
query = "赵士杰老师太贵了。"
|
||||
results = rag_system.run(query)
|
||||
for script, score in results:
|
||||
print(f"Score: {score:.4f}, Script: {script.get_script()}")
|
||||
Reference in New Issue
Block a user