from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.init import trunc_normal_
from transformers.models.bert import BertModel
from .resampler import MultiheadfusionAttention as MultiheadAttention
from .ms_deform_attn import MultiScaleDeformableAttention as DeformAttn

class ContrastiveEmbed(nn.Module):
    """text visual ContrastiveEmbed layer.

    Args:
        max_text_len (int, optional): Maximum length of text.
        log_scale (Optional[Union[str, float]]):  The initial value of a
          learnable parameter to multiply with the similarity
          matrix to normalize the output.  Defaults to 0.0.
          - If set to 'auto', the similarity matrix will be normalized by
            a fixed value ``sqrt(d_c)`` where ``d_c`` is the channel number.
          - If set to 'none' or ``None``, there is no normalization applied.
          - If set to a float number, the similarity matrix will be multiplied
            by ``exp(log_scale)``, where ``log_scale`` is learnable.
        bias (bool, optional): Whether to add bias to the output.
          If set to ``True``, a learnable bias that is initialized as -4.6
          will be added to the output. Useful when training from scratch.
          Defaults to False.
    """

    def __init__(self,
                 max_text_len: int = 256,
                 log_scale = None,
                 bias: bool = False):
        super().__init__()
        self.max_text_len = max_text_len
        self.log_scale = log_scale
        if isinstance(log_scale, float):
            self.log_scale = nn.Parameter(
                torch.Tensor([float(log_scale)]), requires_grad=True)
        elif log_scale not in ['auto', 'none', None]:
            raise ValueError(f'log_scale should be one of '
                             f'"auto", "none", None, but got {log_scale}')

        self.bias = None
        if bias:
            bias_value = -math.log((1 - 0.01) / 0.01)
            self.bias = nn.Parameter(
                torch.Tensor([bias_value]), requires_grad=True)

    def forward(self, visual_feat: Tensor, text_feat: Tensor,
                text_token_mask: Tensor) -> Tensor:
        """Forward function.

        Args:
            visual_feat (Tensor): Visual features.
            text_feat (Tensor): Text features.
            text_token_mask (Tensor): A mask used for text feats.

        Returns:
            Tensor: Classification score.
        """
        res = visual_feat @ text_feat.transpose(-1, -2)
        if isinstance(self.log_scale, nn.Parameter):
            res = res * self.log_scale.exp()
        elif self.log_scale == 'auto':
            # NOTE: similar to the normalizer in self-attention
            res = res / math.sqrt(visual_feat.shape[-1])
        if self.bias is not None:
            res = res + self.bias
        res.masked_fill_(~text_token_mask[:, None, :], float('-inf'))

        new_res = torch.full((*res.shape[:-1], self.max_text_len),
                             float('-inf'),
                             device=res.device)
        new_res[..., :res.shape[-1]] = res

        return new_res


class LVP(nn.Module):
    """
    Language guide visua projector
    """
    def __init__(self,
                 config = "./tex.json",
                 num_head = 8,
                 embed_dim=1024,
                 Nq = 144,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
                 hidden_size=4096,):
        self.text_encoder = BertModel(config=config, add_pooling_layer=False)
        self.text2image_attn = MultiheadAttention(embed_dim,num_head)
        self.image2text_attn = MultiheadAttention(embed_dim,num_head)
        self.lvp = ContrastiveEmbed()
        self.num_queries = Nq
    
        k_modules = [nn.Linear(4096, 1024)]
        for _ in range(1,2):
            k_modules.append(nn.GELU())
            k_modules.append(nn.Linear(1024, 1024))
        self.k_proj_1 = nn.Sequential(*k_modules)

        v_modules = [nn.Linear(4096, 1024)]
        for _ in range(1,2):
            v_modules.append(nn.GELU())
            v_modules.append(nn.Linear(1024, 1024))
        self.v_proj_1 = nn.Sequential(*v_modules)

        self.ln_q_1 = norm_layer(embed_dim)
        self.ln_k_1 = norm_layer(embed_dim)
        self.ln_v_1 = norm_layer(embed_dim)

        modules = [nn.Linear(1024, hidden_size)]
        for _ in range(1, 2):
            modules.append(nn.GELU())
            modules.append(nn.Linear(hidden_size, hidden_size))
        self.mlp = nn.Sequential(*modules)

        self.deform_attn = DeformAttn(num_heads=num_head,embed_dim=embed_dim,num_levels=1)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self,x,text_token):
        image_feature = x.clone()

        text_feautre = self.text_encoder(text_token)
        text_feautre = text_feautre["last_hidden_state"]

        text_feautre = self.k_proj_1(text_feautre)
        x = self.v_proj_1(x)

        # Text-to image and Image-to-Text Attention
        image_to_text_feature = self.image2text_attn(x, text_feautre, text_feautre)[0]
        text_to_image_feautre = self.text2image_attn(image_to_text_feature,x,x)

        # Language guide visual token selection
        similarity = self.lvp(image_to_text_feature,text_to_image_feautre)

        topk_indices = torch.topk(
            similarity.max(-1)[0], k=self.num_queries, dim=1)[1]
        
        topk_token = torch.gather(text_to_image_feautre, 0, topk_indices)

        # Deformable attention
        output = self.deform_attn(topk_token,image_feature,image_feature)

        return output
        

