import torch import torch.nn.functional as F from typing import List, Tuple, Dict, Any from transformers import AutoModelForCausalLM, AutoTokenizer from OHS.Script import * class Reranker: def __init__(self, model_dir: str, device): self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side='left', use_fast=True) self.model = AutoModelForCausalLM.from_pretrained( model_dir, dtype="bfloat16", attn_implementation="flash_attention_2" ).eval() self.model.to(device) self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") self.max_length = 8192 self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False) self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) self.task_description4document = "给定客户异议,根据话术的目标解决问题和目标客户特征,判断该异议处理话术是否能够解决该异议。" self.task_description4script = "给定客户异议和相关处理话术,判断该异议处理话术是否能够解决该异议。" def format_instruction4document(self, query: str, document: str) -> str: return f": {self.task_description4document}\n: {query}\n: {document}" def format_instruction4script(self, query: str, document: str) -> str: return f": {self.task_description4script}\n: {query}\n: {document}" def _process_inputs(self, texts: List[str]) -> Dict[str, torch.Tensor]: inputs = self.tokenizer( texts, padding=False, truncation='longest_first', return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) ) for i, ele in enumerate(inputs['input_ids']): inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length) for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs @torch.no_grad() def compute_scores4document(self, query: str, documents: List[str]) -> List[float]: pairs = [self.format_instruction4document(query, doc) for doc in documents] inputs = self._process_inputs(pairs) batch_scores = self.model(**inputs).logits[:, -1, :] true_vector = batch_scores[:, self.token_true_id] false_vector = batch_scores[:, self.token_false_id] batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = F.log_softmax(batch_scores, dim=1) scores = batch_scores[:, 1].exp().tolist() return scores @torch.no_grad() def compute_scores4script(self, query: str, documents: List[str]) -> List[float]: pairs = [self.format_instruction4script(query, doc) for doc in documents] inputs = self._process_inputs(pairs) batch_scores = self.model(**inputs).logits[:, -1, :] true_vector = batch_scores[:, self.token_true_id] false_vector = batch_scores[:, self.token_false_id] batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = F.log_softmax(batch_scores, dim=1) scores = batch_scores[:, 1].exp().tolist() return scores def rerank(self, query: str, top_scripts: List[ObjectionHandleScript], rerank_N: int = 10) -> List[Tuple[ObjectionHandleScript, float]]: """ 重排序函数 Args: query: 查询文本 top_scripts: 召回的话术列表,每个元素为话术对象 Returns: 重排序后的话术列表,每个元素为(话术对象, 重排序分数) """ if not top_scripts: return [] scripts = [script.get_script() for script in top_scripts] documents = [script.get_script() for script in top_scripts] document_scores = self.compute_scores4document(query, documents) script_scores = self.compute_scores4script(query, scripts) results = [] for i, (script, doc_score, scr_score) in enumerate(zip(top_scripts, document_scores, script_scores)): final_score = 0.6 * doc_score + 0.4 * scr_score results.append((script, final_score)) results.sort(key=lambda x: x[1], reverse=True) results = results[:rerank_N] return results