Files
deal-classification/inference.py
2026-02-03 14:17:33 +08:00

119 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from model import TransClassifier
from transformers import AutoTokenizer
from data_process import extract_json_data, Formatter
import torch
import json
from typing import Dict, List, Optional
import os
import random
import warnings
warnings.filterwarnings("ignore")
valid_keys = [
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
]
ch_valid_keys = [
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
]
all_keys = valid_keys + ["session_id", "label"]
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
d1_keys = valid_keys[:5]
d2_keys = valid_keys[5:10]
d3_keys = valid_keys[10:15]
d4_keys = valid_keys[15:19]
d5_keys = valid_keys[19:23]
class InferenceEngine:
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
self.backbone_dir = backbone_dir
self.ckpt_path = ckpt_path
self.device = device
# 加载 tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
print(f"Tokenizer loaded from {backbone_dir}")
# 加载模型
self.model = TransClassifier(backbone_dir, device)
self.model.to(device)
if self.ckpt_path:
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Model loaded from {ckpt_path}")
else:
print("Warning: No checkpoint path provided. Using untrained model.")
self.model.eval()
print("Inference engine initialized successfully.")
self.formatter = Formatter(en2ch)
def inference_batch(self, json_list: List[str]) -> dict:
"""
批量推理函数,输入为 JSON 字符串列表输出为包含转换概率的字典列表。为防止OOM列表最大长度为8。
请注意Json文件中的词条数必须大于等于10.
"""
# print(111111)
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
id2feature = extract_json_data(json_list)
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
# id2feature
message_list = []
for id, feature in id2feature.items():
messages = self.formatter.get_llm_prompt(feature)
message_list.append(messages)
inputs = self.tokenizer.apply_chat_template(
message_list,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = self.tokenizer(
inputs,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
).to(self.device)
with torch.inference_mode():
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
outputs = self.model(model_inputs)
# 1. 计算分类标签argmax
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
# 2. 计算softmax概率核心修正转CPU、转numpy、转列表解决Tensor序列化问题
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
# 转换为CPU的numpy数组再转列表每个样本对应2个类别的概率
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs] # 只保留类别1的概率
# 3. 计算置信度
confidence = [abs(p - 0.5) * 2 for p in probs]
# 返回格式labels是每个样本的分类标签列表probs是每个样本的类别概率列表confidence是每个样本的置信度列表
return {"labels": preds, "probs": probs, "confidence": confidence}
def inference_sample(self, json_path: str) -> dict:
"""
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
请注意Json文件中的词条数必须大于等于10.
"""
return self.inference_batch([json_path])
if __name__ == "__main__":
# 配置参数
backbone_dir = "Qwen3-1.7B"
ckpt_path = "best_ckpt.pth"
device = "cuda"
engine = InferenceEngine(backbone_dir, ckpt_path, device)