From fed16f04e191a2702792d5281a933ef2598c18fa Mon Sep 17 00:00:00 2001 From: Tordor <3262978839@qq.com> Date: Tue, 3 Feb 2026 16:49:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inference.py | 46 +++++++++++++++++++++++++++++++++++++++++++++- main.py | 12 ++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/inference.py b/inference.py index 70adbe9..d13e641 100644 --- a/inference.py +++ b/inference.py @@ -60,7 +60,7 @@ class InferenceEngine: 请注意Json文件中的词条数必须大于等于10. """ # print(111111) - assert len(json_list) <= 8, "单次输入json文件数量不可超过8。" + assert len(json_list) <= 10, "单次输入json文件数量不可超过8。" id2feature = extract_json_data(json_list) print(json.dumps(id2feature ,indent=2 ,ensure_ascii=False)) # id2feature @@ -109,6 +109,50 @@ class InferenceEngine: 请注意Json文件中的词条数必须大于等于10. """ return self.inference_batch([json_path]) + + def inference( + self, + featurs : dict[str ,dict] + ): + assert len(featurs) <= 10, "单次输入json文件数量不可超过8。" + message_list = [] + for id, feature in featurs.items(): + messages = self.formatter.get_llm_prompt(feature) + message_list.append(messages) + + inputs = self.tokenizer.apply_chat_template( + message_list, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + + model_inputs = self.tokenizer( + inputs, + padding=True, + truncation=True, + max_length=2048, + return_tensors="pt" + ).to(self.device) + + with torch.inference_mode(): + with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16): + outputs = self.model(model_inputs) + + # 1. 计算分类标签(argmax) + preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() + + # 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题) + outputs_float = outputs.float() # 转换为 float32 避免精度问题 + probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] + # 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率) + probs = probs.cpu().numpy().tolist() + probs = [p[1] for p in probs] # 只保留类别1的概率 + + # 3. 计算置信度 + confidence = [abs(p - 0.5) * 2 for p in probs] + # 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表,confidence是每个样本的置信度列表 + return {"labels": preds, "probs": probs, "confidence": confidence} if __name__ == "__main__": # 配置参数 diff --git a/main.py b/main.py index fed4a9d..68ab0b5 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,18 @@ from feature_extraction import process_single -from inference import engine +from inference import InferenceEngine from services.mongo import voice_collection import json,uuid from fastapi import FastAPI from pydantic import BaseModel +backbone_dir = "Qwen3-1.7B" +ckpt_path = "best_ckpt.pth" +device = "cuda" + +engine = InferenceEngine(backbone_dir, ckpt_path, device) + + + async def get_customer_record(): cursor = voice_collection.find({ "tag": "20分钟通话", @@ -14,7 +22,7 @@ async def get_customer_record(): } } }).sort([('_id', -1)]).limit(24) - return await cursor.to_list(length=4) + return await cursor.to_list(length=1) async def main():