Update inference.py
This commit is contained in:
111
inference.py
111
inference.py
@@ -1,11 +1,10 @@
|
||||
from model import TransClassifier
|
||||
from transformers import AutoTokenizer
|
||||
from data_process import extract_json_data, Formatter
|
||||
from data_process import extract_json_data, Formatter, load_data_from_dict
|
||||
import torch
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@@ -14,14 +13,14 @@ valid_keys = [
|
||||
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||
"Competitor_Mindset", "Cognitive_Stage", "Last_Interaction", "Referral_Potential"
|
||||
]
|
||||
ch_valid_keys = [
|
||||
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||
"竞争者心态", "认知阶段", "最后互动时间", "推荐潜力"
|
||||
]
|
||||
all_keys = valid_keys + ["session_id", "label"]
|
||||
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||
@@ -29,7 +28,7 @@ d1_keys = valid_keys[:5]
|
||||
d2_keys = valid_keys[5:10]
|
||||
d3_keys = valid_keys[10:15]
|
||||
d4_keys = valid_keys[15:19]
|
||||
d5_keys = valid_keys[19:23]
|
||||
d5_keys = valid_keys[19:22]
|
||||
|
||||
class InferenceEngine:
|
||||
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
|
||||
@@ -42,7 +41,7 @@ class InferenceEngine:
|
||||
print(f"Tokenizer loaded from {backbone_dir}")
|
||||
|
||||
# 加载模型
|
||||
self.model = TransClassifier(backbone_dir, device)
|
||||
self.model = TransClassifier(backbone_dir, 2, device)
|
||||
self.model.to(device)
|
||||
if self.ckpt_path:
|
||||
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
@@ -57,25 +56,17 @@ class InferenceEngine:
|
||||
def inference_batch(self, json_list: List[str]) -> dict:
|
||||
"""
|
||||
批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
||||
请注意Json文件中的词条数必须大于等于10.
|
||||
请注意Json文件中的词条数必须大于等于5.
|
||||
"""
|
||||
# print(111111)
|
||||
assert len(json_list) <= 10, "单次输入json文件数量不可超过8。"
|
||||
id2feature = extract_json_data(json_list)
|
||||
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
|
||||
# id2feature
|
||||
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
||||
id2feature = extract_json_data(json_files=json_list, threshold=5)
|
||||
|
||||
message_list = []
|
||||
for id, feature in id2feature.items():
|
||||
messages = self.formatter.get_llm_prompt(feature)
|
||||
message_list.append(messages)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
message_list,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
inputs = self.tokenizer.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
model_inputs = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
@@ -87,21 +78,12 @@ class InferenceEngine:
|
||||
with torch.inference_mode():
|
||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||
outputs = self.model(model_inputs)
|
||||
|
||||
# 1. 计算分类标签(argmax)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||
|
||||
# 3. 计算置信度
|
||||
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
outputs_float = outputs.float()
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs]
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
def inference_sample(self, json_path: str) -> dict:
|
||||
"""
|
||||
@@ -109,24 +91,21 @@ class InferenceEngine:
|
||||
请注意Json文件中的词条数必须大于等于10.
|
||||
"""
|
||||
return self.inference_batch([json_path])
|
||||
|
||||
def inference(
|
||||
self,
|
||||
featurs : dict[str ,dict]
|
||||
):
|
||||
assert len(featurs) <= 10, "单次输入json文件数量不可超过8。"
|
||||
|
||||
def inference_batch_json_data(self, json_data: List[dict]) -> dict:
|
||||
"""
|
||||
批量推理函数,输入为 JSON 数据,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
||||
请注意Json文件中的词条数必须大于等于5. 但此处不进行过滤,请注意稍后对输出进行过滤。
|
||||
"""
|
||||
assert len(json_data) <= 8, "单次输入json数据数量不可超过8。"
|
||||
pseudo_id2feature = load_data_from_dict(json_data)
|
||||
|
||||
message_list = []
|
||||
for id, feature in featurs.items():
|
||||
for id, feature in pseudo_id2feature.items():
|
||||
messages = self.formatter.get_llm_prompt(feature)
|
||||
message_list.append(messages)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
message_list,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
model_inputs = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
@@ -138,26 +117,30 @@ class InferenceEngine:
|
||||
with torch.inference_mode():
|
||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||
outputs = self.model(model_inputs)
|
||||
|
||||
# 1. 计算分类标签(argmax)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||
|
||||
# 3. 计算置信度
|
||||
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
outputs_float = outputs.float()
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
probs = [p[1] for p in probs]
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
backbone_dir = "Qwen3-1.7B"
|
||||
ckpt_path = "best_ckpt.pth"
|
||||
device = "cuda"
|
||||
|
||||
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||
import glob
|
||||
deal_files = glob.glob(os.path.join("filtered_deal", "*.json"))
|
||||
test_deal_files = deal_files[:4]
|
||||
not_deal_files = glob.glob(os.path.join("filtered_not_deal", "*.json"))
|
||||
test_not_deal_files = not_deal_files[:4]
|
||||
|
||||
test_files = test_deal_files + test_not_deal_files
|
||||
test_dict = []
|
||||
for test_file in test_files:
|
||||
with open(test_file, "r", encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
test_dict.append(json_data)
|
||||
results = engine.inference_batch_json_data(test_dict)
|
||||
print(results)
|
||||
Reference in New Issue
Block a user