From 06bcaad8d4a30413c5aa94dde6da65170e02fc2b Mon Sep 17 00:00:00 2001 From: WangZiFan Date: Tue, 3 Feb 2026 14:17:33 +0800 Subject: [PATCH] Update inference.py --- inference.py | 368 +++++++++++++++++---------------------------------- 1 file changed, 119 insertions(+), 249 deletions(-) diff --git a/inference.py b/inference.py index b347644..70adbe9 100644 --- a/inference.py +++ b/inference.py @@ -1,249 +1,119 @@ -from model import TransClassifier -from transformers import AutoTokenizer -from data_process import extract_json_data, Formatter -import torch -import json -from typing import Dict, List, Optional -import os -import random -import warnings -warnings.filterwarnings("ignore") - -valid_keys = [ - "Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index", - "Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity", - "Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage", - "Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus", - "Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential" -] -ch_valid_keys = [ - "核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数", - "社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度", - "沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏", - "低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成", - "竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力" -] -all_keys = valid_keys + ["session_id", "label"] -en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)} -d1_keys = valid_keys[:5] -d2_keys = valid_keys[5:10] -d3_keys = valid_keys[10:15] -d4_keys = valid_keys[15:19] -d5_keys = valid_keys[19:23] - -class InferenceEngine: - def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"): - self.backbone_dir = backbone_dir - self.ckpt_path = ckpt_path - self.device = device - - # 加载 tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir) - print(f"Tokenizer loaded from {backbone_dir}") - - # 加载模型 - self.model = TransClassifier(backbone_dir, device) - self.model.to(device) - if self.ckpt_path: - self.model.load_state_dict(torch.load(ckpt_path, map_location=device)) - print(f"Model loaded from {ckpt_path}") - else: - print("Warning: No checkpoint path provided. Using untrained model.") - self.model.eval() - print("Inference engine initialized successfully.") - - self.formatter = Formatter(en2ch) - - def inference_batch(self, json_list: List[str]) -> dict: - """ - 批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。 - 请注意Json文件中的词条数必须大于等于10. - """ - assert len(json_list) <= 8, "单次输入json文件数量不可超过8。" - id2feature = extract_json_data(json_list) - # print(id2feature) # id2feature - - message_list = [] - for id, feature in id2feature.items(): - messages = self.formatter.get_llm_prompt(feature) - message_list.append(messages) - - inputs = self.tokenizer.apply_chat_template( - message_list, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False - ) - model_inputs = self.tokenizer( - inputs, - padding=True, - truncation=True, - max_length=2048, - return_tensors="pt" - ).to(self.device) - - with torch.inference_mode(): - with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16): - outputs = self.model(model_inputs) - - # 1. 计算分类标签(argmax) - preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() - - # 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题) - outputs_float = outputs.float() # 转换为 float32 避免精度问题 - probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] - # 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率) - probs = probs.cpu().numpy().tolist() - - # 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表 - return {"labels": preds, "probs": probs} - - def inference( - self, - featurs : dict[str ,dict] - ): - assert len(featurs) <= 8, "单次输入json文件数量不可超过8。" - message_list = [] - for id, feature in featurs.items(): - messages = self.formatter.get_llm_prompt(feature) - message_list.append(messages) - - inputs = self.tokenizer.apply_chat_template( - message_list, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False - ) - model_inputs = self.tokenizer( - inputs, - padding=True, - truncation=True, - max_length=2048, - return_tensors="pt" - ).to(self.device) - - with torch.inference_mode(): - with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16): - outputs = self.model(model_inputs) - - # 1. 计算分类标签(argmax) - preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() - - # 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题) - outputs_float = outputs.float() # 转换为 float32 避免精度问题 - probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] - # 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率) - probs = probs.cpu().numpy().tolist() - - # 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表 - return {"labels": preds, "probs": probs} - - def inference_sample(self, json_path: str) -> dict: - """ - 单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。 - 请注意Json文件中的词条数必须大于等于10. - """ - return self.inference_batch([json_path]) - -# 配置参数 -backbone_dir = "Qwen3-1.7B" -ckpt_path = "best_ckpt.pth" -device = "cuda" - -engine = InferenceEngine(backbone_dir, ckpt_path, device) - -if __name__ == "__main__": - # 配置参数 - backbone_dir = "Qwen3-1.7B" - ckpt_path = "best_ckpt.pth" - device = "cuda" - - engine = InferenceEngine(backbone_dir, ckpt_path, device) - - from data_process import extract_json_files - import random - - # 获取成交和未成交的json文件路径 - deal_files = extract_json_files("deal") - not_deal_files = extract_json_files("not_deal") - - def filter_json_files_by_key_count(files: List[str], min_keys: int = 10) -> List[str]: - """ - 过滤出JSON文件中字典键数量大于等于指定数量的文件 - - Args: - files: JSON文件路径列表 - min_keys: 最小键数量要求,默认为10 - - Returns: - 符合条件的文件路径列表 - """ - valid_files = [] - - for file_path in files: - try: - with open(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) - - # 检查是否为字典且键数量是否符合要求 - if isinstance(data, dict) and len(data) >= min_keys: - valid_files.append(file_path) - else: - print(f"跳过文件 {os.path.basename(file_path)}: 键数量不足 ({len(data)} < {min_keys})") - except Exception as e: - print(f"读取文件 {file_path} 时出错: {e}") - - return valid_files - - deal_files_filtered = filter_json_files_by_key_count(deal_files, min_keys=10) - not_deal_files_filtered = filter_json_files_by_key_count(not_deal_files, min_keys=10) - - num_samples = 8 - - # 计算每类需要选取的数量 - num_deal_needed = min(4, len(deal_files_filtered)) # 最多选4个成交文件 - num_not_deal_needed = min(4, len(not_deal_files_filtered)) # 最多选4个未成交文件 - - # 如果某类文件不足,从另一类补足 - if num_deal_needed + num_not_deal_needed < num_samples: - if len(deal_files_filtered) > num_deal_needed: - num_deal_needed = min(num_samples, len(deal_files_filtered)) - elif len(not_deal_files_filtered) > num_not_deal_needed: - num_not_deal_needed = min(num_samples, len(not_deal_files_filtered)) - - # 随机选取文件 - selected_deal_files = random.sample(deal_files_filtered, min(num_deal_needed, len(deal_files_filtered))) if deal_files_filtered else [] - selected_not_deal_files = random.sample(not_deal_files_filtered, min(num_not_deal_needed, len(not_deal_files_filtered))) if not_deal_files_filtered else [] - - # 合并选中的文件 - selected_files = selected_deal_files + selected_not_deal_files - - # 如果总数不足8个,尝试从原始文件中随机选取补足 - if len(selected_files) < num_samples: - all_files = deal_files + not_deal_files - # 排除已选的文件 - remaining_files = [f for f in all_files if f not in selected_files] - additional_needed = num_samples - len(selected_files) - if remaining_files: - additional_files = random.sample(remaining_files, min(additional_needed, len(remaining_files))) - selected_files.extend(additional_files) - - true_labels = [] - for i, file_path in enumerate(selected_files): - folder_type = "未成交" if "not_deal" in file_path else "成交" - true_labels.append(folder_type) - - # 使用inference_batch接口进行批量推理 - if selected_files: - print("\n开始批量推理...") - try: - batch_result = engine.inference_batch(selected_files) - print(batch_result) - print(true_labels) - - except Exception as e: - print(f"推理过程中出错: {e}") - else: - print("未找到符合条件的文件进行推理") - - print("\n推理端口测试完成!") \ No newline at end of file +from model import TransClassifier +from transformers import AutoTokenizer +from data_process import extract_json_data, Formatter +import torch +import json +from typing import Dict, List, Optional +import os +import random +import warnings +warnings.filterwarnings("ignore") + +valid_keys = [ + "Core_Fear_Source", "Pain_Threshold", "Time_Window_Pressure", "Helplessness_Index", + "Social_Shame", "Payer_Decision_Maker", "Hidden_Wealth_Proof", "Price_Sensitivity", + "Sunk_Cost", "Compensatory_Spending", "Trust_Deficit", "Secret_Resistance", "Family_Sabotage", + "Low_Self_Efficacy", "Attribution_Barrier", "Emotional_Trigger", "Ultimatum_Event", "Expectation_Bonus", + "Competitor_Mindset", "Cognitive_Stage", "Follow_up_Priority", "Last_Interaction", "Referral_Potential" +] +ch_valid_keys = [ + "核心恐惧源", "疼痛阈值", "时间窗口压力", "无助指数", + "社会羞耻感", "付款决策者", "隐藏财富证明", "价格敏感度", + "沉没成本", "补偿性消费", "信任赤字", "秘密抵触情绪", "家庭破坏", + "低自我效能感", "归因障碍", "情绪触发点", "最后通牒事件", "期望加成", + "竞争者心态", "认知阶段", "跟进优先级", "最后互动时间", "推荐潜力" +] +all_keys = valid_keys + ["session_id", "label"] +en2ch = {en:ch for en, ch in zip(valid_keys, ch_valid_keys)} +d1_keys = valid_keys[:5] +d2_keys = valid_keys[5:10] +d3_keys = valid_keys[10:15] +d4_keys = valid_keys[15:19] +d5_keys = valid_keys[19:23] + +class InferenceEngine: + def __init__(self, backbone_dir: str, ckpt_path: str = "best_ckpt.pth", device: str = "cuda"): + self.backbone_dir = backbone_dir + self.ckpt_path = ckpt_path + self.device = device + + # 加载 tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(backbone_dir) + print(f"Tokenizer loaded from {backbone_dir}") + + # 加载模型 + self.model = TransClassifier(backbone_dir, device) + self.model.to(device) + if self.ckpt_path: + self.model.load_state_dict(torch.load(ckpt_path, map_location=device)) + print(f"Model loaded from {ckpt_path}") + else: + print("Warning: No checkpoint path provided. Using untrained model.") + self.model.eval() + print("Inference engine initialized successfully.") + + self.formatter = Formatter(en2ch) + + def inference_batch(self, json_list: List[str]) -> dict: + """ + 批量推理函数,输入为 JSON 字符串列表,输出为包含转换概率的字典列表。为防止OOM,列表最大长度为8。 + 请注意Json文件中的词条数必须大于等于10. + """ + # print(111111) + assert len(json_list) <= 8, "单次输入json文件数量不可超过8。" + id2feature = extract_json_data(json_list) + print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False)) + # id2feature + + message_list = [] + for id, feature in id2feature.items(): + messages = self.formatter.get_llm_prompt(feature) + message_list.append(messages) + + inputs = self.tokenizer.apply_chat_template( + message_list, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + model_inputs = self.tokenizer( + inputs, + padding=True, + truncation=True, + max_length=2048, + return_tensors="pt" + ).to(self.device) + + with torch.inference_mode(): + with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16): + outputs = self.model(model_inputs) + + # 1. 计算分类标签(argmax) + preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() + + # 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题) + outputs_float = outputs.float() # 转换为 float32 避免精度问题 + probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] + # 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率) + probs = probs.cpu().numpy().tolist() + probs = [p[1] for p in probs] # 只保留类别1的概率 + + # 3. 计算置信度 + confidence = [abs(p - 0.5) * 2 for p in probs] + # 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表 + return {"labels": preds, "probs": probs, "confidence": confidence} + + def inference_sample(self, json_path: str) -> dict: + """ + 单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。 + 请注意Json文件中的词条数必须大于等于10. + """ + return self.inference_batch([json_path]) + +if __name__ == "__main__": + # 配置参数 + backbone_dir = "Qwen3-1.7B" + ckpt_path = "best_ckpt.pth" + device = "cuda" + + engine = InferenceEngine(backbone_dir, ckpt_path, device) \ No newline at end of file