Refactor inference and main modules for improved functionality and add uvicorn dependency
This commit is contained in:
43
inference.py
43
inference.py
@@ -60,7 +60,8 @@ class InferenceEngine:
|
||||
请注意Json文件中的词条数必须大于等于10.
|
||||
"""
|
||||
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
||||
id2feature = extract_json_data(json_list) # id2feature
|
||||
id2feature = extract_json_data(json_list)
|
||||
# print(id2feature) # id2feature
|
||||
|
||||
message_list = []
|
||||
for id, feature in id2feature.items():
|
||||
@@ -97,6 +98,46 @@ class InferenceEngine:
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
def inference(
|
||||
self,
|
||||
featurs : dict[str ,dict]
|
||||
):
|
||||
assert len(featurs) <= 8, "单次输入json文件数量不可超过8。"
|
||||
message_list = []
|
||||
for id, feature in featurs.items():
|
||||
messages = self.formatter.get_llm_prompt(feature)
|
||||
message_list.append(messages)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
message_list,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
model_inputs = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||
outputs = self.model(model_inputs)
|
||||
|
||||
# 1. 计算分类标签(argmax)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
def inference_sample(self, json_path: str) -> dict:
|
||||
"""
|
||||
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
|
||||
|
||||
Reference in New Issue
Block a user