Refactor inference and main modules for improved functionality and add uvicorn dependency
This commit is contained in:
@@ -23,7 +23,8 @@ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "time_12_1/data_ch_1")
|
||||
client = OpenAI(
|
||||
api_key=API_KEY,
|
||||
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
base_url="https://sg1.proxy.yinlihupo.cc/proxy/https://openrouter.ai/api/v1"
|
||||
# https://sg1.proxy.yinlihupo.cc/proxy/
|
||||
base_url="https://openrouter.ai/api/v1"
|
||||
)
|
||||
|
||||
|
||||
|
||||
43
inference.py
43
inference.py
@@ -60,7 +60,8 @@ class InferenceEngine:
|
||||
请注意Json文件中的词条数必须大于等于10.
|
||||
"""
|
||||
assert len(json_list) <= 8, "单次输入json文件数量不可超过8。"
|
||||
id2feature = extract_json_data(json_list) # id2feature
|
||||
id2feature = extract_json_data(json_list)
|
||||
# print(id2feature) # id2feature
|
||||
|
||||
message_list = []
|
||||
for id, feature in id2feature.items():
|
||||
@@ -97,6 +98,46 @@ class InferenceEngine:
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
def inference(
|
||||
self,
|
||||
featurs : dict[str ,dict]
|
||||
):
|
||||
assert len(featurs) <= 8, "单次输入json文件数量不可超过8。"
|
||||
message_list = []
|
||||
for id, feature in featurs.items():
|
||||
messages = self.formatter.get_llm_prompt(feature)
|
||||
message_list.append(messages)
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
message_list,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
model_inputs = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.amp.autocast(device_type=self.device, dtype=torch.bfloat16):
|
||||
outputs = self.model(model_inputs)
|
||||
|
||||
# 1. 计算分类标签(argmax)
|
||||
preds = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
|
||||
|
||||
# 2. 计算softmax概率(核心修正:转CPU、转numpy、转列表,解决Tensor序列化问题)
|
||||
outputs_float = outputs.float() # 转换为 float32 避免精度问题
|
||||
probs = torch.softmax(outputs_float, dim=1) # probs: [B, 2]
|
||||
# 转换为CPU的numpy数组,再转列表(每个样本对应2个类别的概率)
|
||||
probs = probs.cpu().numpy().tolist()
|
||||
|
||||
# 返回格式:labels是每个样本的分类标签列表,probs是每个样本的类别概率列表
|
||||
return {"labels": preds, "probs": probs}
|
||||
|
||||
def inference_sample(self, json_path: str) -> dict:
|
||||
"""
|
||||
单样本推理函数,输入为 JSON 字符串路径,输出为包含转换概率的字典。
|
||||
|
||||
41
main.py
41
main.py
@@ -1,7 +1,9 @@
|
||||
from feature_extraction import process_single
|
||||
from inference import engine
|
||||
from services.mongo import voice_collection
|
||||
|
||||
import json,uuid
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
async def get_customer_record():
|
||||
cursor = voice_collection.find({
|
||||
@@ -12,12 +14,43 @@ async def get_customer_record():
|
||||
}
|
||||
}
|
||||
}).sort([('_id', -1)]).limit(24)
|
||||
return await cursor.to_list(length=24)
|
||||
return await cursor.to_list(length=4)
|
||||
|
||||
|
||||
async def main():
|
||||
records = await get_customer_record()
|
||||
records = await get_customer_record()
|
||||
for record in records:
|
||||
print(record)
|
||||
# print(len(record["text_content"]))
|
||||
data = await process_single(record["text_content"][:10000])
|
||||
# print(json.dumps(data, indent=2 , ensure_ascii=False))
|
||||
temp = {}
|
||||
res = {}
|
||||
for key ,value in data.items():
|
||||
temp[key] = value.get("value") or ""
|
||||
res[uuid.uuid4().hex] = temp
|
||||
print(engine.inference(res))
|
||||
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class Predictbody(BaseModel):
|
||||
content : str
|
||||
|
||||
@app.post("/predict")
|
||||
async def endpoint(body : Predictbody):
|
||||
data = await process_single(body.content[:10000])
|
||||
temp = {}
|
||||
res = {}
|
||||
for key ,value in data.items():
|
||||
temp[key] = value.get("value") or ""
|
||||
res[uuid.uuid4().hex] = temp
|
||||
return {
|
||||
"feature" : data,
|
||||
"predict" : engine.inference(res)
|
||||
}
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
@@ -5,6 +5,7 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"fastapi>=0.128.0",
|
||||
"matplotlib>=3.10.8",
|
||||
"motor>=3.7.1",
|
||||
"openai>=2.16.0",
|
||||
@@ -12,4 +13,5 @@ dependencies = [
|
||||
"python-dotenv>=1.2.1",
|
||||
"torch>=2.10.0",
|
||||
"transformers>=5.0.0",
|
||||
"uvicorn>=0.40.0",
|
||||
]
|
||||
|
||||
53
uv.lock
generated
53
uv.lock
generated
@@ -10,6 +10,15 @@ resolution-markers = [
|
||||
"python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
version = "0.0.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
version = "0.7.0"
|
||||
@@ -152,6 +161,7 @@ name = "deal-classification"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "motor" },
|
||||
{ name = "openai" },
|
||||
@@ -159,10 +169,12 @@ dependencies = [
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "torch" },
|
||||
{ name = "transformers" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "fastapi", specifier = ">=0.128.0" },
|
||||
{ name = "matplotlib", specifier = ">=3.10.8" },
|
||||
{ name = "motor", specifier = ">=3.7.1" },
|
||||
{ name = "openai", specifier = ">=2.16.0" },
|
||||
@@ -170,6 +182,7 @@ requires-dist = [
|
||||
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
||||
{ name = "torch", specifier = ">=2.10.0" },
|
||||
{ name = "transformers", specifier = ">=5.0.0" },
|
||||
{ name = "uvicorn", specifier = ">=0.40.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -190,6 +203,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-doc" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.20.3"
|
||||
@@ -1207,6 +1235,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "starlette"
|
||||
version = "0.50.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/b8/73a0e6a6e079a9d9cfa64113d771e421640b6f679a52eeb9b32f72d871a1/starlette-0.50.0.tar.gz", hash = "sha256:a2a17b22203254bcbc2e1f926d2d55f3f9497f769416b3190768befe598fa3ca", size = 2646985, upload-time = "2025-11-01T15:25:27.516Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/52/1064f510b141bd54025f9b55105e26d1fa970b9be67ad766380a3c9b74b0/starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca", size = 74033, upload-time = "2025-11-01T15:25:25.461Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sympy"
|
||||
version = "1.14.0"
|
||||
@@ -1380,3 +1420,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/5e/a7/c202b344c5ca7daf3
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uvicorn"
|
||||
version = "0.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "h11" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/d1/8f3c683c9561a4e6689dd3b1d345c815f10f86acd044ee1fb9a4dcd0b8c5/uvicorn-0.40.0.tar.gz", hash = "sha256:839676675e87e73694518b5574fd0f24c9d97b46bea16df7b8c05ea1a51071ea", size = 81761, upload-time = "2025-12-21T14:16:22.45Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/d8/2083a1daa7439a66f3a48589a57d576aa117726762618f6bb09fe3798796/uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee", size = 68502, upload-time = "2025-12-21T14:16:21.041Z" },
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user