65 lines
1.6 KiB
Python
65 lines
1.6 KiB
Python
from feature_extraction import process_single
|
|
from inference import InferenceEngine
|
|
from services.mongo import voice_collection
|
|
import json,uuid
|
|
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
|
|
backbone_dir = "Qwen3-1.7B"
|
|
ckpt_path = "best_ckpt.pth"
|
|
device = "cuda"
|
|
|
|
engine = InferenceEngine(backbone_dir, ckpt_path, device)
|
|
|
|
|
|
|
|
async def get_customer_record():
|
|
cursor = voice_collection.find({
|
|
"tag": "20分钟通话",
|
|
"matched_contacts": {
|
|
"$elemMatch": {
|
|
"wecom_id": {"$exists": True, "$ne": ""}
|
|
}
|
|
}
|
|
}).sort([('_id', -1)]).limit(24)
|
|
return await cursor.to_list(length=1)
|
|
|
|
|
|
async def main():
|
|
records = await get_customer_record()
|
|
for record in records:
|
|
# print(len(record["text_content"]))
|
|
data = await process_single(record["text_content"][:10000])
|
|
# print(json.dumps(data, indent=2 , ensure_ascii=False))
|
|
temp = {}
|
|
res = {}
|
|
for key ,value in data.items():
|
|
temp[key] = value.get("value") or ""
|
|
res[uuid.uuid4().hex] = temp
|
|
print(engine.inference(res))
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
class Predictbody(BaseModel):
|
|
content : str
|
|
|
|
@app.post("/predict")
|
|
async def endpoint(body : Predictbody):
|
|
data = await process_single(body.content[:10000])
|
|
temp = {}
|
|
res = {}
|
|
for key ,value in data.items():
|
|
temp[key] = value.get("value") or ""
|
|
res[uuid.uuid4().hex] = temp
|
|
return {
|
|
"feature" : data,
|
|
"predict" : engine.inference(res)
|
|
}
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
asyncio.run(main()) |