Files
DotsOCR/inference.py
2026-03-06 18:00:13 +08:00

94 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info
from dots_ocr.utils import dict_promptmode_to_prompt
from PIL import Image
import io
import json
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings("ignore")
class DotsOcr:
def __init__(self, model_path, device="cuda"):
self.device = device
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
# attn_implementation="flash_attention_2",
dtype="bfloat16",
trust_remote_code=True
)
self.model.to(self.device)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
self.prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
"""
def sample_inference(self, image: Image.Image):
"""
处理 PIL Image 对象返回解析后的结果列表。To avoid OOM, process one image at a time.
"""
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image
},
{"type": "text", "text": self.prompt}
]
}
]
# Preparation for inference
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
result_str = output_text[0]
result_list = json.loads(result_str)
return result_list
if __name__=="__main__":
model_path = "DotsOCR"
dots_ocr = DotsOcr(model_path)
image_paths = ["20260306-065852.webp", "20260306-065909.webp"]
for image_path in image_paths:
output_text = dots_ocr.sample_inference(image_path)
print(output_text)