修复代码

This commit is contained in:
2026-02-03 16:49:00 +08:00
parent 06bcaad8d4
commit fed16f04e1
2 changed files with 55 additions and 3 deletions

View File

@@ -60,7 +60,7 @@ class InferenceEngine:
请注意Json文件中的词条数必须大于等于10.
"""
# print(111111)
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
assert len(json_list) <= 10, "单次输入json文件数量不可超过8。"
id2feature = extract_json_data(json_list)
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
# id2feature
@@ -109,6 +109,50 @@ class InferenceEngine:
请注意Json文件中的词条数必须大于等于10.
"""
return self.inference_batch([json_path])
def inference(
self,
featurs : dict[str ,dict]
):
assert len(featurs) <= 10, "单次输入json文件数量不可超过8。"
message_list = []
for id, feature in featurs.items():
messages = self.formatter.get_llm_prompt(feature)
message_list.append(messages)
inputs = self.tokenizer.apply_chat_template(
message_list,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = self.tokenizer(
inputs,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
).to(self.device)
with torch.inference_mode():
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
outputs = self.model(model_inputs)
# 1. 计算分类标签argmax
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
# 2. 计算softmax概率核心修正转CPU、转numpy、转列表解决Tensor序列化问题
outputs_float = outputs.float() # 转换为 float32 避免精度问题
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
# 转换为CPU的numpy数组再转列表每个样本对应2个类别的概率
probs = probs.cpu().numpy().tolist()
probs = [p[1] for p in probs] # 只保留类别1的概率
# 3. 计算置信度
confidence = [abs(p - 0.5) * 2 for p in probs]
# 返回格式labels是每个样本的分类标签列表probs是每个样本的类别概率列表confidence是每个样本的置信度列表
return {"labels": preds, "probs": probs, "confidence": confidence}
if __name__ == "__main__":
# 配置参数