111 lines
3.4 KiB
Python
111 lines
3.4 KiB
Python
|
from PIL import Image
|
|||
|
from transformers import AutoProcessor, AutoModel, SiglipTextConfig
|
|||
|
import torch
|
|||
|
from typing import List, Union, Optional
|
|||
|
import os
|
|||
|
|
|||
|
class SigLIPInterface:
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
model_name: str = "google/siglip-base-patch16-224",
|
|||
|
max_position_embeddings: int = 64,
|
|||
|
device: Optional[str] = None
|
|||
|
):
|
|||
|
"""
|
|||
|
初始化 SigLIP 接口
|
|||
|
|
|||
|
Args:
|
|||
|
model_name: 模型名称或路径
|
|||
|
max_position_embeddings: 最大位置编码长度,默认为64
|
|||
|
device: 设备名称 (cuda/cpu),如果为None则自动选择
|
|||
|
"""
|
|||
|
self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
|
|||
|
# 加载模型配置
|
|||
|
text_config = SiglipTextConfig.from_pretrained(model_name)
|
|||
|
text_config.max_position_embeddings = max_position_embeddings
|
|||
|
|
|||
|
# 加载模型和处理器
|
|||
|
self.model = AutoModel.from_pretrained(
|
|||
|
model_name,
|
|||
|
text_config=text_config,
|
|||
|
ignore_mismatched_sizes=True
|
|||
|
).to(self.device)
|
|||
|
|
|||
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
|||
|
|
|||
|
def compute_similarity(
|
|||
|
self,
|
|||
|
image: Union[str, Image.Image],
|
|||
|
texts: List[str],
|
|||
|
return_probs: bool = True
|
|||
|
) -> torch.Tensor:
|
|||
|
"""
|
|||
|
计算图像和文本的相似度
|
|||
|
|
|||
|
Args:
|
|||
|
image: 图像路径或 PIL Image 对象
|
|||
|
texts: 文本列表
|
|||
|
return_probs: 是否返回概率值(True)或原始logits(False)
|
|||
|
|
|||
|
Returns:
|
|||
|
相似度分数张量
|
|||
|
"""
|
|||
|
# 处理输入图像
|
|||
|
if isinstance(image, str):
|
|||
|
if not os.path.exists(image):
|
|||
|
raise FileNotFoundError(f"Image file not found: {image}")
|
|||
|
image = Image.open(image).convert("RGB")
|
|||
|
|
|||
|
# 处理输入
|
|||
|
inputs = self.processor(
|
|||
|
text=texts,
|
|||
|
images=image,
|
|||
|
padding="max_length",
|
|||
|
return_tensors="pt"
|
|||
|
).to(self.device)
|
|||
|
|
|||
|
# 计算相似度
|
|||
|
with torch.no_grad():
|
|||
|
outputs = self.model(**inputs)
|
|||
|
logits = outputs.logits_per_image
|
|||
|
|
|||
|
if return_probs:
|
|||
|
return torch.sigmoid(logits)
|
|||
|
return logits
|
|||
|
|
|||
|
def get_similarity_score(
|
|||
|
self,
|
|||
|
image: Union[str, Image.Image],
|
|||
|
text: str
|
|||
|
) -> float:
|
|||
|
"""
|
|||
|
获取单个图像和文本的相似度分数
|
|||
|
|
|||
|
Args:
|
|||
|
image: 图像路径或 PIL Image 对象
|
|||
|
text: 文本字符串
|
|||
|
|
|||
|
Returns:
|
|||
|
相似度分数(0-1之间)
|
|||
|
"""
|
|||
|
scores = self.compute_similarity(image, [text])
|
|||
|
return scores[0][0].item()
|
|||
|
|
|||
|
# 使用示例
|
|||
|
if __name__ == "__main__":
|
|||
|
# 创建接口实例
|
|||
|
siglip = SigLIPInterface()
|
|||
|
|
|||
|
# 示例用法
|
|||
|
image_path = "image.png"
|
|||
|
text = "这里有一个非常精美的盒子。"
|
|||
|
|
|||
|
# 计算相似度
|
|||
|
score = siglip.get_similarity_score(image_path, text)
|
|||
|
print(f"相似度分数: {score:.2%}")
|
|||
|
|
|||
|
# 批量计算
|
|||
|
texts = ["这里有一个非常精美的盒子。", "这是一段示例文本", "这里有一个非常精美的商品。"]
|
|||
|
scores = siglip.compute_similarity(image_path, texts)
|
|||
|
print(f"批量相似度分数: {scores}")
|