Files
deal-classification/inference.py
2026-01-29 19:03:32 +08:00

201 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from model import TransClassifier
from transformers import AutoTokenizer
from data_process import extract_json_data, Formatter
import torch
import json
from typing import Dict, List, Optional
import os
import random
import warnings
warnings.filterwarnings("ignore")
valid_keys = [
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
]
ch_valid_keys = [
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
]
all_keys = valid_keys + ["session_id", "label"]
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
d1_keys = valid_keys[:5]
d2_keys = valid_keys[5:10]
d3_keys = valid_keys[10:15]
d4_keys = valid_keys[15:19]
d5_keys = valid_keys[19:23]
class InferenceEngine:
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
self.backbone_dir = backbone_dir
self.ckpt_path = ckpt_path
self.device = device
# 加载 tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
print(f"Tokenizer loaded from {backbone_dir}")
# 加载模型
self.model = TransClassifier(backbone_dir, device)
self.model.to(device)
if self.ckpt_path:
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Model loaded from {ckpt_path}")
else:
print("Warning: No checkpoint path provided. Using untrained model.")
self.model.eval()
print("Inference engine initialized successfully.")
self.formatter = Formatter(en2ch)
def inference_batch(self, json_list: List[str]) -> dict:
"""
批量推理函数,输入为 JSON 字符串列表输出为包含转换概率的字典列表。为防止OOM列表最大长度为8。
请注意Json文件中的词条数必须大于等于10.
"""
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
id2feature = extract_json_data(json_list) # id2feature
message_list = []
for id, feature in id2feature.items():
messages = self.formatter.get_llm_prompt(feature)
message_list.append(messages)
inputs = self.tokenizer.apply_chat_template(
message_list,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = self.tokenizer(
inputs,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
).to(self.device)
with torch.inference_mode():
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
outputs = self.model(model_inputs)
# 1. 计算分类标签argmax
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
# 2. 计算softmax概率核心修正转CPU、转numpy、转列表解决Tensor序列化问题
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
# 转换为CPU的numpy数组再转列表每个样本对应2个类别的概率
probs = probs.cpu().numpy().tolist()
# 返回格式labels是每个样本的分类标签列表probs是每个样本的类别概率列表
return {"labels": preds, "probs": probs}
def inference_sample(self, json_path: str) -> dict:
"""
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
请注意Json文件中的词条数必须大于等于10.
"""
return self.inference_batch([json_path])
if __name__ == "__main__":
# 配置参数
backbone_dir = "Qwen3-1.7B"
ckpt_path = "best_ckpt.pth"
device = "cuda"
engine = InferenceEngine(backbone_dir, ckpt_path, device)
from data_process import extract_json_files
import random
# 获取成交和未成交的json文件路径
deal_files = extract_json_files("deal")
not_deal_files = extract_json_files("not_deal")
def filter_json_files_by_key_count(files: List[str], min_keys: int = 10) -> List[str]:
"""
过滤出JSON文件中字典键数量大于等于指定数量的文件
Args:
files: JSON文件路径列表
min_keys: 最小键数量要求默认为10
Returns:
符合条件的文件路径列表
"""
valid_files = []
for file_path in files:
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 检查是否为字典且键数量是否符合要求
if isinstance(data, dict) and len(data) >= min_keys:
valid_files.append(file_path)
else:
print(f"跳过文件 {os.path.basename(file_path)}: 键数量不足 ({len(data)} < {min_keys})")
except Exception as e:
print(f"读取文件 {file_path} 时出错: {e}")
return valid_files
deal_files_filtered = filter_json_files_by_key_count(deal_files, min_keys=10)
not_deal_files_filtered = filter_json_files_by_key_count(not_deal_files, min_keys=10)
num_samples = 8
# 计算每类需要选取的数量
num_deal_needed = min(4, len(deal_files_filtered)) # 最多选4个成交文件
num_not_deal_needed = min(4, len(not_deal_files_filtered)) # 最多选4个未成交文件
# 如果某类文件不足,从另一类补足
if num_deal_needed + num_not_deal_needed < num_samples:
if len(deal_files_filtered) > num_deal_needed:
num_deal_needed = min(num_samples, len(deal_files_filtered))
elif len(not_deal_files_filtered) > num_not_deal_needed:
num_not_deal_needed = min(num_samples, len(not_deal_files_filtered))
# 随机选取文件
selected_deal_files = random.sample(deal_files_filtered, min(num_deal_needed, len(deal_files_filtered))) if deal_files_filtered else []
selected_not_deal_files = random.sample(not_deal_files_filtered, min(num_not_deal_needed, len(not_deal_files_filtered))) if not_deal_files_filtered else []
# 合并选中的文件
selected_files = selected_deal_files + selected_not_deal_files
# 如果总数不足8个尝试从原始文件中随机选取补足
if len(selected_files) < num_samples:
all_files = deal_files + not_deal_files
# 排除已选的文件
remaining_files = [f for f in all_files if f not in selected_files]
additional_needed = num_samples - len(selected_files)
if remaining_files:
additional_files = random.sample(remaining_files, min(additional_needed, len(remaining_files)))
selected_files.extend(additional_files)
true_labels = []
for i, file_path in enumerate(selected_files):
folder_type = "未成交" if "not_deal" in file_path else "成交"
true_labels.append(folder_type)
# 使用inference_batch接口进行批量推理
if selected_files:
print("\n开始批量推理...")
try:
batch_result = engine.inference_batch(selected_files)
print(batch_result)
print(true_labels)
except Exception as e:
print(f"推理过程中出错: {e}")
else:
print("未找到符合条件的文件进行推理")
print("\n推理端口测试完成!")