diff --git a/siglip_interface.py b/siglip_interface.py new file mode 100644 index 0000000..5429845 --- /dev/null +++ b/siglip_interface.py @@ -0,0 +1,111 @@ +from PIL import Image +from transformers import AutoProcessor, AutoModel, SiglipTextConfig +import torch +from typing import List, Union, Optional +import os + +class SigLIPInterface: + def __init__( + self, + model_name: str = "google/siglip-base-patch16-224", + max_position_embeddings: int = 64, + device: Optional[str] = None + ): + """ + 初始化 SigLIP 接口 + + Args: + model_name: 模型名称或路径 + max_position_embeddings: 最大位置编码长度,默认为64 + device: 设备名称 (cuda/cpu),如果为None则自动选择 + """ + self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu" + + # 加载模型配置 + text_config = SiglipTextConfig.from_pretrained(model_name) + text_config.max_position_embeddings = max_position_embeddings + + # 加载模型和处理器 + self.model = AutoModel.from_pretrained( + model_name, + text_config=text_config, + ignore_mismatched_sizes=True + ).to(self.device) + + self.processor = AutoProcessor.from_pretrained(model_name) + + def compute_similarity( + self, + image: Union[str, Image.Image], + texts: List[str], + return_probs: bool = True + ) -> torch.Tensor: + """ + 计算图像和文本的相似度 + + Args: + image: 图像路径或 PIL Image 对象 + texts: 文本列表 + return_probs: 是否返回概率值(True)或原始logits(False) + + Returns: + 相似度分数张量 + """ + # 处理输入图像 + if isinstance(image, str): + if not os.path.exists(image): + raise FileNotFoundError(f"Image file not found: {image}") + image = Image.open(image).convert("RGB") + + # 处理输入 + inputs = self.processor( + text=texts, + images=image, + padding="max_length", + return_tensors="pt" + ).to(self.device) + + # 计算相似度 + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs.logits_per_image + + if return_probs: + return torch.sigmoid(logits) + return logits + + def get_similarity_score( + self, + image: Union[str, Image.Image], + text: str + ) -> float: + """ + 获取单个图像和文本的相似度分数 + + Args: + image: 图像路径或 PIL Image 对象 + text: 文本字符串 + + Returns: + 相似度分数(0-1之间) + """ + scores = self.compute_similarity(image, [text]) + return scores[0][0].item() + +# 使用示例 +if __name__ == "__main__": + # 创建接口实例 + siglip = SigLIPInterface() + + # 示例用法 + image_path = "image.png" + text = "这里有一个非常精美的盒子。" + + # 计算相似度 + score = siglip.get_similarity_score(image_path, text) + print(f"相似度分数: {score:.2%}") + + # 批量计算 + texts = ["这里有一个非常精美的盒子。", "这是一段示例文本", "这里有一个非常精美的商品。"] + scores = siglip.compute_similarity(image_path, texts) + print(f"批量相似度分数: {scores}") \ No newline at end of file