Files
object-handle-RAG/OHS/Embedder.py
2026-01-29 18:52:35 +08:00

188 lines
6.7 KiB
Python

import torch
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import pickle
import faiss
import time
import numpy as np
from tqdm import tqdm
from OHS.Script import *
class Embedder:
def __init__(self, model_dir: str, device):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side="left", use_fast=True)
self.model = AutoModel.from_pretrained(
model_dir,
dtype="bfloat16",
attn_implementation="flash_attention_2"
).eval()
self.device = device
self.model.to(self.device)
self.embedding_dim = self.model.config.hidden_size
self.max_length = 8192
self.task_description = "给定客户异议,根据话术的目标解决问题和目标客户特征,检索能够解决该异议的相关异议处理话术。"
@staticmethod
def last_token_pool(
last_hidden_states: Tensor,
attention_mask: Tensor
) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_detailed_instruct(self, query: str) -> str:
return f'Instruct: {self.task_description}\nQuery:{query}'
def embed_query(self, query: str) -> Tensor:
query_text = self.get_detailed_instruct(query)
inputs = self.tokenizer(
[query_text],
truncation=True,
max_length=self.max_length,
return_tensors="pt",
).to(self.device)
outputs = self.model(**inputs)
embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1).cpu()
return embeddings.detach().float().numpy()
def embed_document(self, texts: List[str]) -> Tensor:
inputs = self.tokenizer(
texts,
truncation=True,
padding=True,
max_length=8192,
return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1).cpu()
return embeddings.detach().float().numpy()
class ScriptRetriever:
def __init__(self, embedder: Embedder, dataset: ScriptDataset, N: int = 100, force_rebuild: bool = False):
self.embedder = embedder
self.dataset = dataset
self.N = N # 召回数量
self.force_rebuild = force_rebuild
self.index_file = dataset.index_file
self.vectors_file = dataset.vectors_file
self.metadata_file = dataset.metadata_file
self.index = None
self.document_vectors = None
self.document_metadata = []
self._setup_index()
def _setup_index(self):
if self.force_rebuild:
print("强制重建向量数据库,正在清除缓存...")
self.clear_cache()
if (self.index_file.exists() and self.vectors_file.exists() and self.metadata_file.exists()):
print("发现缓存文件,正在加载...")
self._load_cached_index()
else:
print("未发现缓存文件,正在构建索引...")
self._build_index()
self._save_index()
def _build_index(self):
document_texts = self.dataset.get_all_document_texts()
print(f"正在编码 {len(document_texts)} 个文档...")
batch_size = 32
all_embeddings = []
for i in tqdm(range(0, len(document_texts), batch_size), desc="encoding documents"):
batch_texts = document_texts[i:i + batch_size]
batch_embeddings = self.embedder.embed_document(batch_texts)
all_embeddings.append(batch_embeddings)
self.document_vectors = np.vstack(all_embeddings)
self.index = faiss.IndexFlatIP(self.embedder.embedding_dim)
self.index.add(self.document_vectors)
self.document_metadata = list(range(len(document_texts)))
def _load_cached_index(self):
self.index = faiss.read_index(str(self.index_file))
self.document_vectors = np.load(str(self.vectors_file))
with open(self.metadata_file, 'rb') as f:
self.document_metadata = pickle.load(f)
def _save_index(self):
faiss.write_index(self.index, str(self.index_file))
np.save(str(self.vectors_file), self.document_vectors)
with open(self.metadata_file, 'wb') as f:
pickle.dump(self.document_metadata, f)
print(f"索引已保存到: {self.index_file.parent}")
def retrieve(self, query: str) -> List[Tuple[int, float]]:
"""
召回N个最相关的话术
Args:
query: 查询文本
Returns:
List[Tuple[int, float]]: (文档索引, 相似度分数) 列表
"""
query_vector = self.embedder.embed_query(query)
if query_vector.ndim == 1:
query_vector = query_vector.reshape(1, -1)
distances, indices = self.index.search(query_vector, min(self.N, len(self.document_metadata)))
results = []
for i in range(len(indices[0])):
doc_idx = indices[0][i]
score = distances[0][i]
# 检查索引是否有效
if doc_idx >= 0 and doc_idx < len(self.document_metadata):
script_id = self.document_metadata[doc_idx]
results.append((script_id, float(score)))
return results
def get_top_scripts(self, query: str) -> List[Tuple[ObjectionHandleScript, float]]:
"""
获取与查询最相关的话术对象
Args:
query: 查询文本
k: 返回结果数量
Returns:
List[Tuple[ObjectionHandleScript, float]]: (话术对象, 相似度分数) 列表
"""
top_indices = self.retrieve(query)
results = []
for script_id, score in top_indices:
script = self.dataset.get_script(script_id)
results.append((script, score))
return results
def clear_cache(self):
for file_path in [self.index_file, self.vectors_file, self.metadata_file]:
if file_path.exists():
file_path.unlink()
self.index = None
self.document_vectors = None
self.document_metadata = []