111 lines
3.4 KiB
Python
111 lines
3.4 KiB
Python
from PIL import Image
|
||
from transformers import AutoProcessor, AutoModel, SiglipTextConfig
|
||
import torch
|
||
from typing import List, Union, Optional
|
||
import os
|
||
|
||
class SigLIPInterface:
|
||
def __init__(
|
||
self,
|
||
model_name: str = "google/siglip-base-patch16-224",
|
||
max_position_embeddings: int = 64,
|
||
device: Optional[str] = None
|
||
):
|
||
"""
|
||
初始化 SigLIP 接口
|
||
|
||
Args:
|
||
model_name: 模型名称或路径
|
||
max_position_embeddings: 最大位置编码长度,默认为64
|
||
device: 设备名称 (cuda/cpu),如果为None则自动选择
|
||
"""
|
||
self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
# 加载模型配置
|
||
text_config = SiglipTextConfig.from_pretrained(model_name)
|
||
text_config.max_position_embeddings = max_position_embeddings
|
||
|
||
# 加载模型和处理器
|
||
self.model = AutoModel.from_pretrained(
|
||
model_name,
|
||
text_config=text_config,
|
||
ignore_mismatched_sizes=True
|
||
).to(self.device)
|
||
|
||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||
|
||
def compute_similarity(
|
||
self,
|
||
image: Union[str, Image.Image],
|
||
texts: List[str],
|
||
return_probs: bool = True
|
||
) -> torch.Tensor:
|
||
"""
|
||
计算图像和文本的相似度
|
||
|
||
Args:
|
||
image: 图像路径或 PIL Image 对象
|
||
texts: 文本列表
|
||
return_probs: 是否返回概率值(True)或原始logits(False)
|
||
|
||
Returns:
|
||
相似度分数张量
|
||
"""
|
||
# 处理输入图像
|
||
if isinstance(image, str):
|
||
if not os.path.exists(image):
|
||
raise FileNotFoundError(f"Image file not found: {image}")
|
||
image = Image.open(image).convert("RGB")
|
||
|
||
# 处理输入
|
||
inputs = self.processor(
|
||
text=texts,
|
||
images=image,
|
||
padding="max_length",
|
||
return_tensors="pt"
|
||
).to(self.device)
|
||
|
||
# 计算相似度
|
||
with torch.no_grad():
|
||
outputs = self.model(**inputs)
|
||
logits = outputs.logits_per_image
|
||
|
||
if return_probs:
|
||
return torch.sigmoid(logits)
|
||
return logits
|
||
|
||
def get_similarity_score(
|
||
self,
|
||
image: Union[str, Image.Image],
|
||
text: str
|
||
) -> float:
|
||
"""
|
||
获取单个图像和文本的相似度分数
|
||
|
||
Args:
|
||
image: 图像路径或 PIL Image 对象
|
||
text: 文本字符串
|
||
|
||
Returns:
|
||
相似度分数(0-1之间)
|
||
"""
|
||
scores = self.compute_similarity(image, [text])
|
||
return scores[0][0].item()
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 创建接口实例
|
||
siglip = SigLIPInterface()
|
||
|
||
# 示例用法
|
||
image_path = "image.png"
|
||
text = "这里有一个非常精美的盒子。"
|
||
|
||
# 计算相似度
|
||
score = siglip.get_similarity_score(image_path, text)
|
||
print(f"相似度分数: {score:.2%}")
|
||
|
||
# 批量计算
|
||
texts = ["这里有一个非常精美的盒子。", "这是一段示例文本", "这里有一个非常精美的商品。"]
|
||
scores = siglip.compute_similarity(image_path, texts)
|
||
print(f"批量相似度分数: {scores}") |