import torch
import torch.nn.functional as F
from transformers import Blip2Processor, Blip2Model
from PIL import Image

class BLIP2Extractor:
    def __init__(self, model_name="Salesforce/blip2-opt-2.7b", device="cuda"):
        """
        初始化BLIP-2特征提取器
        """
        self.device = device
        self.processor = Blip2Processor.from_pretrained(model_name)
        # 使用Blip2Model而不是Blip2ForConditionalGeneration，因为我们只需要特征
        self.model = Blip2Model.from_pretrained(
            model_name,
            torch_dtype=torch.float16
        ).to(device)
        self.model.eval()

    def extract_features(self, image_path=None, text=None):
        """
        提取图像或文本的特征
        
        Args:
            image_path: 图像路径（可选）
            text: 文本内容（可选）
            
        Returns:
            torch.Tensor: 特征向量
        """
        if image_path is not None:
            # 处理图像
            image = Image.open(image_path).convert('RGB')
            inputs = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
            with torch.no_grad():
                outputs = self.model.vision_model(**inputs)
                # 获取图像特征 [CLS] token
                image_features = outputs.pooler_output
                # 归一化
                image_features = F.normalize(image_features, dim=-1)
                return image_features
                
        elif text is not None:
            # 处理文本
            inputs = self.processor(text=text, return_tensors="pt", padding=True).to(self.device, torch.float16)
            with torch.no_grad():
                outputs = self.model.language_model(**inputs)
                # 获取文本特征（使用最后一层隐藏状态的平均值）
                text_features = outputs.last_hidden_state.mean(dim=1)
                # 归一化
                text_features = F.normalize(text_features, dim=-1)
                return text_features
        
        else:
            raise ValueError("必须提供图像路径或文本")

    def compute_similarity(self, image_path, text):
        """
        计算图像和文本之间的相似度
        
        Args:
            image_path: 图像路径
            text: 文本内容
            
        Returns:
            float: 相似度分数
        """
        # 提取特征
        image_features = self.extract_features(image_path=image_path)
        text_features = self.extract_features(text=text)
        
        # 计算余弦相似度
        similarity = torch.mm(image_features, text_features.transpose(0, 1))
        return similarity.item()
