Update inference.py
This commit is contained in:
368
inference.py
368
inference.py
@@ -1,249 +1,119 @@
|
|||||||
from model import TransClassifier
|
from model import TransClassifier
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from data_process import extract_json_data, Formatter
|
from data_process import extract_json_data, Formatter
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
valid_keys = [
|
valid_keys = [
|
||||||
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
|
"Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index",
|
||||||
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
"Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity",
|
||||||
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
"Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage",
|
||||||
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
"Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus",
|
||||||
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
"Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential"
|
||||||
]
|
]
|
||||||
ch_valid_keys = [
|
ch_valid_keys = [
|
||||||
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
"核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数",
|
||||||
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
"社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度",
|
||||||
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
"沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏",
|
||||||
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
"低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成",
|
||||||
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
"竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力"
|
||||||
]
|
]
|
||||||
all_keys = valid_keys + ["session_id", "label"]
|
all_keys = valid_keys + ["session_id", "label"]
|
||||||
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)}
|
||||||
d1_keys = valid_keys[:5]
|
d1_keys = valid_keys[:5]
|
||||||
d2_keys = valid_keys[5:10]
|
d2_keys = valid_keys[5:10]
|
||||||
d3_keys = valid_keys[10:15]
|
d3_keys = valid_keys[10:15]
|
||||||
d4_keys = valid_keys[15:19]
|
d4_keys = valid_keys[15:19]
|
||||||
d5_keys = valid_keys[19:23]
|
d5_keys = valid_keys[19:23]
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
|
def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"):
|
||||||
self.backbone_dir = backbone_dir
|
self.backbone_dir = backbone_dir
|
||||||
self.ckpt_path = ckpt_path
|
self.ckpt_path = ckpt_path
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
# 加载 tokenizer
|
# 加载 tokenizer
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir)
|
||||||
print(f"Tokenizer loaded from {backbone_dir}")
|
print(f"Tokenizer loaded from {backbone_dir}")
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
self.model = TransClassifier(backbone_dir, device)
|
self.model = TransClassifier(backbone_dir, device)
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
if self.ckpt_path:
|
if self.ckpt_path:
|
||||||
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||||
print(f"Model loaded from {ckpt_path}")
|
print(f"Model loaded from {ckpt_path}")
|
||||||
else:
|
else:
|
||||||
print("Warning: No checkpoint path provided. Using untrained model.")
|
print("Warning: No checkpoint path provided. Using untrained model.")
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
print("Inference engine initialized successfully.")
|
print("Inference engine initialized successfully.")
|
||||||
|
|
||||||
self.formatter = Formatter(en2ch)
|
self.formatter = Formatter(en2ch)
|
||||||
|
|
||||||
def inference_batch(self, json_list: List[str]) -> dict:
|
def inference_batch(self, json_list: List[str]) -> dict:
|
||||||
"""
|
"""
|
||||||
批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。
|
||||||
请注意Json文件中的词条数必须大于等于10.
|
请注意Json文件中的词条数必须大于等于10.
|
||||||
"""
|
"""
|
||||||
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
# print(111111)
|
||||||
id2feature = extract_json_data(json_list)
|
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
||||||
# print(id2feature) # id2feature
|
id2feature = extract_json_data(json_list)
|
||||||
|
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
|
||||||
message_list = []
|
# id2feature
|
||||||
for id, feature in id2feature.items():
|
|
||||||
messages = self.formatter.get_llm_prompt(feature)
|
message_list = []
|
||||||
message_list.append(messages)
|
for id, feature in id2feature.items():
|
||||||
|
messages = self.formatter.get_llm_prompt(feature)
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
message_list.append(messages)
|
||||||
message_list,
|
|
||||||
tokenize=False,
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
add_generation_prompt=True,
|
message_list,
|
||||||
enable_thinking=False
|
tokenize=False,
|
||||||
)
|
add_generation_prompt=True,
|
||||||
model_inputs = self.tokenizer(
|
enable_thinking=False
|
||||||
inputs,
|
)
|
||||||
padding=True,
|
model_inputs = self.tokenizer(
|
||||||
truncation=True,
|
inputs,
|
||||||
max_length=2048,
|
padding=True,
|
||||||
return_tensors="pt"
|
truncation=True,
|
||||||
).to(self.device)
|
max_length=2048,
|
||||||
|
return_tensors="pt"
|
||||||
with torch.inference_mode():
|
).to(self.device)
|
||||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
|
||||||
outputs = self.model(model_inputs)
|
with torch.inference_mode():
|
||||||
|
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||||
# 1. 计算分类标签(argmax)
|
outputs = self.model(model_inputs)
|
||||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
|
||||||
|
# 1. 计算分类标签(argmax)
|
||||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
|
||||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||||
probs = probs.cpu().numpy().tolist()
|
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||||
|
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
probs = probs.cpu().numpy().tolist()
|
||||||
return {"labels": preds, "probs": probs}
|
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||||
|
|
||||||
def inference(
|
# 3. 计算置信度
|
||||||
self,
|
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||||
featurs : dict[str ,dict]
|
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||||
):
|
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||||
assert len(featurs) <= 8, "单次输入json文件数量不可超过8。"
|
|
||||||
message_list = []
|
def inference_sample(self, json_path: str) -> dict:
|
||||||
for id, feature in featurs.items():
|
"""
|
||||||
messages = self.formatter.get_llm_prompt(feature)
|
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
|
||||||
message_list.append(messages)
|
请注意Json文件中的词条数必须大于等于10.
|
||||||
|
"""
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
return self.inference_batch([json_path])
|
||||||
message_list,
|
|
||||||
tokenize=False,
|
if __name__ == "__main__":
|
||||||
add_generation_prompt=True,
|
# 配置参数
|
||||||
enable_thinking=False
|
backbone_dir = "Qwen3-1.7B"
|
||||||
)
|
ckpt_path = "best_ckpt.pth"
|
||||||
model_inputs = self.tokenizer(
|
device = "cuda"
|
||||||
inputs,
|
|
||||||
padding=True,
|
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||||
truncation=True,
|
|
||||||
max_length=2048,
|
|
||||||
return_tensors="pt"
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
|
||||||
outputs = self.model(model_inputs)
|
|
||||||
|
|
||||||
# 1. 计算分类标签(argmax)
|
|
||||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
|
||||||
|
|
||||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
|
||||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
|
||||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
|
||||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
|
||||||
probs = probs.cpu().numpy().tolist()
|
|
||||||
|
|
||||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
|
||||||
return {"labels": preds, "probs": probs}
|
|
||||||
|
|
||||||
def inference_sample(self, json_path: str) -> dict:
|
|
||||||
"""
|
|
||||||
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
|
|
||||||
请注意Json文件中的词条数必须大于等于10.
|
|
||||||
"""
|
|
||||||
return self.inference_batch([json_path])
|
|
||||||
|
|
||||||
# 配置参数
|
|
||||||
backbone_dir = "Qwen3-1.7B"
|
|
||||||
ckpt_path = "best_ckpt.pth"
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 配置参数
|
|
||||||
backbone_dir = "Qwen3-1.7B"
|
|
||||||
ckpt_path = "best_ckpt.pth"
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
|
||||||
|
|
||||||
from data_process import extract_json_files
|
|
||||||
import random
|
|
||||||
|
|
||||||
# 获取成交和未成交的json文件路径
|
|
||||||
deal_files = extract_json_files("deal")
|
|
||||||
not_deal_files = extract_json_files("not_deal")
|
|
||||||
|
|
||||||
def filter_json_files_by_key_count(files: List[str], min_keys: int = 10) -> List[str]:
|
|
||||||
"""
|
|
||||||
过滤出JSON文件中字典键数量大于等于指定数量的文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: JSON文件路径列表
|
|
||||||
min_keys: 最小键数量要求,默认为10
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
符合条件的文件路径列表
|
|
||||||
"""
|
|
||||||
valid_files = []
|
|
||||||
|
|
||||||
for file_path in files:
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
# 检查是否为字典且键数量是否符合要求
|
|
||||||
if isinstance(data, dict) and len(data) >= min_keys:
|
|
||||||
valid_files.append(file_path)
|
|
||||||
else:
|
|
||||||
print(f"跳过文件 {os.path.basename(file_path)}: 键数量不足 ({len(data)} < {min_keys})")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"读取文件 {file_path} 时出错: {e}")
|
|
||||||
|
|
||||||
return valid_files
|
|
||||||
|
|
||||||
deal_files_filtered = filter_json_files_by_key_count(deal_files, min_keys=10)
|
|
||||||
not_deal_files_filtered = filter_json_files_by_key_count(not_deal_files, min_keys=10)
|
|
||||||
|
|
||||||
num_samples = 8
|
|
||||||
|
|
||||||
# 计算每类需要选取的数量
|
|
||||||
num_deal_needed = min(4, len(deal_files_filtered)) # 最多选4个成交文件
|
|
||||||
num_not_deal_needed = min(4, len(not_deal_files_filtered)) # 最多选4个未成交文件
|
|
||||||
|
|
||||||
# 如果某类文件不足,从另一类补足
|
|
||||||
if num_deal_needed + num_not_deal_needed < num_samples:
|
|
||||||
if len(deal_files_filtered) > num_deal_needed:
|
|
||||||
num_deal_needed = min(num_samples, len(deal_files_filtered))
|
|
||||||
elif len(not_deal_files_filtered) > num_not_deal_needed:
|
|
||||||
num_not_deal_needed = min(num_samples, len(not_deal_files_filtered))
|
|
||||||
|
|
||||||
# 随机选取文件
|
|
||||||
selected_deal_files = random.sample(deal_files_filtered, min(num_deal_needed, len(deal_files_filtered))) if deal_files_filtered else []
|
|
||||||
selected_not_deal_files = random.sample(not_deal_files_filtered, min(num_not_deal_needed, len(not_deal_files_filtered))) if not_deal_files_filtered else []
|
|
||||||
|
|
||||||
# 合并选中的文件
|
|
||||||
selected_files = selected_deal_files + selected_not_deal_files
|
|
||||||
|
|
||||||
# 如果总数不足8个,尝试从原始文件中随机选取补足
|
|
||||||
if len(selected_files) < num_samples:
|
|
||||||
all_files = deal_files + not_deal_files
|
|
||||||
# 排除已选的文件
|
|
||||||
remaining_files = [f for f in all_files if f not in selected_files]
|
|
||||||
additional_needed = num_samples - len(selected_files)
|
|
||||||
if remaining_files:
|
|
||||||
additional_files = random.sample(remaining_files, min(additional_needed, len(remaining_files)))
|
|
||||||
selected_files.extend(additional_files)
|
|
||||||
|
|
||||||
true_labels = []
|
|
||||||
for i, file_path in enumerate(selected_files):
|
|
||||||
folder_type = "未成交" if "not_deal" in file_path else "成交"
|
|
||||||
true_labels.append(folder_type)
|
|
||||||
|
|
||||||
# 使用inference_batch接口进行批量推理
|
|
||||||
if selected_files:
|
|
||||||
print("\n开始批量推理...")
|
|
||||||
try:
|
|
||||||
batch_result = engine.inference_batch(selected_files)
|
|
||||||
print(batch_result)
|
|
||||||
print(true_labels)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"推理过程中出错: {e}")
|
|
||||||
else:
|
|
||||||
print("未找到符合条件的文件进行推理")
|
|
||||||
|
|
||||||
print("\n推理端口测试完成!")
|
|
||||||
Reference in New Issue
Block a user