diff --git a/OHS/Embedder.py b/OHS/Embedder.py new file mode 100644 index 0000000..b88ddc6 --- /dev/null +++ b/OHS/Embedder.py @@ -0,0 +1,188 @@ +import torch +from torch import Tensor +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModel + +import pickle +import faiss +import time +import numpy as np +from tqdm import tqdm + +from OHS.Script import * + +class Embedder: + def __init__(self, model_dir: str, device): + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side="left", use_fast=True) + self.model = AutoModel.from_pretrained( + model_dir, + dtype="bfloat16", + attn_implementation="flash_attention_2" + ).eval() + + self.device = device + self.model.to(self.device) + + self.embedding_dim = self.model.config.hidden_size + self.max_length = 8192 + self.task_description = "给定客户异议,根据话术的目标解决问题和目标客户特征,检索能够解决该异议的相关异议处理话术。" + + @staticmethod + def last_token_pool( + last_hidden_states: Tensor, + attention_mask: Tensor + ) -> Tensor: + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + def get_detailed_instruct(self, query: str) -> str: + return f'Instruct: {self.task_description}\nQuery:{query}' + + def embed_query(self, query: str) -> Tensor: + query_text = self.get_detailed_instruct(query) + inputs = self.tokenizer( + [query_text], + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ).to(self.device) + + outputs = self.model(**inputs) + embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1).cpu() + return embeddings.detach().float().numpy() + + def embed_document(self, texts: List[str]) -> Tensor: + inputs = self.tokenizer( + texts, + truncation=True, + padding=True, + max_length=8192, + return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1).cpu() + return embeddings.detach().float().numpy() + +class ScriptRetriever: + def __init__(self, embedder: Embedder, dataset: ScriptDataset, N: int = 100, force_rebuild: bool = False): + self.embedder = embedder + self.dataset = dataset + self.N = N # 召回数量 + self.force_rebuild = force_rebuild + + self.index_file = dataset.index_file + self.vectors_file = dataset.vectors_file + self.metadata_file = dataset.metadata_file + + self.index = None + self.document_vectors = None + self.document_metadata = [] + + self._setup_index() + + def _setup_index(self): + if self.force_rebuild: + print("强制重建向量数据库,正在清除缓存...") + self.clear_cache() + + if (self.index_file.exists() and self.vectors_file.exists() and self.metadata_file.exists()): + print("发现缓存文件,正在加载...") + self._load_cached_index() + else: + print("未发现缓存文件,正在构建索引...") + self._build_index() + self._save_index() + + def _build_index(self): + document_texts = self.dataset.get_all_document_texts() + print(f"正在编码 {len(document_texts)} 个文档...") + + batch_size = 32 + all_embeddings = [] + + for i in tqdm(range(0, len(document_texts), batch_size), desc="encoding documents"): + batch_texts = document_texts[i:i + batch_size] + batch_embeddings = self.embedder.embed_document(batch_texts) + all_embeddings.append(batch_embeddings) + + self.document_vectors = np.vstack(all_embeddings) + self.index = faiss.IndexFlatIP(self.embedder.embedding_dim) + self.index.add(self.document_vectors) + self.document_metadata = list(range(len(document_texts))) + + def _load_cached_index(self): + self.index = faiss.read_index(str(self.index_file)) + self.document_vectors = np.load(str(self.vectors_file)) + with open(self.metadata_file, 'rb') as f: + self.document_metadata = pickle.load(f) + + def _save_index(self): + faiss.write_index(self.index, str(self.index_file)) + np.save(str(self.vectors_file), self.document_vectors) + with open(self.metadata_file, 'wb') as f: + pickle.dump(self.document_metadata, f) + print(f"索引已保存到: {self.index_file.parent}") + + def retrieve(self, query: str) -> List[Tuple[int, float]]: + """ + 召回N个最相关的话术 + + Args: + query: 查询文本 + + Returns: + List[Tuple[int, float]]: (文档索引, 相似度分数) 列表 + """ + query_vector = self.embedder.embed_query(query) + + if query_vector.ndim == 1: + query_vector = query_vector.reshape(1, -1) + + distances, indices = self.index.search(query_vector, min(self.N, len(self.document_metadata))) + results = [] + for i in range(len(indices[0])): + doc_idx = indices[0][i] + score = distances[0][i] + + # 检查索引是否有效 + if doc_idx >= 0 and doc_idx < len(self.document_metadata): + script_id = self.document_metadata[doc_idx] + results.append((script_id, float(score))) + return results + + def get_top_scripts(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]: + """ + 获取与查询最相关的话术对象 + + Args: + query: 查询文本 + k: 返回结果数量 + + Returns: + List[Tuple[ObjectionHandleScript, float]]: (话术对象, 相似度分数) 列表 + """ + top_indices = self.retrieve(query) + + results = [] + for script_id, score in top_indices: + script = self.dataset.get_script(script_id) + results.append((script, score)) + + return results + + def clear_cache(self): + for file_path in [self.index_file, self.vectors_file, self.metadata_file]: + if file_path.exists(): + file_path.unlink() + self.index = None + self.document_vectors = None + self.document_metadata = [] \ No newline at end of file