Upload files to "OHS"
This commit is contained in:
110
OHS/Reranker.py
Normal file
110
OHS/Reranker.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import List, Tuple, Dict, Any
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from OHS.Script import *
|
||||||
|
|
||||||
|
class Reranker:
|
||||||
|
def __init__(self, model_dir: str, device):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side='left', use_fast=True)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
dtype="bfloat16",
|
||||||
|
attn_implementation="flash_attention_2"
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
|
||||||
|
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
|
||||||
|
self.max_length = 8192
|
||||||
|
|
||||||
|
self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
|
||||||
|
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||||
|
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
|
||||||
|
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
|
||||||
|
|
||||||
|
self.task_description4document = "给定客户异议,根据话术的目标解决问题和目标客户特征,判断该异议处理话术是否能够解决该异议。"
|
||||||
|
self.task_description4script = "给定客户异议和相关处理话术,判断该异议处理话术是否能够解决该异议。"
|
||||||
|
|
||||||
|
def format_instruction4document(self, query: str, document: str) -> str:
|
||||||
|
return f"<Instruct>: {self.task_description4document}\n<Query>: {query}\n<Document>: {document}"
|
||||||
|
|
||||||
|
def format_instruction4script(self, query: str, document: str) -> str:
|
||||||
|
return f"<Instruct>: {self.task_description4script}\n<Query>: {query}\n<Document>: {document}"
|
||||||
|
|
||||||
|
def _process_inputs(self, texts: List[str]) -> Dict[str, torch.Tensor]:
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
texts,
|
||||||
|
padding=False,
|
||||||
|
truncation='longest_first',
|
||||||
|
return_attention_mask=False,
|
||||||
|
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, ele in enumerate(inputs['input_ids']):
|
||||||
|
inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
|
||||||
|
|
||||||
|
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
|
||||||
|
|
||||||
|
for key in inputs:
|
||||||
|
inputs[key] = inputs[key].to(self.model.device)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def compute_scores4document(self, query: str, documents: List[str]) -> List[float]:
|
||||||
|
pairs = [self.format_instruction4document(query, doc) for doc in documents]
|
||||||
|
inputs = self._process_inputs(pairs)
|
||||||
|
|
||||||
|
batch_scores = self.model(**inputs).logits[:, -1, :]
|
||||||
|
|
||||||
|
true_vector = batch_scores[:, self.token_true_id]
|
||||||
|
false_vector = batch_scores[:, self.token_false_id]
|
||||||
|
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||||
|
batch_scores = F.log_softmax(batch_scores, dim=1)
|
||||||
|
scores = batch_scores[:, 1].exp().tolist()
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def compute_scores4script(self, query: str, documents: List[str]) -> List[float]:
|
||||||
|
pairs = [self.format_instruction4script(query, doc) for doc in documents]
|
||||||
|
inputs = self._process_inputs(pairs)
|
||||||
|
|
||||||
|
batch_scores = self.model(**inputs).logits[:, -1, :]
|
||||||
|
|
||||||
|
true_vector = batch_scores[:, self.token_true_id]
|
||||||
|
false_vector = batch_scores[:, self.token_false_id]
|
||||||
|
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||||
|
batch_scores = F.log_softmax(batch_scores, dim=1)
|
||||||
|
scores = batch_scores[:, 1].exp().tolist()
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def rerank(self, query: str, top_scripts: List[ObjectionHandleScript], rerank_N: int = 10) -> List[Tuple[ObjectionHandleScript, float]]:
|
||||||
|
"""
|
||||||
|
重排序函数
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
top_scripts: 召回的话术列表,每个元素为话术对象
|
||||||
|
Returns:
|
||||||
|
重排序后的话术列表,每个元素为(话术对象, 重排序分数)
|
||||||
|
"""
|
||||||
|
if not top_scripts:
|
||||||
|
return []
|
||||||
|
scripts = [script.get_script() for script in top_scripts]
|
||||||
|
documents = [script.get_script() for script in top_scripts]
|
||||||
|
|
||||||
|
document_scores = self.compute_scores4document(query, documents)
|
||||||
|
script_scores = self.compute_scores4script(query, scripts)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, (script, doc_score, scr_score) in enumerate(zip(top_scripts, document_scores, script_scores)):
|
||||||
|
final_score = 0.6 * doc_score + 0.4 * scr_score
|
||||||
|
results.append((script, final_score))
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
results = results[:rerank_N]
|
||||||
|
return results
|
||||||
77
OHS/Script.py
Normal file
77
OHS/Script.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Tuple, Any, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ObjectionHandleScript:
|
||||||
|
def __init__(self, id: int, script: str, target_problem: str, target_customer: str):
|
||||||
|
|
||||||
|
self.id: int = id
|
||||||
|
self.script: str = script
|
||||||
|
self.target_problem: str = target_problem
|
||||||
|
self.target_customer: str = target_customer
|
||||||
|
|
||||||
|
def get_id(self):
|
||||||
|
return self.id
|
||||||
|
|
||||||
|
def get_script(self):
|
||||||
|
return self.script
|
||||||
|
|
||||||
|
def get_target_problem(self):
|
||||||
|
return self.target_problem
|
||||||
|
|
||||||
|
def get_target_customer(self):
|
||||||
|
"""
|
||||||
|
用于Reranker阶段的重排序
|
||||||
|
"""
|
||||||
|
return self.target_customer
|
||||||
|
|
||||||
|
def get_document_text(self) -> str:
|
||||||
|
"""
|
||||||
|
用于Embedding阶段召回和Reranker阶段的重排序
|
||||||
|
"""
|
||||||
|
return f"目标解决问题: {self.target_problem}\n目标客户特征: {self.target_customer}"
|
||||||
|
|
||||||
|
class ScriptDataset:
|
||||||
|
def __init__(self, json_file: str):
|
||||||
|
self.json_file = Path(json_file)
|
||||||
|
|
||||||
|
with open(json_file, 'r', encoding='utf-8') as f:
|
||||||
|
raw_data = json.load(f)
|
||||||
|
assert isinstance(raw_data, list), "JSON file should contain a list of objects"
|
||||||
|
|
||||||
|
self.scripts = []
|
||||||
|
for idx, e in enumerate(raw_data):
|
||||||
|
self.scripts.append(ObjectionHandleScript(
|
||||||
|
id=idx,
|
||||||
|
script=e['话术本身'],
|
||||||
|
target_problem=e['目标解决问题'],
|
||||||
|
target_customer=e['目标客户特征']
|
||||||
|
))
|
||||||
|
|
||||||
|
self.index_file = self.json_file.with_suffix('.faiss')
|
||||||
|
self.vectors_file = self.json_file.with_suffix('.vectors.npy')
|
||||||
|
self.metadata_file = self.json_file.with_suffix('.metadata.pkl')
|
||||||
|
self.faiss_index = None
|
||||||
|
self.document_vectors = None
|
||||||
|
|
||||||
|
def get_script(self, idx) -> ObjectionHandleScript:
|
||||||
|
return self.scripts[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.scripts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.scripts[idx]
|
||||||
|
|
||||||
|
def get_all_document_texts(self) -> List[str]:
|
||||||
|
return [script.get_document_text() for script in self.scripts]
|
||||||
|
|
||||||
|
def get_all_script_texts(self) -> List[str]:
|
||||||
|
return [script.get_script() for script in self.scripts]
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
json_path = "scripts_deduplicated.json"
|
||||||
|
dataset = ScriptDataset(json_path)
|
||||||
BIN
OHS/scripts_deduplicated.faiss
Normal file
BIN
OHS/scripts_deduplicated.faiss
Normal file
Binary file not shown.
3252
OHS/scripts_deduplicated.json
Normal file
3252
OHS/scripts_deduplicated.json
Normal file
File diff suppressed because one or more lines are too long
BIN
OHS/scripts_deduplicated.metadata.pkl
Normal file
BIN
OHS/scripts_deduplicated.metadata.pkl
Normal file
Binary file not shown.
BIN
OHS/scripts_deduplicated.vectors.npy
Normal file
BIN
OHS/scripts_deduplicated.vectors.npy
Normal file
Binary file not shown.
Reference in New Issue
Block a user