From ee38f71f403ed96b315e8641cf1aba43f49730d4 Mon Sep 17 00:00:00 2001 From: WangZiFan Date: Tue, 10 Mar 2026 16:59:53 +0800 Subject: [PATCH] Update app.py --- app.py | 117 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 60 insertions(+), 57 deletions(-) diff --git a/app.py b/app.py index e401962..02d82df 100644 --- a/app.py +++ b/app.py @@ -1,58 +1,61 @@ -from fastapi import FastAPI, File, UploadFile, HTTPException -import io -from PIL import Image -import logging -from inference import DotsOcr -from datetime import datetime -import asyncio - -# 配置日志 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -app = FastAPI(title="DotsOCR API") -inference_semaphore = asyncio.Semaphore(1) - -model_path = "DotsOCR" -try: - ocr_engine = DotsOcr(model_path) - logger.info(f"{datetime.now()} - DotsOCR model loaded successfully") -except Exception as e: - logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}") - ocr_engine = None - -@app.post("/predict") -async def predict(file: UploadFile = File(...)): - """ - 上传一张图像,返回布局识别结果(字典列表)。 - 使用信号量确保同一时间只有一个请求进行推理,防止 OOM。 - """ - if ocr_engine is None: - raise HTTPException(status_code=500, detail="模型未正确加载") - - if file.content_type is not None and not file.content_type.startswith("image/"): - raise HTTPException(status_code=400, detail="文件必须是图像格式") - - async with inference_semaphore: - try: - image_bytes = await file.read() - - try: - image = Image.open(io.BytesIO(image_bytes)).convert("RGB") - except Exception: - raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像") - - logger.info(f"处理图像: {file.filename}, 尺寸: {image.size}") - - result = ocr_engine.sample_inference(image) - return result - except HTTPException: - # 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误 - raise - except Exception as e: - logger.exception("推理过程中发生错误") - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/health") -async def health_check(): +from fastapi import FastAPI, File, UploadFile, HTTPException +import io +from PIL import Image +import logging +from inference import DotsOcr +from datetime import datetime +import asyncio + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = FastAPI(title="DotsOCR API") +inference_semaphore = asyncio.Semaphore(1) + +model_path = "DotsOCR" +try: + ocr_engine = DotsOcr(model_path) + logger.info(f"{datetime.now()} - DotsOCR model loaded successfully") +except Exception as e: + logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}") + ocr_engine = None + +@app.post("/predict") +async def predict(file: UploadFile = File(...)): + """ + 上传一张图像,返回布局识别结果(字典列表)。 + 使用信号量确保同一时间只有一个请求进行推理,防止 OOM。 + """ + if ocr_engine is None: + raise HTTPException(status_code=500, detail="模型未正确加载") + + if file.content_type is not None and not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="文件必须是图像格式") + + async with inference_semaphore: + try: + image_bytes = await file.read() + + try: + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") + except Exception: + raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像") + logger.info(f"{datetime.now()} - 处理图像: {file.filename}, 尺寸: {image.size}") + + result = ocr_engine.sample_inference(image) + + if isinstance(result, dict) and "error" in result: + logger.error(f"{datetime.now()} - 推理返回错误: {result}") + raise HTTPException(status_code=500, detail=result) + return result + except HTTPException: + # 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误 + raise + except Exception as e: + logger.exception(f"{datetime.now()} - 推理过程中发生错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/health") +async def health_check(): return {"status": "ok", "model_loaded": ocr_engine is not None} \ No newline at end of file