Update inference.py

This commit is contained in:
2026-03-10 17:00:06 +08:00
parent ee38f71f40
commit 057f64cc37

View File

@@ -1,93 +1,98 @@
import torch import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
from dots_ocr.utils import dict_promptmode_to_prompt from dots_ocr.utils import dict_promptmode_to_prompt
from PIL import Image from PIL import Image
import io import io
import json import json
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
class DotsOcr: class DotsOcr:
def __init__(self, model_path, device="cuda"): def __init__(self, model_path, device="cuda"):
self.device = device self.device = device
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
# attn_implementation="flash_attention_2", # attn_implementation="flash_attention_2",
dtype="bfloat16", dtype="bfloat16",
trust_remote_code=True trust_remote_code=True
) )
self.model.to(self.device) self.model.to(self.device)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=True) self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
self.prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. self.prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2] 1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules: 3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted. - Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX. - Formula: Format its text as LaTeX.
- Table: Format its text as HTML. - Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown. - All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints: 4. Constraints:
- The output text must be the original text from the image, with no translation. - The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order. - All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object. 5. Final Output: The entire output must be a single JSON object.
""" """
def sample_inference(self, image: Image.Image): def sample_inference(self, image: Image.Image):
""" """
处理 PIL Image 对象返回解析后的结果列表。To avoid OOM, process one image at a time. 处理 PIL Image 对象返回解析后的结果列表。To avoid OOM, process one image at a time.
""" """
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {
"type": "image", "type": "image",
"image": image "image": image
}, },
{"type": "text", "text": self.prompt} {"type": "text", "text": self.prompt}
] ]
} }
] ]
# Preparation for inference # Preparation for inference
text = self.processor.apply_chat_template( text = self.processor.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True add_generation_prompt=True
) )
image_inputs, video_inputs = process_vision_info(messages) image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor( inputs = self.processor(
text=[text], text=[text],
images=image_inputs, images=image_inputs,
videos=video_inputs, videos=video_inputs,
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
).to(self.device) ).to(self.device)
# Inference: Generation of the output # Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=24000) generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
generated_ids_trimmed = [ generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
] ]
output_text = self.processor.batch_decode( output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
) )
result_str = output_text[0] result_str = output_text[0]
result_list = json.loads(result_str)
return result_list try:
result_list = json.loads(result_str)
if __name__=="__main__": return result_list
model_path = "DotsOCR" except json.JSONDecodeError:
dots_ocr = DotsOcr(model_path) print(f"无法解析 JSON 输出: {result_str}")
image_paths = ["20260306-065852.webp", "20260306-065909.webp"] return {"error": "无法解析 JSON 输出"}
for image_path in image_paths:
output_text = dots_ocr.sample_inference(image_path) if __name__=="__main__":
print(output_text) model_path = "DotsOCR"
dots_ocr = DotsOcr(model_path)
image_paths = ["20260310-162729.webp"]
for image_path in image_paths:
output_text = dots_ocr.sample_inference(image_path)
print(output_text)