Update app.py
This commit is contained in:
117
app.py
117
app.py
@@ -1,58 +1,61 @@
|
|||||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||||
import io
|
import io
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
from inference import DotsOcr
|
from inference import DotsOcr
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="DotsOCR API")
|
app = FastAPI(title="DotsOCR API")
|
||||||
inference_semaphore = asyncio.Semaphore(1)
|
inference_semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
model_path = "DotsOCR"
|
model_path = "DotsOCR"
|
||||||
try:
|
try:
|
||||||
ocr_engine = DotsOcr(model_path)
|
ocr_engine = DotsOcr(model_path)
|
||||||
logger.info(f"{datetime.now()} - DotsOCR model loaded successfully")
|
logger.info(f"{datetime.now()} - DotsOCR model loaded successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}")
|
logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}")
|
||||||
ocr_engine = None
|
ocr_engine = None
|
||||||
|
|
||||||
@app.post("/predict")
|
@app.post("/predict")
|
||||||
async def predict(file: UploadFile = File(...)):
|
async def predict(file: UploadFile = File(...)):
|
||||||
"""
|
"""
|
||||||
上传一张图像,返回布局识别结果(字典列表)。
|
上传一张图像,返回布局识别结果(字典列表)。
|
||||||
使用信号量确保同一时间只有一个请求进行推理,防止 OOM。
|
使用信号量确保同一时间只有一个请求进行推理,防止 OOM。
|
||||||
"""
|
"""
|
||||||
if ocr_engine is None:
|
if ocr_engine is None:
|
||||||
raise HTTPException(status_code=500, detail="模型未正确加载")
|
raise HTTPException(status_code=500, detail="模型未正确加载")
|
||||||
|
|
||||||
if file.content_type is not None and not file.content_type.startswith("image/"):
|
if file.content_type is not None and not file.content_type.startswith("image/"):
|
||||||
raise HTTPException(status_code=400, detail="文件必须是图像格式")
|
raise HTTPException(status_code=400, detail="文件必须是图像格式")
|
||||||
|
|
||||||
async with inference_semaphore:
|
async with inference_semaphore:
|
||||||
try:
|
try:
|
||||||
image_bytes = await file.read()
|
image_bytes = await file.read()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像")
|
raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像")
|
||||||
|
logger.info(f"{datetime.now()} - 处理图像: {file.filename}, 尺寸: {image.size}")
|
||||||
logger.info(f"处理图像: {file.filename}, 尺寸: {image.size}")
|
|
||||||
|
result = ocr_engine.sample_inference(image)
|
||||||
result = ocr_engine.sample_inference(image)
|
|
||||||
return result
|
if isinstance(result, dict) and "error" in result:
|
||||||
except HTTPException:
|
logger.error(f"{datetime.now()} - 推理返回错误: {result}")
|
||||||
# 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误
|
raise HTTPException(status_code=500, detail=result)
|
||||||
raise
|
return result
|
||||||
except Exception as e:
|
except HTTPException:
|
||||||
logger.exception("推理过程中发生错误")
|
# 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise
|
||||||
|
except Exception as e:
|
||||||
@app.get("/health")
|
logger.exception(f"{datetime.now()} - 推理过程中发生错误: {e}")
|
||||||
async def health_check():
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
return {"status": "ok", "model_loaded": ocr_engine is not None}
|
return {"status": "ok", "model_loaded": ocr_engine is not None}
|
||||||
Reference in New Issue
Block a user