77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import json
|
|
import numpy as np
|
|
from typing import List, Dict, Tuple, Any, Optional
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
@dataclass
|
|
class ObjectionHandleScript:
|
|
def __init__(self, id: int, script: str, target_problem: str, target_customer: str):
|
|
|
|
self.id: int = id
|
|
self.script: str = script
|
|
self.target_problem: str = target_problem
|
|
self.target_customer: str = target_customer
|
|
|
|
def get_id(self):
|
|
return self.id
|
|
|
|
def get_script(self):
|
|
return self.script
|
|
|
|
def get_target_problem(self):
|
|
return self.target_problem
|
|
|
|
def get_target_customer(self):
|
|
"""
|
|
用于Reranker阶段的重排序
|
|
"""
|
|
return self.target_customer
|
|
|
|
def get_document_text(self) -> str:
|
|
"""
|
|
用于Embedding阶段召回和Reranker阶段的重排序
|
|
"""
|
|
return f"目标解决问题: {self.target_problem}\n目标客户特征: {self.target_customer}"
|
|
|
|
class ScriptDataset:
|
|
def __init__(self, json_file: str):
|
|
self.json_file = Path(json_file)
|
|
|
|
with open(json_file, 'r', encoding='utf-8') as f:
|
|
raw_data = json.load(f)
|
|
assert isinstance(raw_data, list), "JSON file should contain a list of objects"
|
|
|
|
self.scripts = []
|
|
for idx, e in enumerate(raw_data):
|
|
self.scripts.append(ObjectionHandleScript(
|
|
id=idx,
|
|
script=e['话术本身'],
|
|
target_problem=e['目标解决问题'],
|
|
target_customer=e['目标客户特征']
|
|
))
|
|
|
|
self.index_file = self.json_file.with_suffix('.faiss')
|
|
self.vectors_file = self.json_file.with_suffix('.vectors.npy')
|
|
self.metadata_file = self.json_file.with_suffix('.metadata.pkl')
|
|
self.faiss_index = None
|
|
self.document_vectors = None
|
|
|
|
def get_script(self, idx) -> ObjectionHandleScript:
|
|
return self.scripts[idx]
|
|
|
|
def __len__(self):
|
|
return len(self.scripts)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.scripts[idx]
|
|
|
|
def get_all_document_texts(self) -> List[str]:
|
|
return [script.get_document_text() for script in self.scripts]
|
|
|
|
def get_all_script_texts(self) -> List[str]:
|
|
return [script.get_script() for script in self.scripts]
|
|
|
|
if __name__=="__main__":
|
|
json_path = "scripts_deduplicated.json"
|
|
dataset = ScriptDataset(json_path) |