58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
import io
|
|
from PIL import Image
|
|
import logging
|
|
from inference import DotsOcr
|
|
from datetime import datetime
|
|
import asyncio
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(title="DotsOCR API")
|
|
inference_semaphore = asyncio.Semaphore(1)
|
|
|
|
model_path = "DotsOCR"
|
|
try:
|
|
ocr_engine = DotsOcr(model_path)
|
|
logger.info(f"{datetime.now()} - DotsOCR model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}")
|
|
ocr_engine = None
|
|
|
|
@app.post("/predict")
|
|
async def predict(file: UploadFile = File(...)):
|
|
"""
|
|
上传一张图像,返回布局识别结果(字典列表)。
|
|
使用信号量确保同一时间只有一个请求进行推理,防止 OOM。
|
|
"""
|
|
if ocr_engine is None:
|
|
raise HTTPException(status_code=500, detail="模型未正确加载")
|
|
|
|
if file.content_type is not None and not file.content_type.startswith("image/"):
|
|
raise HTTPException(status_code=400, detail="文件必须是图像格式")
|
|
|
|
async with inference_semaphore:
|
|
try:
|
|
image_bytes = await file.read()
|
|
|
|
try:
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像")
|
|
|
|
logger.info(f"处理图像: {file.filename}, 尺寸: {image.size}")
|
|
|
|
result = ocr_engine.sample_inference(image)
|
|
return result
|
|
except HTTPException:
|
|
# 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("推理过程中发生错误")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "ok", "model_loaded": ocr_engine is not None} |