292 lines
13 KiB
Python
292 lines
13 KiB
Python
|
from open_r1.vlm_modules.vlm_module import VLMBaseModule
|
||
|
from typing import Dict, Any, Union
|
||
|
from transformers import AutoModel, AutoProcessor, AutoConfig
|
||
|
import torch
|
||
|
import torchvision.transforms as T
|
||
|
from PIL import Image
|
||
|
from torchvision.transforms.functional import InterpolationMode
|
||
|
from transformers.feature_extraction_sequence_utils import BatchFeature
|
||
|
|
||
|
IMG_START_TOKEN='<img>'
|
||
|
IMG_END_TOKEN='</img>'
|
||
|
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
|
||
|
|
||
|
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||
|
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||
|
|
||
|
class InvernVLModule(VLMBaseModule):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.conv_template = None
|
||
|
self.num_image_token = None
|
||
|
|
||
|
def get_vlm_key(self):
|
||
|
return "internvl"
|
||
|
|
||
|
def get_model_class(self, model_id: str, model_init_kwargs: dict):
|
||
|
assert "InternVL" in model_id, f"model_id must contain 'InternVL', but got {model_id}"
|
||
|
self.model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
||
|
# The model class of InternVL when being mapped has been determined by its config
|
||
|
model_cls = AutoModel
|
||
|
# InternVL should be inputted with "trust_remote_code=True"
|
||
|
model_init_kwargs["trust_remote_code"] = True
|
||
|
# "use_cache" should be removed
|
||
|
model_init_kwargs.pop("use_cache", None)
|
||
|
# "flash_attention_2" should be modified to "use_flash_attn" in InternVL
|
||
|
if "flash_attention_2" in model_init_kwargs.get("attn_implementation", ""):
|
||
|
model_init_kwargs["use_flash_attn"] = True
|
||
|
model_init_kwargs.pop("attn_implementation")
|
||
|
return model_cls
|
||
|
|
||
|
def post_model_init(self, model, processing_class):
|
||
|
self.conv_template = model.conv_template if self.conv_template is None else self.conv_template
|
||
|
self.num_image_token = model.num_image_token if self.num_image_token is None else self.num_image_token
|
||
|
img_context_token_id = processing_class.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||
|
model.img_context_token_id = img_context_token_id
|
||
|
|
||
|
def is_embeds_input(self):
|
||
|
return True
|
||
|
|
||
|
def get_processing_class(self):
|
||
|
return AutoProcessor
|
||
|
|
||
|
def get_eos_token_id(self, processing_class):
|
||
|
eos_token_id = processing_class.convert_tokens_to_ids(self.conv_template.sep.strip())
|
||
|
return eos_token_id
|
||
|
|
||
|
def get_vision_modules_keywords(self):
|
||
|
return ['vision_model']
|
||
|
|
||
|
def get_custom_multimodal_keywords(self):
|
||
|
return ['pixel_values', 'image_flags']
|
||
|
|
||
|
def get_non_generate_params(self):
|
||
|
return ['image_flags']
|
||
|
|
||
|
def get_custom_processing_keywords(self):
|
||
|
return ['max_anyres_num']
|
||
|
|
||
|
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
|
||
|
prompts_text = []
|
||
|
for example in inputs:
|
||
|
template = self.conv_template.copy()
|
||
|
conversation_list = example["prompt"]
|
||
|
system_message = extract_system_message(conversation_list)
|
||
|
if system_message is not None:
|
||
|
template.system_message = system_message
|
||
|
|
||
|
processed_list = process_conversation_list(conversation_list, system_message)
|
||
|
for i, processed_item in enumerate(processed_list):
|
||
|
if i % 2 == 0:
|
||
|
template.append_message(template.roles[0], processed_item)
|
||
|
else:
|
||
|
template.append_message(template.roles[1], processed_item)
|
||
|
if len(processed_list) % 2 == 1:
|
||
|
template.append_message(template.roles[1], None)
|
||
|
query = template.get_prompt()
|
||
|
prompts_text.append(query)
|
||
|
return prompts_text
|
||
|
|
||
|
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
|
||
|
# Process images
|
||
|
full_pixel_values = []
|
||
|
num_patches_list = []
|
||
|
for img in images:
|
||
|
pixel_values = self._load_image(img, input_size=self.model_config.vision_config.image_size, max_num=processing_class.max_anyres_num)
|
||
|
full_pixel_values.append(pixel_values)
|
||
|
num_patches_list.append(pixel_values.shape[0])
|
||
|
full_pixel_values = torch.cat(full_pixel_values, dim=0)
|
||
|
|
||
|
# Process prompts
|
||
|
queries = []
|
||
|
image_idx = 0
|
||
|
for query in prompts_text:
|
||
|
while "<image>" in query:
|
||
|
num_patches = num_patches_list[image_idx]
|
||
|
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
||
|
query = query.replace("<image>", image_tokens, 1)
|
||
|
image_idx += 1
|
||
|
queries.append(query)
|
||
|
assert image_idx == len(num_patches_list)
|
||
|
|
||
|
model_inputs = processing_class(
|
||
|
queries,
|
||
|
return_tensors=return_tensors,
|
||
|
padding=padding,
|
||
|
padding_side=padding_side,
|
||
|
add_special_tokens=add_special_tokens,
|
||
|
)
|
||
|
model_inputs["pixel_values"] = full_pixel_values
|
||
|
# Only support pure-image data currently (each sample should contain the image)
|
||
|
model_inputs['image_flags'] = torch.ones(full_pixel_values.shape[0], dtype=torch.long)
|
||
|
|
||
|
model_inputs = BatchFeature(data=model_inputs)
|
||
|
return model_inputs
|
||
|
|
||
|
def _load_image(self, image: Image.Image, input_size: int=448, max_num:int=12):
|
||
|
transform = build_transform(input_size=input_size)
|
||
|
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||
|
pixel_values = [transform(image) for image in images]
|
||
|
pixel_values = torch.stack(pixel_values)
|
||
|
return pixel_values
|
||
|
|
||
|
@staticmethod
|
||
|
def get_question_template(task_type: str):
|
||
|
match task_type:
|
||
|
case _:
|
||
|
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
|
||
|
|
||
|
@staticmethod
|
||
|
def format_reward_rec(completions, **kwargs):
|
||
|
"""Check if the InternVL model output matches a specific format."""
|
||
|
import re
|
||
|
pattern = r"<think>.*?</think>\s*<answer>.*?\[\d+,\s*\d+,\s*\d+,\s*\d+\].*?</answer>"
|
||
|
completion_contents = [completion[0]["content"] for completion in completions]
|
||
|
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
|
||
|
return [1.0 if match else 0.0 for match in matches]
|
||
|
|
||
|
@staticmethod
|
||
|
def iou_reward(completions, solution, **kwargs):
|
||
|
"""Calculate IoU reward between predicted bounding box from InternVL model and ground truth bounding box."""
|
||
|
"""Adopt soft iou reward here"""
|
||
|
import re
|
||
|
import os
|
||
|
from datetime import datetime
|
||
|
def iou(box1, box2):
|
||
|
inter_x1 = max(box1[0], box2[0])
|
||
|
inter_y1 = max(box1[1], box2[1])
|
||
|
inter_x2 = min(box1[2]-1, box2[2]-1)
|
||
|
inter_y2 = min(box1[3]-1, box2[3]-1)
|
||
|
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
||
|
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
||
|
else:
|
||
|
inter = 0
|
||
|
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
||
|
return float(inter)/union
|
||
|
contents = [completion[0]["content"] for completion in completions]
|
||
|
rewards = []
|
||
|
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||
|
answer_tag_pattern = r'<answer>(.*?)</answer>'
|
||
|
bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
|
||
|
for content, sol in zip(contents, solution):
|
||
|
reward = 0.0
|
||
|
# Try symbolic verification first
|
||
|
try:
|
||
|
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
||
|
if content_answer_match:
|
||
|
content_answer = content_answer_match.group(1).strip()
|
||
|
bbox_match = re.search(bbox_pattern, content_answer)
|
||
|
if bbox_match:
|
||
|
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
|
||
|
reward = iou(bbox, sol)
|
||
|
except Exception:
|
||
|
pass # Continue to next verification method if this fails
|
||
|
|
||
|
rewards.append(reward)
|
||
|
if os.getenv("DEBUG_MODE") == "true":
|
||
|
log_path = os.getenv("LOG_PATH")
|
||
|
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||
|
with open(log_path, "a", encoding='utf-8') as f:
|
||
|
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
||
|
f.write(f"Content: {content}\n")
|
||
|
f.write(f"Solution: {sol}\n")
|
||
|
return rewards
|
||
|
|
||
|
|
||
|
def process_conversation_list(conversation_list, system_message=None, image_newline=True):
|
||
|
if system_message is not None:
|
||
|
conversation_list = conversation_list[1:]
|
||
|
processed_list = []
|
||
|
|
||
|
for item in conversation_list:
|
||
|
role = item["role"]
|
||
|
content = item["content"]
|
||
|
|
||
|
if isinstance(content, list):
|
||
|
overall_str = ""
|
||
|
for content_item in content:
|
||
|
if content_item.get("type") == "image":
|
||
|
overall_str += "<image>" if not image_newline else "<image>\n"
|
||
|
elif content_item.get("type") == "text":
|
||
|
overall_str += content_item.get("text")
|
||
|
else:
|
||
|
raise ValueError(f"Unsupported content type: {type(content_item)}")
|
||
|
processed_list.append(overall_str)
|
||
|
elif isinstance(content, str):
|
||
|
processed_list.append(content)
|
||
|
else:
|
||
|
raise ValueError(f"Unsupported content type: {type(content)}")
|
||
|
|
||
|
return processed_list
|
||
|
|
||
|
def extract_system_message(conversation_list):
|
||
|
if conversation_list[0]["role"] == "system":
|
||
|
if isinstance(conversation_list[0]["content"], list):
|
||
|
return conversation_list[0]["content"][0]["text"]
|
||
|
else:
|
||
|
return conversation_list[0]["content"]
|
||
|
return None
|
||
|
|
||
|
|
||
|
def build_transform(input_size):
|
||
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||
|
transform = T.Compose([
|
||
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||
|
T.ToTensor(),
|
||
|
T.Normalize(mean=MEAN, std=STD)
|
||
|
])
|
||
|
return transform
|
||
|
|
||
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||
|
best_ratio_diff = float('inf')
|
||
|
best_ratio = (1, 1)
|
||
|
area = width * height
|
||
|
for ratio in target_ratios:
|
||
|
target_aspect_ratio = ratio[0] / ratio[1]
|
||
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||
|
if ratio_diff < best_ratio_diff:
|
||
|
best_ratio_diff = ratio_diff
|
||
|
best_ratio = ratio
|
||
|
elif ratio_diff == best_ratio_diff:
|
||
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||
|
best_ratio = ratio
|
||
|
return best_ratio
|
||
|
|
||
|
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
||
|
orig_width, orig_height = image.size
|
||
|
aspect_ratio = orig_width / orig_height
|
||
|
|
||
|
# calculate the existing image aspect ratio
|
||
|
target_ratios = set(
|
||
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||
|
i * j <= max_num and i * j >= min_num)
|
||
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||
|
|
||
|
# find the closest aspect ratio to the target
|
||
|
target_aspect_ratio = find_closest_aspect_ratio(
|
||
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||
|
|
||
|
# calculate the target width and height
|
||
|
target_width = image_size * target_aspect_ratio[0]
|
||
|
target_height = image_size * target_aspect_ratio[1]
|
||
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||
|
|
||
|
# resize the image
|
||
|
resized_img = image.resize((target_width, target_height))
|
||
|
processed_images = []
|
||
|
for i in range(blocks):
|
||
|
box = (
|
||
|
(i % (target_width // image_size)) * image_size,
|
||
|
(i // (target_width // image_size)) * image_size,
|
||
|
((i % (target_width // image_size)) + 1) * image_size,
|
||
|
((i // (target_width // image_size)) + 1) * image_size
|
||
|
)
|
||
|
# split the image
|
||
|
split_img = resized_img.crop(box)
|
||
|
processed_images.append(split_img)
|
||
|
assert len(processed_images) == blocks
|
||
|
if use_thumbnail and len(processed_images) != 1:
|
||
|
thumbnail_img = image.resize((image_size, image_size))
|
||
|
processed_images.append(thumbnail_img)
|
||
|
return processed_images
|