Fortrain/predata.py
2025-02-13 17:15:57 +08:00

202 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
import requests
from io import BytesIO
from urllib.parse import urlparse
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
# 处理不同类型的输入
if isinstance(image_file, Image.Image):
image = image_file.convert('RGB')
elif isinstance(image_file, str) and bool(urlparse(image_file).netloc):
try:
response = requests.get(image_file, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert('RGB')
except Exception as e:
raise ValueError(f"无法从URL加载图片: {str(e)}")
elif isinstance(image_file, str):
image = Image.open(image_file).convert('RGB')
elif isinstance(image_file, bytes):
image = Image.open(BytesIO(image_file)).convert('RGB')
else:
raise ValueError(f"不支持的图片格式: {type(image_file)}")
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def prepare_training_data(question: str, answer: str, tokenizer, image_path=None, input_size=448, max_num=12, num_image_token: int = 256):
"""
准备完整的训练数据(包括输入和标签)
Args:
question (str): 用户输入的问题
answer (str): 助手的回答文本
tokenizer: 分词器
image_path: 图片路径、URL或PIL Image对象
input_size (int): 图片输入尺寸
max_num (int): 最大图片块数
num_image_token (int): 每张图片的token数量
Returns:
dict: 包含模型训练所需的所有输入
"""
# 1. 处理图像输入
pixel_values = None
if image_path is not None:
pixel_values = load_image(image_path, input_size=input_size, max_num=max_num)[-1:] ### 只取最后一张图片
if torch.cuda.is_available():
pixel_values = pixel_values.to(torch.bfloat16)
# 2. 确保问题包含图片标记
if pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
# 3. 根据pixel_values确定num_patches
num_patches = pixel_values.shape[0] if pixel_values is not None else 0
# 4. 构造完整的对话内容
system_msg = "你是书生·万象英文名是InternVL是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
# 5. 替换图片标记
if num_patches > 0:
image_tokens = "<img>" + "<IMG_CONTEXT>" * (num_image_token * num_patches) + "</img>"
question = question.replace('<image>', image_tokens, 1)
# 6. 构造完整的query包含答案
full_prompt = (
f"<|im_start|>system\n{system_msg}<|im_end|>\n"
f"<|im_start|>user\n{question}<|im_end|>\n"
f"<|im_start|>assistant\n{answer}<|im_end|>\n"
)
# 7. 转换为模型输入格式
model_inputs = tokenizer(
full_prompt,
return_tensors="pt",
add_special_tokens=False
)
# 8. 构造labels
input_ids = model_inputs["input_ids"]
labels = input_ids.clone()
assistant_start_token = "<|im_start|>assistant\n"
assistant_token_ids = tokenizer(assistant_start_token, add_special_tokens=False)["input_ids"]
assistant_start_pos = None
for i in range(len(input_ids[0]) - len(assistant_token_ids)):
if input_ids[0][i:i+len(assistant_token_ids)].tolist() == assistant_token_ids:
assistant_start_pos = i
break
if assistant_start_pos is not None:
labels[0, :assistant_start_pos] = -100
return {
"input_ids": input_ids,
"attention_mask": model_inputs["attention_mask"],
"labels": labels,
"pixel_values": pixel_values
}
# 使用示例:
"""
import torch
from transformers import AutoTokenizer
# 初始化tokenizer
tokenizer = AutoTokenizer.from_pretrained("path_to_tokenizer")
# 准备示例数据
question = "请描述这张图片中的内容"
answer = "这是一张美丽的风景照,画面中有青山绿水。"
image_path = "./examples/image1.jpg" # 或URL或PIL Image对象
# 准备训练数据
training_data = prepare_training_data(
question=question,
answer=answer,
tokenizer=tokenizer,
image_path=image_path,
input_size=448,
max_num=12
)
# training_data 包含:
# {
# "input_ids": tensor([[...]]), # 完整对话的token ids
# "attention_mask": tensor([[...]]), # 注意力掩码
# "labels": tensor([[...]]), # 带有-100标记的标签
# "pixel_values": tensor([[...]]) # 图像数据
# }
"""