diff --git a/app.py b/app.py new file mode 100644 index 0000000..e401962 --- /dev/null +++ b/app.py @@ -0,0 +1,58 @@ +from fastapi import FastAPI, File, UploadFile, HTTPException +import io +from PIL import Image +import logging +from inference import DotsOcr +from datetime import datetime +import asyncio + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = FastAPI(title="DotsOCR API") +inference_semaphore = asyncio.Semaphore(1) + +model_path = "DotsOCR" +try: + ocr_engine = DotsOcr(model_path) + logger.info(f"{datetime.now()} - DotsOCR model loaded successfully") +except Exception as e: + logger.error(f"{datetime.now()} - DotsOCR model loading failed: {e}") + ocr_engine = None + +@app.post("/predict") +async def predict(file: UploadFile = File(...)): + """ + 上传一张图像,返回布局识别结果(字典列表)。 + 使用信号量确保同一时间只有一个请求进行推理,防止 OOM。 + """ + if ocr_engine is None: + raise HTTPException(status_code=500, detail="模型未正确加载") + + if file.content_type is not None and not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail="文件必须是图像格式") + + async with inference_semaphore: + try: + image_bytes = await file.read() + + try: + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") + except Exception: + raise HTTPException(status_code=400, detail="无法识别的图像文件,请上传有效的图像") + + logger.info(f"处理图像: {file.filename}, 尺寸: {image.size}") + + result = ocr_engine.sample_inference(image) + return result + except HTTPException: + # 直接抛出 HTTP 异常,避免被通用异常捕获导致状态码错误 + raise + except Exception as e: + logger.exception("推理过程中发生错误") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/health") +async def health_check(): + return {"status": "ok", "model_loaded": ocr_engine is not None} \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..e904c49 --- /dev/null +++ b/inference.py @@ -0,0 +1,93 @@ +import torch +from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer +from qwen_vl_utils import process_vision_info +from dots_ocr.utils import dict_promptmode_to_prompt +from PIL import Image +import io +import json +from typing import List, Dict, Tuple +import warnings +warnings.filterwarnings("ignore") + +class DotsOcr: + def __init__(self, model_path, device="cuda"): + self.device = device + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + # attn_implementation="flash_attention_2", + dtype="bfloat16", + trust_remote_code=True + ) + self.model.to(self.device) + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=True) + self.prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. + + 1. Bbox format: [x1, y1, x2, y2] + + 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. + + 3. Text Extraction & Formatting Rules: + - Picture: For the 'Picture' category, the text field should be omitted. + - Formula: Format its text as LaTeX. + - Table: Format its text as HTML. + - All Others (Text, Title, etc.): Format their text as Markdown. + + 4. Constraints: + - The output text must be the original text from the image, with no translation. + - All layout elements must be sorted according to human reading order. + + 5. Final Output: The entire output must be a single JSON object. + """ + + def sample_inference(self, image: Image.Image): + """ + 处理 PIL Image 对象,返回解析后的结果列表。To avoid OOM, process one image at a time. + """ + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image + }, + {"type": "text", "text": self.prompt} + ] + } + ] + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + + # Inference: Generation of the output + generated_ids = self.model.generate(**inputs, max_new_tokens=24000) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + result_str = output_text[0] + result_list = json.loads(result_str) + return result_list + +if __name__=="__main__": + model_path = "DotsOCR" + dots_ocr = DotsOcr(model_path) + image_paths = ["20260306-065852.webp", "20260306-065909.webp"] + for image_path in image_paths: + output_text = dots_ocr.sample_inference(image_path) + print(output_text) diff --git a/test.py b/test.py new file mode 100644 index 0000000..f703105 --- /dev/null +++ b/test.py @@ -0,0 +1,6 @@ +import requests + +url = "http://10.200.0.118:8000/predict" +files = {"file": open("20260306-065909.webp", "rb")} +response = requests.post(url, files=files) +print(response.text) # 打印识别结果 \ No newline at end of file