添加 SigLIP 接口脚本
This commit is contained in:
parent
d456bf0410
commit
0f194c620f
111
siglip_interface.py
Normal file
111
siglip_interface.py
Normal file
@ -0,0 +1,111 @@
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoModel, SiglipTextConfig
|
||||
import torch
|
||||
from typing import List, Union, Optional
|
||||
import os
|
||||
|
||||
class SigLIPInterface:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "google/siglip-base-patch16-224",
|
||||
max_position_embeddings: int = 64,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化 SigLIP 接口
|
||||
|
||||
Args:
|
||||
model_name: 模型名称或路径
|
||||
max_position_embeddings: 最大位置编码长度,默认为64
|
||||
device: 设备名称 (cuda/cpu),如果为None则自动选择
|
||||
"""
|
||||
self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 加载模型配置
|
||||
text_config = SiglipTextConfig.from_pretrained(model_name)
|
||||
text_config.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# 加载模型和处理器
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
text_config=text_config,
|
||||
ignore_mismatched_sizes=True
|
||||
).to(self.device)
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
image: Union[str, Image.Image],
|
||||
texts: List[str],
|
||||
return_probs: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算图像和文本的相似度
|
||||
|
||||
Args:
|
||||
image: 图像路径或 PIL Image 对象
|
||||
texts: 文本列表
|
||||
return_probs: 是否返回概率值(True)或原始logits(False)
|
||||
|
||||
Returns:
|
||||
相似度分数张量
|
||||
"""
|
||||
# 处理输入图像
|
||||
if isinstance(image, str):
|
||||
if not os.path.exists(image):
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
# 处理输入
|
||||
inputs = self.processor(
|
||||
text=texts,
|
||||
images=image,
|
||||
padding="max_length",
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
# 计算相似度
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits_per_image
|
||||
|
||||
if return_probs:
|
||||
return torch.sigmoid(logits)
|
||||
return logits
|
||||
|
||||
def get_similarity_score(
|
||||
self,
|
||||
image: Union[str, Image.Image],
|
||||
text: str
|
||||
) -> float:
|
||||
"""
|
||||
获取单个图像和文本的相似度分数
|
||||
|
||||
Args:
|
||||
image: 图像路径或 PIL Image 对象
|
||||
text: 文本字符串
|
||||
|
||||
Returns:
|
||||
相似度分数(0-1之间)
|
||||
"""
|
||||
scores = self.compute_similarity(image, [text])
|
||||
return scores[0][0].item()
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 创建接口实例
|
||||
siglip = SigLIPInterface()
|
||||
|
||||
# 示例用法
|
||||
image_path = "image.png"
|
||||
text = "这里有一个非常精美的盒子。"
|
||||
|
||||
# 计算相似度
|
||||
score = siglip.get_similarity_score(image_path, text)
|
||||
print(f"相似度分数: {score:.2%}")
|
||||
|
||||
# 批量计算
|
||||
texts = ["这里有一个非常精美的盒子。", "这是一段示例文本", "这里有一个非常精美的商品。"]
|
||||
scores = siglip.compute_similarity(image_path, texts)
|
||||
print(f"批量相似度分数: {scores}")
|
Loading…
Reference in New Issue
Block a user