188 lines
6.7 KiB
Python
188 lines
6.7 KiB
Python
import torch
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
import pickle
|
|
import faiss
|
|
import time
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from OHS.Script import *
|
|
|
|
class Embedder:
|
|
def __init__(self, model_dir: str, device):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side="left", use_fast=True)
|
|
self.model = AutoModel.from_pretrained(
|
|
model_dir,
|
|
dtype="bfloat16",
|
|
attn_implementation="flash_attention_2"
|
|
).eval()
|
|
|
|
self.device = device
|
|
self.model.to(self.device)
|
|
|
|
self.embedding_dim = self.model.config.hidden_size
|
|
self.max_length = 8192
|
|
self.task_description = "给定客户异议,根据话术的目标解决问题和目标客户特征,检索能够解决该异议的相关异议处理话术。"
|
|
|
|
@staticmethod
|
|
def last_token_pool(
|
|
last_hidden_states: Tensor,
|
|
attention_mask: Tensor
|
|
) -> Tensor:
|
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
|
if left_padding:
|
|
return last_hidden_states[:, -1]
|
|
else:
|
|
sequence_lengths = attention_mask.sum(dim=1) - 1
|
|
batch_size = last_hidden_states.shape[0]
|
|
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
|
|
|
def get_detailed_instruct(self, query: str) -> str:
|
|
return f'Instruct: {self.task_description}\nQuery:{query}'
|
|
|
|
def embed_query(self, query: str) -> Tensor:
|
|
query_text = self.get_detailed_instruct(query)
|
|
inputs = self.tokenizer(
|
|
[query_text],
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
return_tensors="pt",
|
|
).to(self.device)
|
|
|
|
outputs = self.model(**inputs)
|
|
embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
|
embeddings = F.normalize(embeddings, p=2, dim=1).cpu()
|
|
return embeddings.detach().float().numpy()
|
|
|
|
def embed_document(self, texts: List[str]) -> Tensor:
|
|
inputs = self.tokenizer(
|
|
texts,
|
|
truncation=True,
|
|
padding=True,
|
|
max_length=8192,
|
|
return_tensors="pt").to(self.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
|
|
embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
|
embeddings = F.normalize(embeddings, p=2, dim=1).cpu()
|
|
return embeddings.detach().float().numpy()
|
|
|
|
class ScriptRetriever:
|
|
def __init__(self, embedder: Embedder, dataset: ScriptDataset, N: int = 100, force_rebuild: bool = False):
|
|
self.embedder = embedder
|
|
self.dataset = dataset
|
|
self.N = N # 召回数量
|
|
self.force_rebuild = force_rebuild
|
|
|
|
self.index_file = dataset.index_file
|
|
self.vectors_file = dataset.vectors_file
|
|
self.metadata_file = dataset.metadata_file
|
|
|
|
self.index = None
|
|
self.document_vectors = None
|
|
self.document_metadata = []
|
|
|
|
self._setup_index()
|
|
|
|
def _setup_index(self):
|
|
if self.force_rebuild:
|
|
print("强制重建向量数据库,正在清除缓存...")
|
|
self.clear_cache()
|
|
|
|
if (self.index_file.exists() and self.vectors_file.exists() and self.metadata_file.exists()):
|
|
print("发现缓存文件,正在加载...")
|
|
self._load_cached_index()
|
|
else:
|
|
print("未发现缓存文件,正在构建索引...")
|
|
self._build_index()
|
|
self._save_index()
|
|
|
|
def _build_index(self):
|
|
document_texts = self.dataset.get_all_document_texts()
|
|
print(f"正在编码 {len(document_texts)} 个文档...")
|
|
|
|
batch_size = 32
|
|
all_embeddings = []
|
|
|
|
for i in tqdm(range(0, len(document_texts), batch_size), desc="encoding documents"):
|
|
batch_texts = document_texts[i:i + batch_size]
|
|
batch_embeddings = self.embedder.embed_document(batch_texts)
|
|
all_embeddings.append(batch_embeddings)
|
|
|
|
self.document_vectors = np.vstack(all_embeddings)
|
|
self.index = faiss.IndexFlatIP(self.embedder.embedding_dim)
|
|
self.index.add(self.document_vectors)
|
|
self.document_metadata = list(range(len(document_texts)))
|
|
|
|
def _load_cached_index(self):
|
|
self.index = faiss.read_index(str(self.index_file))
|
|
self.document_vectors = np.load(str(self.vectors_file))
|
|
with open(self.metadata_file, 'rb') as f:
|
|
self.document_metadata = pickle.load(f)
|
|
|
|
def _save_index(self):
|
|
faiss.write_index(self.index, str(self.index_file))
|
|
np.save(str(self.vectors_file), self.document_vectors)
|
|
with open(self.metadata_file, 'wb') as f:
|
|
pickle.dump(self.document_metadata, f)
|
|
print(f"索引已保存到: {self.index_file.parent}")
|
|
|
|
def retrieve(self, query: str) -> List[Tuple[int, float]]:
|
|
"""
|
|
召回N个最相关的话术
|
|
|
|
Args:
|
|
query: 查询文本
|
|
|
|
Returns:
|
|
List[Tuple[int, float]]: (文档索引, 相似度分数) 列表
|
|
"""
|
|
query_vector = self.embedder.embed_query(query)
|
|
|
|
if query_vector.ndim == 1:
|
|
query_vector = query_vector.reshape(1, -1)
|
|
|
|
distances, indices = self.index.search(query_vector, min(self.N, len(self.document_metadata)))
|
|
results = []
|
|
for i in range(len(indices[0])):
|
|
doc_idx = indices[0][i]
|
|
score = distances[0][i]
|
|
|
|
# 检查索引是否有效
|
|
if doc_idx >= 0 and doc_idx < len(self.document_metadata):
|
|
script_id = self.document_metadata[doc_idx]
|
|
results.append((script_id, float(score)))
|
|
return results
|
|
|
|
def get_top_scripts(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]:
|
|
"""
|
|
获取与查询最相关的话术对象
|
|
|
|
Args:
|
|
query: 查询文本
|
|
k: 返回结果数量
|
|
|
|
Returns:
|
|
List[Tuple[ObjectionHandleScript, float]]: (话术对象, 相似度分数) 列表
|
|
"""
|
|
top_indices = self.retrieve(query)
|
|
|
|
results = []
|
|
for script_id, score in top_indices:
|
|
script = self.dataset.get_script(script_id)
|
|
results.append((script, score))
|
|
|
|
return results
|
|
|
|
def clear_cache(self):
|
|
for file_path in [self.index_file, self.vectors_file, self.metadata_file]:
|
|
if file_path.exists():
|
|
file_path.unlink()
|
|
self.index = None
|
|
self.document_vectors = None
|
|
self.document_metadata = [] |