from fastapi import FastAPI, File, UploadFile, HTTPException import io from PIL import Image import logging from inference import DotsOcr from datetime import datetime import asyncio # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="DotsOCR API") inference_semaphore = asyncio.Semaphore(1) model_path = "DotsOCR" try: ocr_engine = DotsOcr(model_path) logger.info(f"{datetime.now()} - DotsOCR model loaded successfully") except Exception as e: logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}") ocr_engine = None @app.post("/predict") async def predict(file: UploadFile = File(...)): """ 上传一张图像,返回布局识别结果(字典列表)。 使用信号量确保同一时间只有一个请求进行推理,防止 OOM。 """ if ocr_engine is None: raise HTTPException(status_code=500, detail="模型未正确加载") if file.content_type is not None and not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="文件必须是图像格式") async with inference_semaphore: try: image_bytes = await file.read() try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像") logger.info(f"{datetime.now()} - 处理图像: {file.filename}, 尺寸: {image.size}") result = ocr_engine.sample_inference(image) if isinstance(result, dict) and "error" in result: logger.error(f"{datetime.now()} - 推理返回错误: {result}") raise HTTPException(status_code=500, detail=result) return result except HTTPException: # 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误 raise except Exception as e: logger.exception(f"{datetime.now()} - 推理过程中发生错误: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "ok", "model_loaded": ocr_engine is not None}