Update app.py

This commit is contained in:
2026-03-10 16:59:53 +08:00
parent 252bc49307
commit ee38f71f40

117
app.py
View File

@@ -1,58 +1,61 @@
from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi import FastAPI, File, UploadFile, HTTPException
import io import io
from PIL import Image from PIL import Image
import logging import logging
from inference import DotsOcr from inference import DotsOcr
from datetime import datetime from datetime import datetime
import asyncio import asyncio
# 配置日志 # 配置日志
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
app = FastAPI(title="DotsOCR API") app = FastAPI(title="DotsOCR API")
inference_semaphore = asyncio.Semaphore(1) inference_semaphore = asyncio.Semaphore(1)
model_path = "DotsOCR" model_path = "DotsOCR"
try: try:
ocr_engine = DotsOcr(model_path) ocr_engine = DotsOcr(model_path)
logger.info(f"{datetime.now()} - DotsOCR model loaded successfully") logger.info(f"{datetime.now()} - DotsOCR model loaded successfully")
except Exception as e: except Exception as e:
logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}") logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}")
ocr_engine = None ocr_engine = None
@app.post("/predict") @app.post("/predict")
async def predict(file: UploadFile = File(...)): async def predict(file: UploadFile = File(...)):
""" """
上传一张图像,返回布局识别结果(字典列表)。 上传一张图像,返回布局识别结果(字典列表)。
使用信号量确保同一时间只有一个请求进行推理,防止 OOM。 使用信号量确保同一时间只有一个请求进行推理,防止 OOM。
""" """
if ocr_engine is None: if ocr_engine is None:
raise HTTPException(status_code=500, detail="模型未正确加载") raise HTTPException(status_code=500, detail="模型未正确加载")
if file.content_type is not None and not file.content_type.startswith("image/"): if file.content_type is not None and not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="文件必须是图像格式") raise HTTPException(status_code=400, detail="文件必须是图像格式")
async with inference_semaphore: async with inference_semaphore:
try: try:
image_bytes = await file.read() image_bytes = await file.read()
try: try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception: except Exception:
raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像") raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像")
logger.info(f"{datetime.now()} - 处理图像: {file.filename}, 尺寸: {image.size}")
logger.info(f"处理图像: {file.filename}, 尺寸: {image.size}")
result = ocr_engine.sample_inference(image)
result = ocr_engine.sample_inference(image)
return result if isinstance(result, dict) and "error" in result:
except HTTPException: logger.error(f"{datetime.now()} - 推理返回错误: {result}")
# 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误 raise HTTPException(status_code=500, detail=result)
raise return result
except Exception as e: except HTTPException:
logger.exception("推理过程中发生错误") # 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误
raise HTTPException(status_code=500, detail=str(e)) raise
except Exception as e:
@app.get("/health") logger.exception(f"{datetime.now()} - 推理过程中发生错误: {e}")
async def health_check(): raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "ok", "model_loaded": ocr_engine is not None} return {"status": "ok", "model_loaded": ocr_engine is not None}