from transformers import CLIPModel
from typing import Optional
import torch 
from bayesvlm.hessians import KroneckerFactorizedCovariance

class CLIP(torch.nn.Module):
    source_projection_has_bias = False
    target_projection_has_bias = False

    def __init__(
        self,
        logit_scale: float,
        logit_bias: float = 0,
        source_covariance: KroneckerFactorizedCovariance = None,
        target_covariance: KroneckerFactorizedCovariance = None,
        device: Optional[str] = None,
    ):
        super().__init__()
        self.logit_scale = torch.nn.Parameter(torch.ones([], device=device) * logit_scale)
        self.logit_bias = torch.nn.Parameter(torch.ones([], device=device) * logit_bias)
        self.source_covariance = source_covariance
        self.target_covariance = target_covariance

    @property
    def device(self):
        return self.logit_scale.data.device

    def set_covariances(
        self,
        source_covariance: KroneckerFactorizedCovariance = None,
        target_covariance: KroneckerFactorizedCovariance = None,
    ):
        self.source_covariance = KroneckerFactorizedCovariance(
            A_inv=source_covariance.A_inv.clone().to(self.device),
            B_inv=source_covariance.B_inv.clone().to(self.device),
        ) if source_covariance is not None else None
        
        self.target_covariance = KroneckerFactorizedCovariance(
            A_inv=target_covariance.A_inv.clone().to(self.device),
            B_inv=target_covariance.B_inv.clone().to(self.device),
        ) if target_covariance is not None else None

    @classmethod
    def from_huggingface(
        cls,
        model_name: str,
        device: Optional[str] = None,
    ):
        clip = CLIPModel.from_pretrained(model_name)
        model = cls(
            logit_scale=clip.logit_scale.item(),
        )
        model = model.to(device) if device is not None else model
        return model

    def _compute_probabilistic_logits_smith(
        self,
        image_embeds,
        image_activations,
        text_embeds,
        text_activations
    ):
        """
        This function compute the expected value and variance of the cosine similarity between two probabilistic embeddings.
        The derivation adopts the approach by Smith et al. (2023).
        """
         
        source_covariance = self.source_covariance
        target_covariance = self.target_covariance

        source_activations = image_activations
        target_activations = text_activations

        if self.source_projection_has_bias:
            source_activations = torch.cat([source_activations, torch.ones_like(source_activations[:, :1])], dim=-1)
        
        if self.target_projection_has_bias:
            target_activations = torch.cat([target_activations, torch.ones_like(target_activations[:, :1])], dim=-1)

        source_embeds = image_embeds
        target_embeds = text_embeds

        source_B_factor = source_covariance.B_inv.diagonal()
        target_B_factor = target_covariance.B_inv.diagonal()

        source_diag_cov = torch.einsum('ij,jk,ik->i', source_activations, source_covariance.A_inv, source_activations)[:,None] * source_B_factor
        target_diag_cov = torch.einsum('ij,jk,ik->i', target_activations, target_covariance.A_inv, target_activations)[:,None] * target_B_factor

        norm_source = source_embeds**2 + source_diag_cov
        expect_norm_source = norm_source.sum(dim=-1, keepdim=True)
        norm_target = target_embeds**2 + target_diag_cov
        expect_norm_target = norm_target.sum(dim=-1, keepdim=True)

        # compute expected value
        expected_similarity = torch.matmul(source_embeds/torch.sqrt(expect_norm_source), (target_embeds/torch.sqrt(expect_norm_target)).t())

        # compute variance 
        term1 = torch.matmul(norm_source, target_diag_cov.t())
        term2 = torch.matmul(source_diag_cov, (target_embeds**2).t())
                
        variance_similarity = (term1 + term2)/(expect_norm_source*expect_norm_target.t())

        scale = self.logit_scale.exp()
        mean = expected_similarity * scale
        var = variance_similarity * (scale ** 2)
        return mean, var
        
    def forward(
            self, 
            images
        ):
        """
        Args:
            from_embeds (torch.Tensor): [batch_size, embed_dim]
            to_embeds (torch.Tensor): [batch_size, embed_dim]

        Returns:
            similarity (torch.Tensor): [#from, #to]
        """

        image_embeds, image_activations = self.open_clip_model.visual(images)
        text_embeds, text_activations = self.text_embeds, self.text_activations
        
        mean, var = self._compute_probabilistic_logits_smith(
            image_embeds, 
            image_activations,
            text_embeds,
            text_activations)

        kappa = 1 / torch.sqrt(1. + torch.pi / 8 * var)
        logits = kappa * mean
        return logits