from PIL import Image from transformers import AutoProcessor, AutoModel, SiglipTextConfig import torch from typing import List, Union, Optional import os class SigLIPInterface: def __init__( self, model_name: str = "google/siglip-base-patch16-224", max_position_embeddings: int = 64, device: Optional[str] = None ): """ 初始化 SigLIP 接口 Args: model_name: 模型名称或路径 max_position_embeddings: 最大位置编码长度,默认为64 device: 设备名称 (cuda/cpu),如果为None则自动选择 """ self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu" # 加载模型配置 text_config = SiglipTextConfig.from_pretrained(model_name) text_config.max_position_embeddings = max_position_embeddings # 加载模型和处理器 self.model = AutoModel.from_pretrained( model_name, text_config=text_config, ignore_mismatched_sizes=True ).to(self.device) self.processor = AutoProcessor.from_pretrained(model_name) def compute_similarity( self, image: Union[str, Image.Image], texts: List[str], return_probs: bool = True ) -> torch.Tensor: """ 计算图像和文本的相似度 Args: image: 图像路径或 PIL Image 对象 texts: 文本列表 return_probs: 是否返回概率值(True)或原始logits(False) Returns: 相似度分数张量 """ # 处理输入图像 if isinstance(image, str): if not os.path.exists(image): raise FileNotFoundError(f"Image file not found: {image}") image = Image.open(image).convert("RGB") # 处理输入 inputs = self.processor( text=texts, images=image, padding="max_length", return_tensors="pt" ).to(self.device) # 计算相似度 with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits_per_image if return_probs: return torch.sigmoid(logits) return logits def get_similarity_score( self, image: Union[str, Image.Image], text: str ) -> float: """ 获取单个图像和文本的相似度分数 Args: image: 图像路径或 PIL Image 对象 text: 文本字符串 Returns: 相似度分数(0-1之间) """ scores = self.compute_similarity(image, [text]) return scores[0][0].item() # 使用示例 if __name__ == "__main__": # 创建接口实例 siglip = SigLIPInterface() # 示例用法 image_path = "image.png" text = "这里有一个非常精美的盒子。" # 计算相似度 score = siglip.get_similarity_score(image_path, text) print(f"相似度分数: {score:.2%}") # 批量计算 texts = ["这里有一个非常精美的盒子。", "这是一段示例文本", "这里有一个非常精美的商品。"] scores = siglip.compute_similarity(image_path, texts) print(f"批量相似度分数: {scores}")