Fortrain/siglip_interface.py
2025-04-23 06:30:03 +00:00

111 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from PIL import Image
from transformers import AutoProcessor, AutoModel, SiglipTextConfig
import torch
from typing import List, Union, Optional
import os
class SigLIPInterface:
def __init__(
self,
model_name: str = "google/siglip-base-patch16-224",
max_position_embeddings: int = 64,
device: Optional[str] = None
):
"""
初始化 SigLIP 接口
Args:
model_name: 模型名称或路径
max_position_embeddings: 最大位置编码长度默认为64
device: 设备名称 (cuda/cpu)如果为None则自动选择
"""
self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
# 加载模型配置
text_config = SiglipTextConfig.from_pretrained(model_name)
text_config.max_position_embeddings = max_position_embeddings
# 加载模型和处理器
self.model = AutoModel.from_pretrained(
model_name,
text_config=text_config,
ignore_mismatched_sizes=True
).to(self.device)
self.processor = AutoProcessor.from_pretrained(model_name)
def compute_similarity(
self,
image: Union[str, Image.Image],
texts: List[str],
return_probs: bool = True
) -> torch.Tensor:
"""
计算图像和文本的相似度
Args:
image: 图像路径或 PIL Image 对象
texts: 文本列表
return_probs: 是否返回概率值True或原始logitsFalse
Returns:
相似度分数张量
"""
# 处理输入图像
if isinstance(image, str):
if not os.path.exists(image):
raise FileNotFoundError(f"Image file not found: {image}")
image = Image.open(image).convert("RGB")
# 处理输入
inputs = self.processor(
text=texts,
images=image,
padding="max_length",
return_tensors="pt"
).to(self.device)
# 计算相似度
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits_per_image
if return_probs:
return torch.sigmoid(logits)
return logits
def get_similarity_score(
self,
image: Union[str, Image.Image],
text: str
) -> float:
"""
获取单个图像和文本的相似度分数
Args:
image: 图像路径或 PIL Image 对象
text: 文本字符串
Returns:
相似度分数0-1之间
"""
scores = self.compute_similarity(image, [text])
return scores[0][0].item()
# 使用示例
if __name__ == "__main__":
# 创建接口实例
siglip = SigLIPInterface()
# 示例用法
image_path = "image.png"
text = "这里有一个非常精美的盒子。"
# 计算相似度
score = siglip.get_similarity_score(image_path, text)
print(f"相似度分数: {score:.2%}")
# 批量计算
texts = ["这里有一个非常精美的盒子。", "这是一段示例文本", "这里有一个非常精美的商品。"]
scores = siglip.compute_similarity(image_path, texts)
print(f"批量相似度分数: {scores}")