修复代码
This commit is contained in:
46
inference.py
46
inference.py
@@ -60,7 +60,7 @@ class InferenceEngine:
|
|||||||
请注意Json文件中的词条数必须大于等于10.
|
请注意Json文件中的词条数必须大于等于10.
|
||||||
"""
|
"""
|
||||||
# print(111111)
|
# print(111111)
|
||||||
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
assert len(json_list) <= 10, "单次输入json文件数量不可超过8。"
|
||||||
id2feature = extract_json_data(json_list)
|
id2feature = extract_json_data(json_list)
|
||||||
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
|
print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False))
|
||||||
# id2feature
|
# id2feature
|
||||||
@@ -110,6 +110,50 @@ class InferenceEngine:
|
|||||||
"""
|
"""
|
||||||
return self.inference_batch([json_path])
|
return self.inference_batch([json_path])
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
featurs : dict[str ,dict]
|
||||||
|
):
|
||||||
|
assert len(featurs) <= 10, "单次输入json文件数量不可超过8。"
|
||||||
|
message_list = []
|
||||||
|
for id, feature in featurs.items():
|
||||||
|
messages = self.formatter.get_llm_prompt(feature)
|
||||||
|
message_list.append(messages)
|
||||||
|
|
||||||
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
|
message_list,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
inputs,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=2048,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||||
|
outputs = self.model(model_inputs)
|
||||||
|
|
||||||
|
# 1. 计算分类标签(argmax)
|
||||||
|
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||||
|
|
||||||
|
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||||
|
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||||
|
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||||
|
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||||
|
probs = probs.cpu().numpy().tolist()
|
||||||
|
probs = [p[1] for p in probs] # 只保留类别1的概率
|
||||||
|
|
||||||
|
# 3. 计算置信度
|
||||||
|
confidence = [abs(p - 0.5) * 2 for p in probs]
|
||||||
|
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表
|
||||||
|
return {"labels": preds, "probs": probs, "confidence": confidence}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 配置参数
|
# 配置参数
|
||||||
backbone_dir = "Qwen3-1.7B"
|
backbone_dir = "Qwen3-1.7B"
|
||||||
|
|||||||
12
main.py
12
main.py
@@ -1,10 +1,18 @@
|
|||||||
from feature_extraction import process_single
|
from feature_extraction import process_single
|
||||||
from inference import engine
|
from inference import InferenceEngine
|
||||||
from services.mongo import voice_collection
|
from services.mongo import voice_collection
|
||||||
import json,uuid
|
import json,uuid
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
backbone_dir = "Qwen3-1.7B"
|
||||||
|
ckpt_path = "best_ckpt.pth"
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_customer_record():
|
async def get_customer_record():
|
||||||
cursor = voice_collection.find({
|
cursor = voice_collection.find({
|
||||||
"tag": "20分钟通话",
|
"tag": "20分钟通话",
|
||||||
@@ -14,7 +22,7 @@ async def get_customer_record():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}).sort([('_id', -1)]).limit(24)
|
}).sort([('_id', -1)]).limit(24)
|
||||||
return await cursor.to_list(length=4)
|
return await cursor.to_list(length=1)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user