Files
DotsOCR/inference.py
2026-03-10 19:49:08 +08:00

152 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info
from dots_ocr.utils import dict_promptmode_to_prompt
from PIL import Image
import re
import json
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings("ignore")
class DotsOcr:
def __init__(self, model_path, device="cuda"):
self.device = device
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
dtype="bfloat16",
trust_remote_code=True
)
self.model.to(self.device)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
self.prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
"""
def _format_prompt(self, image: Image.Image):
m = {
"role": "user",
"content": [
{
"type": "image",
"image": image
},
{"type": "text", "text": self.prompt}
]
}
return m
def _process_output(self, output_text: str):
match = re.search(r"\[.*\]", output_text, re.DOTALL)
if match:
return match.group(0)
else:
return output_text
def sample_inference(self, image: Image.Image):
"""
处理 PIL Image 对象返回解析后的结果列表。To avoid OOM, process one image at a time.
"""
messages = [
self._format_prompt(image)
]
# Preparation for inference
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
result_str = self._process_output(output_text[0])
try:
result_list = json.loads(result_str)
return result_list
except json.JSONDecodeError:
print(f"无法解析 JSON 输出: {result_str}")
return {"error": "无法解析 JSON 输出"}
def batch_inference(self, images: List[Image.Image]):
"""
处理 PIL Image 对象,返回解析后的结果列表。
"""
messages = [[self._format_prompt(img)] for img in images]
# Preparation for inference
texts = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
result_list = []
for text in output_text:
try:
result_list.append(json.loads(self._process_output(text)))
except json.JSONDecodeError:
print(f"无法解析 JSON 输出: {text} \n类型:{type(text)}")
result_list.append({"error": "无法解析 JSON 输出"})
return result_list
if __name__=="__main__":
model_path = "DotsOCR"
dots_ocr = DotsOcr(model_path)
image_paths = ["20260306-065852.webp", "20260306-065909.webp"]
imgs = [Image.open(image_path) for image_path in image_paths]
results = dots_ocr.batch_inference(imgs)
print(results)