from open_r1.vlm_modules.vlm_module import VLMBaseModule from typing import Dict, Any, Union from transformers import AutoModel, AutoProcessor, AutoConfig import torch import torchvision.transforms as T from PIL import Image from torchvision.transforms.functional import InterpolationMode from transformers.feature_extraction_sequence_utils import BatchFeature IMG_START_TOKEN='' IMG_END_TOKEN='' IMG_CONTEXT_TOKEN='' IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) class InvernVLModule(VLMBaseModule): def __init__(self): super().__init__() self.conv_template = None self.num_image_token = None def get_vlm_key(self): return "internvl" def get_model_class(self, model_id: str, model_init_kwargs: dict): assert "InternVL" in model_id, f"model_id must contain 'InternVL', but got {model_id}" self.model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) # The model class of InternVL when being mapped has been determined by its config model_cls = AutoModel # InternVL should be inputted with "trust_remote_code=True" model_init_kwargs["trust_remote_code"] = True # "use_cache" should be removed model_init_kwargs.pop("use_cache", None) # "flash_attention_2" should be modified to "use_flash_attn" in InternVL if "flash_attention_2" in model_init_kwargs.get("attn_implementation", ""): model_init_kwargs["use_flash_attn"] = True model_init_kwargs.pop("attn_implementation") return model_cls def post_model_init(self, model, processing_class): self.conv_template = model.conv_template if self.conv_template is None else self.conv_template self.num_image_token = model.num_image_token if self.num_image_token is None else self.num_image_token img_context_token_id = processing_class.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) model.img_context_token_id = img_context_token_id def is_embeds_input(self): return True def get_processing_class(self): return AutoProcessor def get_eos_token_id(self, processing_class): eos_token_id = processing_class.convert_tokens_to_ids(self.conv_template.sep.strip()) return eos_token_id def get_vision_modules_keywords(self): return ['vision_model'] def get_custom_multimodal_keywords(self): return ['pixel_values', 'image_flags'] def get_non_generate_params(self): return ['image_flags'] def get_custom_processing_keywords(self): return ['max_anyres_num'] def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]): prompts_text = [] for example in inputs: template = self.conv_template.copy() conversation_list = example["prompt"] system_message = extract_system_message(conversation_list) if system_message is not None: template.system_message = system_message processed_list = process_conversation_list(conversation_list, system_message) for i, processed_item in enumerate(processed_list): if i % 2 == 0: template.append_message(template.roles[0], processed_item) else: template.append_message(template.roles[1], processed_item) if len(processed_list) % 2 == 1: template.append_message(template.roles[1], None) query = template.get_prompt() prompts_text.append(query) return prompts_text def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False): # Process images full_pixel_values = [] num_patches_list = [] for img in images: pixel_values = self._load_image(img, input_size=self.model_config.vision_config.image_size, max_num=processing_class.max_anyres_num) full_pixel_values.append(pixel_values) num_patches_list.append(pixel_values.shape[0]) full_pixel_values = torch.cat(full_pixel_values, dim=0) # Process prompts queries = [] image_idx = 0 for query in prompts_text: while "" in query: num_patches = num_patches_list[image_idx] image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN query = query.replace("", image_tokens, 1) image_idx += 1 queries.append(query) assert image_idx == len(num_patches_list) model_inputs = processing_class( queries, return_tensors=return_tensors, padding=padding, padding_side=padding_side, add_special_tokens=add_special_tokens, ) model_inputs["pixel_values"] = full_pixel_values # Only support pure-image data currently (each sample should contain the image) model_inputs['image_flags'] = torch.ones(full_pixel_values.shape[0], dtype=torch.long) model_inputs = BatchFeature(data=model_inputs) return model_inputs def _load_image(self, image: Image.Image, input_size: int=448, max_num:int=12): transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values @staticmethod def get_question_template(task_type: str): match task_type: case _: return "{Question} First output the thinking process in tags and then output the final answer in tags." @staticmethod def format_reward_rec(completions, **kwargs): """Check if the InternVL model output matches a specific format.""" import re pattern = r".*?\s*.*?\[\d+,\s*\d+,\s*\d+,\s*\d+\].*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents] return [1.0 if match else 0.0 for match in matches] @staticmethod def iou_reward(completions, solution, **kwargs): """Calculate IoU reward between predicted bounding box from InternVL model and ground truth bounding box.""" """Adopt soft iou reward here""" import re import os from datetime import datetime def iou(box1, box2): inter_x1 = max(box1[0], box2[0]) inter_y1 = max(box1[1], box2[1]) inter_x2 = min(box1[2]-1, box2[2]-1) inter_y2 = min(box1[3]-1, box2[3]-1) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) else: inter = 0 union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter return float(inter)/union contents = [completion[0]["content"] for completion in completions] rewards = [] current_time = datetime.now().strftime("%d-%H-%M-%S-%f") answer_tag_pattern = r'(.*?)' bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]' for content, sol in zip(contents, solution): reward = 0.0 # Try symbolic verification first try: content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL) if content_answer_match: content_answer = content_answer_match.group(1).strip() bbox_match = re.search(bbox_pattern, content_answer) if bbox_match: bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))] reward = iou(bbox, sol) except Exception: pass # Continue to next verification method if this fails rewards.append(reward) if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") # local_rank = int(os.getenv("LOCAL_RANK", 0)) with open(log_path, "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") f.write(f"Content: {content}\n") f.write(f"Solution: {sol}\n") return rewards def process_conversation_list(conversation_list, system_message=None, image_newline=True): if system_message is not None: conversation_list = conversation_list[1:] processed_list = [] for item in conversation_list: role = item["role"] content = item["content"] if isinstance(content, list): overall_str = "" for content_item in content: if content_item.get("type") == "image": overall_str += "" if not image_newline else "\n" elif content_item.get("type") == "text": overall_str += content_item.get("text") else: raise ValueError(f"Unsupported content type: {type(content_item)}") processed_list.append(overall_str) elif isinstance(content, str): processed_list.append(content) else: raise ValueError(f"Unsupported content type: {type(content)}") return processed_list def extract_system_message(conversation_list): if conversation_list[0]["role"] == "system": if isinstance(conversation_list[0]["content"], list): return conversation_list[0]["content"][0]["text"] else: return conversation_list[0]["content"] return None def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images