Files
DotsOCR/app.py
2026-03-10 16:59:53 +08:00

61 lines
2.2 KiB
Python

from fastapi import FastAPI, File, UploadFile, HTTPException
import io
from PIL import Image
import logging
from inference import DotsOcr
from datetime import datetime
import asyncio
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="DotsOCR API")
inference_semaphore = asyncio.Semaphore(1)
model_path = "DotsOCR"
try:
ocr_engine = DotsOcr(model_path)
logger.info(f"{datetime.now()} - DotsOCR model loaded successfully")
except Exception as e:
logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}")
ocr_engine = None
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
上传一张图像,返回布局识别结果(字典列表)。
使用信号量确保同一时间只有一个请求进行推理,防止 OOM。
"""
if ocr_engine is None:
raise HTTPException(status_code=500, detail="模型未正确加载")
if file.content_type is not None and not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="文件必须是图像格式")
async with inference_semaphore:
try:
image_bytes = await file.read()
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像")
logger.info(f"{datetime.now()} - 处理图像: {file.filename}, 尺寸: {image.size}")
result = ocr_engine.sample_inference(image)
if isinstance(result, dict) and "error" in result:
logger.error(f"{datetime.now()} - 推理返回错误: {result}")
raise HTTPException(status_code=500, detail=result)
return result
except HTTPException:
# 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误
raise
except Exception as e:
logger.exception(f"{datetime.now()} - 推理过程中发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "ok", "model_loaded": ocr_engine is not None}