From 0b168057cabc3df17828449c56ecb8bfc88730fb Mon Sep 17 00:00:00 2001 From: Tordor <3262978839@qq.com> Date: Fri, 30 Jan 2026 15:50:52 +0800 Subject: [PATCH] Refactor inference and main modules for improved functionality and add uvicorn dependency --- feature_extraction.py | 3 ++- inference.py | 43 ++++++++++++++++++++++++++++++++++- main.py | 41 +++++++++++++++++++++++++++++---- pyproject.toml | 2 ++ uv.lock | 53 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+), 6 deletions(-) diff --git a/feature_extraction.py b/feature_extraction.py index f550dea..f453f81 100644 --- a/feature_extraction.py +++ b/feature_extraction.py @@ -23,7 +23,8 @@ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "time_12_1/data_ch_1") client = OpenAI( api_key=API_KEY, # base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" - base_url="https://sg1.proxy.yinlihupo.cc/proxy/https://openrouter.ai/api/v1" + # https://sg1.proxy.yinlihupo.cc/proxy/ + base_url="https://openrouter.ai/api/v1" ) diff --git a/inference.py b/inference.py index f8347d3..b347644 100644 --- a/inference.py +++ b/inference.py @@ -60,7 +60,8 @@ class InferenceEngine: 请注意Json文件中的词条数必须大于等于10. """ assert len(json_list) <= 8, "单次输入json文件数量不可超过8。" - id2feature = extract_json_data(json_list) # id2feature + id2feature = extract_json_data(json_list) + # print(id2feature) # id2feature message_list = [] for id, feature in id2feature.items(): @@ -97,6 +98,46 @@ class InferenceEngine: # 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表 return {"labels": preds, "probs": probs} + def inference( + self, + featurs : dict[str ,dict] + ): + assert len(featurs) <= 8, "单次输入json文件数量不可超过8。" + message_list = [] + for id, feature in featurs.items(): + messages = self.formatter.get_llm_prompt(feature) + message_list.append(messages) + + inputs = self.tokenizer.apply_chat_template( + message_list, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + model_inputs = self.tokenizer( + inputs, + padding=True, + truncation=True, + max_length=2048, + return_tensors="pt" + ).to(self.device) + + with torch.inference_mode(): + with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16): + outputs = self.model(model_inputs) + + # 1. 计算分类标签(argmax) + preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist() + + # 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题) + outputs_float = outputs.float() # 转换为 float32 避免精度问题 + probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2] + # 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率) + probs = probs.cpu().numpy().tolist() + + # 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表 + return {"labels": preds, "probs": probs} + def inference_sample(self, json_path: str) -> dict: """ 单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。 diff --git a/main.py b/main.py index b54a1f5..fed4a9d 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,9 @@ from feature_extraction import process_single from inference import engine from services.mongo import voice_collection - +import json,uuid +from fastapi import FastAPI +from pydantic import BaseModel async def get_customer_record(): cursor = voice_collection.find({ @@ -12,12 +14,43 @@ async def get_customer_record(): } } }).sort([('_id', -1)]).limit(24) - return await cursor.to_list(length=24) + return await cursor.to_list(length=4) + async def main(): - records = await get_customer_record() + records = await get_customer_record() for record in records: - print(record) + # print(len(record["text_content"])) + data = await process_single(record["text_content"][:10000]) + # print(json.dumps(data, indent=2 , ensure_ascii=False)) + temp = {} + res = {} + for key ,value in data.items(): + temp[key] = value.get("value") or "" + res[uuid.uuid4().hex] = temp + print(engine.inference(res)) + + + +app = FastAPI() + +class Predictbody(BaseModel): + content : str + +@app.post("/predict") +async def endpoint(body : Predictbody): + data = await process_single(body.content[:10000]) + temp = {} + res = {} + for key ,value in data.items(): + temp[key] = value.get("value") or "" + res[uuid.uuid4().hex] = temp + return { + "feature" : data, + "predict" : engine.inference(res) + } + + if __name__ == "__main__": import asyncio diff --git a/pyproject.toml b/pyproject.toml index f48caca..c029ba2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.13" dependencies = [ + "fastapi>=0.128.0", "matplotlib>=3.10.8", "motor>=3.7.1", "openai>=2.16.0", @@ -12,4 +13,5 @@ dependencies = [ "python-dotenv>=1.2.1", "torch>=2.10.0", "transformers>=5.0.0", + "uvicorn>=0.40.0", ] diff --git a/uv.lock b/uv.lock index dbd1670..ecb567a 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,15 @@ resolution-markers = [ "python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -152,6 +161,7 @@ name = "deal-classification" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "fastapi" }, { name = "matplotlib" }, { name = "motor" }, { name = "openai" }, @@ -159,10 +169,12 @@ dependencies = [ { name = "python-dotenv" }, { name = "torch" }, { name = "transformers" }, + { name = "uvicorn" }, ] [package.metadata] requires-dist = [ + { name = "fastapi", specifier = ">=0.128.0" }, { name = "matplotlib", specifier = ">=3.10.8" }, { name = "motor", specifier = ">=3.7.1" }, { name = "openai", specifier = ">=2.16.0" }, @@ -170,6 +182,7 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "torch", specifier = ">=2.10.0" }, { name = "transformers", specifier = ">=5.0.0" }, + { name = "uvicorn", specifier = ">=0.40.0" }, ] [[package]] @@ -190,6 +203,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, ] +[[package]] +name = "fastapi" +version = "0.128.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" }, +] + [[package]] name = "filelock" version = "3.20.3" @@ -1207,6 +1235,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "starlette" +version = "0.50.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/b8/73a0e6a6e079a9d9cfa64113d771e421640b6f679a52eeb9b32f72d871a1/starlette-0.50.0.tar.gz", hash = "sha256:a2a17b22203254bcbc2e1f926d2d55f3f9497f769416b3190768befe598fa3ca", size = 2646985, upload-time = "2025-11-01T15:25:27.516Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/52/1064f510b141bd54025f9b55105e26d1fa970b9be67ad766380a3c9b74b0/starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca", size = 74033, upload-time = "2025-11-01T15:25:25.461Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -1380,3 +1420,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/5e/a7/c202b344c5ca7daf3 wheels = [ { url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" }, ] + +[[package]] +name = "uvicorn" +version = "0.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/d1/8f3c683c9561a4e6689dd3b1d345c815f10f86acd044ee1fb9a4dcd0b8c5/uvicorn-0.40.0.tar.gz", hash = "sha256:839676675e87e73694518b5574fd0f24c9d97b46bea16df7b8c05ea1a51071ea", size = 81761, upload-time = "2025-12-21T14:16:22.45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/d8/2083a1daa7439a66f3a48589a57d576aa117726762618f6bb09fe3798796/uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee", size = 68502, upload-time = "2025-12-21T14:16:21.041Z" }, +]