Update inference.py

This commit is contained in:
2026-03-10 19:49:08 +08:00
parent 057f64cc37
commit e9d67f409d

View File

@@ -1,11 +1,11 @@
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
from dots_ocr.utils import dict_promptmode_to_prompt from dots_ocr.utils import dict_promptmode_to_prompt
from PIL import Image from PIL import Image
import io import re
import json import json
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@@ -14,7 +14,7 @@ class DotsOcr:
self.device = device self.device = device
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
# attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
dtype="bfloat16", dtype="bfloat16",
trust_remote_code=True trust_remote_code=True
) )
@@ -39,21 +39,33 @@ class DotsOcr:
5. Final Output: The entire output must be a single JSON object. 5. Final Output: The entire output must be a single JSON object.
""" """
def _format_prompt(self, image: Image.Image):
m = {
"role": "user",
"content": [
{
"type": "image",
"image": image
},
{"type": "text", "text": self.prompt}
]
}
return m
def _process_output(self, output_text: str):
match = re.search(r"\[.*\]", output_text, re.DOTALL)
if match:
return match.group(0)
else:
return output_text
def sample_inference(self, image: Image.Image): def sample_inference(self, image: Image.Image):
""" """
处理 PIL Image 对象返回解析后的结果列表。To avoid OOM, process one image at a time. 处理 PIL Image 对象返回解析后的结果列表。To avoid OOM, process one image at a time.
""" """
messages = [ messages = [
{ self._format_prompt(image)
"role": "user",
"content": [
{
"type": "image",
"image": image
},
{"type": "text", "text": self.prompt}
]
}
] ]
# Preparation for inference # Preparation for inference
@@ -80,7 +92,7 @@ class DotsOcr:
output_text = self.processor.batch_decode( output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
) )
result_str = output_text[0] result_str = self._process_output(output_text[0])
try: try:
result_list = json.loads(result_str) result_list = json.loads(result_str)
@@ -89,10 +101,52 @@ class DotsOcr:
print(f"无法解析 JSON 输出: {result_str}") print(f"无法解析 JSON 输出: {result_str}")
return {"error": "无法解析 JSON 输出"} return {"error": "无法解析 JSON 输出"}
def batch_inference(self, images: List[Image.Image]):
"""
处理 PIL Image 对象,返回解析后的结果列表。
"""
messages = [[self._format_prompt(img)] for img in images]
# Preparation for inference
texts = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
result_list = []
for text in output_text:
try:
result_list.append(json.loads(self._process_output(text)))
except json.JSONDecodeError:
print(f"无法解析 JSON 输出: {text} \n类型:{type(text)}")
result_list.append({"error": "无法解析 JSON 输出"})
return result_list
if __name__=="__main__": if __name__=="__main__":
model_path = "DotsOCR" model_path = "DotsOCR"
dots_ocr = DotsOcr(model_path) dots_ocr = DotsOcr(model_path)
image_paths = ["20260310-162729.webp"] image_paths = ["20260306-065852.webp", "20260306-065909.webp"]
for image_path in image_paths: imgs = [Image.open(image_path) for image_path in image_paths]
output_text = dots_ocr.sample_inference(image_path) results = dots_ocr.batch_inference(imgs)
print(output_text) print(results)