添加 SigLIP 接口脚本

This commit is contained in:
taodj 2025-04-23 06:30:03 +00:00
parent d456bf0410
commit 0f194c620f

111
siglip_interface.py Normal file
View File

@ -0,0 +1,111 @@
from PIL import Image
from transformers import AutoProcessor, AutoModel, SiglipTextConfig
import torch
from typing import List, Union, Optional
import os
class SigLIPInterface:
def __init__(
self,
model_name: str = "google/siglip-base-patch16-224",
max_position_embeddings: int = 64,
device: Optional[str] = None
):
"""
初始化 SigLIP 接口
Args:
model_name: 模型名称或路径
max_position_embeddings: 最大位置编码长度默认为64
device: 设备名称 (cuda/cpu)如果为None则自动选择
"""
self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
# 加载模型配置
text_config = SiglipTextConfig.from_pretrained(model_name)
text_config.max_position_embeddings = max_position_embeddings
# 加载模型和处理器
self.model = AutoModel.from_pretrained(
model_name,
text_config=text_config,
ignore_mismatched_sizes=True
).to(self.device)
self.processor = AutoProcessor.from_pretrained(model_name)
def compute_similarity(
self,
image: Union[str, Image.Image],
texts: List[str],
return_probs: bool = True
) -> torch.Tensor:
"""
计算图像和文本的相似度
Args:
image: 图像路径或 PIL Image 对象
texts: 文本列表
return_probs: 是否返回概率值True或原始logitsFalse
Returns:
相似度分数张量
"""
# 处理输入图像
if isinstance(image, str):
if not os.path.exists(image):
raise FileNotFoundError(f"Image file not found: {image}")
image = Image.open(image).convert("RGB")
# 处理输入
inputs = self.processor(
text=texts,
images=image,
padding="max_length",
return_tensors="pt"
).to(self.device)
# 计算相似度
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits_per_image
if return_probs:
return torch.sigmoid(logits)
return logits
def get_similarity_score(
self,
image: Union[str, Image.Image],
text: str
) -> float:
"""
获取单个图像和文本的相似度分数
Args:
image: 图像路径或 PIL Image 对象
text: 文本字符串
Returns:
相似度分数0-1之间
"""
scores = self.compute_similarity(image, [text])
return scores[0][0].item()
# 使用示例
if __name__ == "__main__":
# 创建接口实例
siglip = SigLIPInterface()
# 示例用法
image_path = "image.png"
text = "这里有一个非常精美的盒子。"
# 计算相似度
score = siglip.get_similarity_score(image_path, text)
print(f"相似度分数: {score:.2%}")
# 批量计算
texts = ["这里有一个非常精美的盒子。", "这是一段示例文本", "这里有一个非常精美的商品。"]
scores = siglip.compute_similarity(image_path, texts)
print(f"批量相似度分数: {scores}")