Update inference.py

This commit is contained in:
2026-02-03 14:17:33 +08:00
parent 0b168057ca
commit 06bcaad8d4

View File

@@ -59,9 +59,11 @@ class InferenceEngine:
批量推理函数,输入为 JSON 字符串列表输出为包含转换概率的字典列表。为防止OOM列表最大长度为8。 批量推理函数,输入为 JSON 字符串列表输出为包含转换概率的字典列表。为防止OOM列表最大长度为8。
请注意Json文件中的词条数必须大于等于10. 请注意Json文件中的词条数必须大于等于10.
""" """
# print(111111)
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。" assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
id2feature = extract_json_data(json_list) id2feature = extract_json_data(json_list)
# print(id2feature) # id2feature print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
# id2feature
message_list = [] message_list = []
for id, feature in id2feature.items(): for id, feature in id2feature.items():
@@ -94,49 +96,12 @@ class InferenceEngine:
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
# 转换为CPU的numpy数组再转列表每个样本对应2个类别的概率 # 转换为CPU的numpy数组再转列表每个样本对应2个类别的概率
probs = probs.cpu().numpy().tolist() probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs] # 只保留类别1的概率
# 返回格式labels是每个样本的分类标签列表probs是每个样本的类别概率列表 # 3. 计算置信度
return {"labels": preds, "probs": probs} confidence = [abs(p - 0.5) * 2 for p in probs]
# 返回格式labels是每个样本的分类标签列表probs是每个样本的类别概率列表confidence是每个样本的置信度列表
def inference( return {"labels": preds, "probs": probs, "confidence": confidence}
self,
featurs : dict[str ,dict]
):
assert len(featurs) <= 8, "单次输入json文件数量不可超过8。"
message_list = []
for id, feature in featurs.items():
messages = self.formatter.get_llm_prompt(feature)
message_list.append(messages)
inputs = self.tokenizer.apply_chat_template(
message_list,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = self.tokenizer(
inputs,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
).to(self.device)
with torch.inference_mode():
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
outputs = self.model(model_inputs)
# 1. 计算分类标签argmax
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
# 2. 计算softmax概率核心修正转CPU、转numpy、转列表解决Tensor序列化问题
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
# 转换为CPU的numpy数组再转列表每个样本对应2个类别的概率
probs = probs.cpu().numpy().tolist()
# 返回格式labels是每个样本的分类标签列表probs是每个样本的类别概率列表
return {"labels": preds, "probs": probs}
def inference_sample(self, json_path: str) -> dict: def inference_sample(self, json_path: str) -> dict:
""" """
@@ -145,13 +110,6 @@ class InferenceEngine:
""" """
return self.inference_batch([json_path]) return self.inference_batch([json_path])
# 配置参数
backbone_dir = "Qwen3-1.7B"
ckpt_path = "best_ckpt.pth"
device = "cuda"
engine = InferenceEngine(backbone_dir, ckpt_path, device)
if __name__ == "__main__": if __name__ == "__main__":
# 配置参数 # 配置参数
backbone_dir = "Qwen3-1.7B" backbone_dir = "Qwen3-1.7B"
@@ -159,91 +117,3 @@ if __name__ == "__main__":
device = "cuda" device = "cuda"
engine = InferenceEngine(backbone_dir, ckpt_path, device) engine = InferenceEngine(backbone_dir, ckpt_path, device)
from data_process import extract_json_files
import random
# 获取成交和未成交的json文件路径
deal_files = extract_json_files("deal")
not_deal_files = extract_json_files("not_deal")
def filter_json_files_by_key_count(files: List[str], min_keys: int = 10) -> List[str]:
"""
过滤出JSON文件中字典键数量大于等于指定数量的文件
Args:
files: JSON文件路径列表
min_keys: 最小键数量要求默认为10
Returns:
符合条件的文件路径列表
"""
valid_files = []
for file_path in files:
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 检查是否为字典且键数量是否符合要求
if isinstance(data, dict) and len(data) >= min_keys:
valid_files.append(file_path)
else:
print(f"跳过文件 {os.path.basename(file_path)}: 键数量不足 ({len(data)} < {min_keys})")
except Exception as e:
print(f"读取文件 {file_path} 时出错: {e}")
return valid_files
deal_files_filtered = filter_json_files_by_key_count(deal_files, min_keys=10)
not_deal_files_filtered = filter_json_files_by_key_count(not_deal_files, min_keys=10)
num_samples = 8
# 计算每类需要选取的数量
num_deal_needed = min(4, len(deal_files_filtered)) # 最多选4个成交文件
num_not_deal_needed = min(4, len(not_deal_files_filtered)) # 最多选4个未成交文件
# 如果某类文件不足,从另一类补足
if num_deal_needed + num_not_deal_needed < num_samples:
if len(deal_files_filtered) > num_deal_needed:
num_deal_needed = min(num_samples, len(deal_files_filtered))
elif len(not_deal_files_filtered) > num_not_deal_needed:
num_not_deal_needed = min(num_samples, len(not_deal_files_filtered))
# 随机选取文件
selected_deal_files = random.sample(deal_files_filtered, min(num_deal_needed, len(deal_files_filtered))) if deal_files_filtered else []
selected_not_deal_files = random.sample(not_deal_files_filtered, min(num_not_deal_needed, len(not_deal_files_filtered))) if not_deal_files_filtered else []
# 合并选中的文件
selected_files = selected_deal_files + selected_not_deal_files
# 如果总数不足8个尝试从原始文件中随机选取补足
if len(selected_files) < num_samples:
all_files = deal_files + not_deal_files
# 排除已选的文件
remaining_files = [f for f in all_files if f not in selected_files]
additional_needed = num_samples - len(selected_files)
if remaining_files:
additional_files = random.sample(remaining_files, min(additional_needed, len(remaining_files)))
selected_files.extend(additional_files)
true_labels = []
for i, file_path in enumerate(selected_files):
folder_type = "未成交" if "not_deal" in file_path else "成交"
true_labels.append(folder_type)
# 使用inference_batch接口进行批量推理
if selected_files:
print("\n开始批量推理...")
try:
batch_result = engine.inference_batch(selected_files)
print(batch_result)
print(true_labels)
except Exception as e:
print(f"推理过程中出错: {e}")
else:
print("未找到符合条件的文件进行推理")
print("\n推理端口测试完成!")