import torch from torch import Tensor import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel import pickle import faiss import time import numpy as np from tqdm import tqdm from OHS.Script import * class Embedder: def __init__(self, model_dir: str, device): self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side="left", use_fast=True) self.model = AutoModel.from_pretrained( model_dir, dtype="bfloat16", attn_implementation="flash_attention_2" ).eval() self.device = device self.model.to(self.device) self.embedding_dim = self.model.config.hidden_size self.max_length = 8192 self.task_description = "给定客户异议,根据话术的目标解决问题和目标客户特征,检索能够解决该异议的相关异议处理话术。" @staticmethod def last_token_pool( last_hidden_states: Tensor, attention_mask: Tensor ) -> Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def get_detailed_instruct(self, query: str) -> str: return f'Instruct: {self.task_description}\nQuery:{query}' def embed_query(self, query: str) -> Tensor: query_text = self.get_detailed_instruct(query) inputs = self.tokenizer( [query_text], truncation=True, max_length=self.max_length, return_tensors="pt", ).to(self.device) outputs = self.model(**inputs) embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask']) embeddings = F.normalize(embeddings, p=2, dim=1).cpu() return embeddings.detach().float().numpy() def embed_document(self, texts: List[str]) -> Tensor: inputs = self.tokenizer( texts, truncation=True, padding=True, max_length=8192, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model(**inputs) embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask']) embeddings = F.normalize(embeddings, p=2, dim=1).cpu() return embeddings.detach().float().numpy() class ScriptRetriever: def __init__(self, embedder: Embedder, dataset: ScriptDataset, N: int = 100, force_rebuild: bool = False): self.embedder = embedder self.dataset = dataset self.N = N # 召回数量 self.force_rebuild = force_rebuild self.index_file = dataset.index_file self.vectors_file = dataset.vectors_file self.metadata_file = dataset.metadata_file self.index = None self.document_vectors = None self.document_metadata = [] self._setup_index() def _setup_index(self): if self.force_rebuild: print("强制重建向量数据库,正在清除缓存...") self.clear_cache() if (self.index_file.exists() and self.vectors_file.exists() and self.metadata_file.exists()): print("发现缓存文件,正在加载...") self._load_cached_index() else: print("未发现缓存文件,正在构建索引...") self._build_index() self._save_index() def _build_index(self): document_texts = self.dataset.get_all_document_texts() print(f"正在编码 {len(document_texts)} 个文档...") batch_size = 32 all_embeddings = [] for i in tqdm(range(0, len(document_texts), batch_size), desc="encoding documents"): batch_texts = document_texts[i:i + batch_size] batch_embeddings = self.embedder.embed_document(batch_texts) all_embeddings.append(batch_embeddings) self.document_vectors = np.vstack(all_embeddings) self.index = faiss.IndexFlatIP(self.embedder.embedding_dim) self.index.add(self.document_vectors) self.document_metadata = list(range(len(document_texts))) def _load_cached_index(self): self.index = faiss.read_index(str(self.index_file)) self.document_vectors = np.load(str(self.vectors_file)) with open(self.metadata_file, 'rb') as f: self.document_metadata = pickle.load(f) def _save_index(self): faiss.write_index(self.index, str(self.index_file)) np.save(str(self.vectors_file), self.document_vectors) with open(self.metadata_file, 'wb') as f: pickle.dump(self.document_metadata, f) print(f"索引已保存到: {self.index_file.parent}") def retrieve(self, query: str) -> List[Tuple[int, float]]: """ 召回N个最相关的话术 Args: query: 查询文本 Returns: List[Tuple[int, float]]: (文档索引, 相似度分数) 列表 """ query_vector = self.embedder.embed_query(query) if query_vector.ndim == 1: query_vector = query_vector.reshape(1, -1) distances, indices = self.index.search(query_vector, min(self.N, len(self.document_metadata))) results = [] for i in range(len(indices[0])): doc_idx = indices[0][i] score = distances[0][i] # 检查索引是否有效 if doc_idx >= 0 and doc_idx < len(self.document_metadata): script_id = self.document_metadata[doc_idx] results.append((script_id, float(score))) return results def get_top_scripts(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]: """ 获取与查询最相关的话术对象 Args: query: 查询文本 k: 返回结果数量 Returns: List[Tuple[ObjectionHandleScript, float]]: (话术对象, 相似度分数) 列表 """ top_indices = self.retrieve(query) results = [] for script_id, score in top_indices: script = self.dataset.get_script(script_id) results.append((script, score)) return results def clear_cache(self): for file_path in [self.index_file, self.vectors_file, self.metadata_file]: if file_path.exists(): file_path.unlink() self.index = None self.document_vectors = None self.document_metadata = []