Files
deal-classification/inference.py

152 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from model import TransClassifier
from transformers import AutoTokenizer
from data_process import extract_json_data, Formatter, load_data_from_dict
import torch
import json
from typing import Dict, List, Optional
import os
import warnings
warnings.filterwarnings("ignore")
valid_keys = [
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
"Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential"
]
ch_valid_keys = [
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
"竞争者心态", "认知阶段", "最后互动时间", "推荐潜力"
]
all_keys = valid_keys + ["session_id", "label"]
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
d1_keys = valid_keys[:5]
d2_keys = valid_keys[5:10]
d3_keys = valid_keys[10:15]
d4_keys = valid_keys[15:19]
d5_keys = valid_keys[19:22]
class InferenceEngine:
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
self.backbone_dir = backbone_dir
self.ckpt_path = ckpt_path
self.device = device
# 加载 tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
print(f"Tokenizer loaded from {backbone_dir}")
# 加载模型
self.model = TransClassifier(backbone_dir, 2, device)
self.model.to(device)
if self.ckpt_path:
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Model loaded from {ckpt_path}")
else:
print("Warning: No checkpoint path provided. Using untrained model.")
self.model.eval()
print("Inference engine initialized successfully.")
self.formatter = Formatter(en2ch)
def inference_batch(self, json_list: List[str]) -> dict:
"""
批量推理函数,输入为 JSON 字符串列表输出为包含转换概率的字典列表。为防止OOM列表最大长度为8。
请注意Json文件中的词条数必须大于等于5.
"""
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
id2feature = extract_json_data(json_files=json_list, threshold=5)
message_list = []
for id, feature in id2feature.items():
messages = self.formatter.get_llm_prompt(feature)
message_list.append(messages)
inputs = self.tokenizer.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True, enable_thinking=False)
model_inputs = self.tokenizer(
inputs,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
).to(self.device)
with torch.inference_mode():
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
outputs = self.model(model_inputs)
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
outputs_float = outputs.float()
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs]
return {"labels": preds, "probs": probs}
def inference_sample(self, json_path: str) -> dict:
"""
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
请注意Json文件中的词条数必须大于等于10.
"""
return self.inference_batch([json_path])
# def inference(
# self,
# featurs : dict[str ,dict]
# ):
# # assert len(featurs) <= 10, "单次输入json文件数量不可超过8。"
def inference_batch_json_data(self, json_data: List[dict]) -> dict:
"""
批量推理函数,输入为 JSON 数据输出为包含转换概率的字典列表。为防止OOM列表最大长度为8。
请注意Json文件中的词条数必须大于等于5. 但此处不进行过滤,请注意稍后对输出进行过滤。
"""
assert len(json_data) <= 8, "单次输入json数据数量不可超过8。"
pseudo_id2feature = load_data_from_dict(json_data)
message_list = []
for id, feature in pseudo_id2feature.items():
messages = self.formatter.get_llm_prompt(feature)
message_list.append(messages)
inputs = self.tokenizer.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True, enable_thinking=False)
model_inputs = self.tokenizer(
inputs,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
).to(self.device)
with torch.inference_mode():
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
outputs = self.model(model_inputs)
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
outputs_float = outputs.float()
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs]
return {"labels": preds, "probs": probs}
if __name__ == "__main__":
backbone_dir = "Qwen3-1.7B"
ckpt_path = "best_ckpt.pth"
device = "cuda"
engine = InferenceEngine(backbone_dir, ckpt_path, device)
import glob
deal_files = glob.glob(os.path.join("filtered_deal", "*.json"))
test_deal_files = deal_files[:4]
not_deal_files = glob.glob(os.path.join("filtered_not_deal", "*.json"))
test_not_deal_files = not_deal_files[:4]
test_files = test_deal_files + test_not_deal_files
test_dict = []
for test_file in test_files:
with open(test_file, "r", encoding="utf-8") as f:
json_data = json.load(f)
test_dict.append(json_data)
results = engine.inference_batch_json_data(test_dict)
print(results)