import torch
import torch.nn.functional as F
from torch import nn
from timm.models.layers import DropPath
from einops import pack, repeat, unpack
from timm.models.layers import trunc_normal_

def visualize_cross_attention(attn_weights, text_tokens, text_inputs=""):
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from tqdm import tqdm
    """
    Fast version using matplotlib.imshow instead of seaborn.heatmap.
    """
    bsz = len(text_tokens)
    num_heads = attn_weights.size(0) // bsz
    vis_len = attn_weights.size(1)
    text_len = attn_weights.size(2)

    attn_avg = attn_weights.view(bsz, num_heads, vis_len, text_len).mean(dim=1)  # (bsz, vis_len, text_len)
    attn_avg = attn_avg.cpu().detach().numpy()

    fig, axs = plt.subplots(2, 3, figsize=(18, 8), constrained_layout=True)
    axs = axs.flatten()
    
    vmin = np.percentile(attn_avg, 5)
    vmax = np.percentile(attn_avg, 95)
    #cmap='Blues'

    for i in tqdm(range(6)):
        ax = axs[i]
        heatmap = attn_avg[i]
        
        # X축: 텍스트 토큰
        t_len = min(text_len, len(text_tokens[i]))
        xticks = list(range(t_len))
        xticklabels = text_tokens[i][:t_len]

        # Y축: 비전 토큰
        v_len = vis_len

        im = ax.imshow(heatmap[:v_len, :t_len], aspect='auto', cmap='Blues', vmin=vmin, vmax=vmax)
        #im = ax.imshow(heatmap, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)

        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, rotation=25, fontsize=16)
        ax.set_yticks([])

        ax.set_title(f"{'Pos' if i < 3 else 'Neg'} #{i % 3 + 1} {text_inputs[i]}", fontsize=16)

    #cbar_ax = fig.add_axes([0.92, 0.12, 0.015, 0.76])
    #cbar = fig.colorbar(im, cax=cbar_ax)
    #cbar.ax.tick_params(labelsize=16)
    #fig.subplots_adjust(wspace=0.3, hspace=0.3)
    #fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.015, pad=0.02)
    fig.savefig("/home/sojungan/GLEE/vis/test.png", dpi=150)
    plt.close(fig)


class VLFuse(torch.nn.Module):
    """
    Early Fusion Module
    """

    def __init__(self, register: bool=False):
        super(VLFuse, self).__init__()
        self.init_configs()

        # early fusion module
        # bi-direction (text->image, image->text)
        #self.register = register
        self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.img_dim, # 256
        #self.b_attn = BiAttentionBlockRegister(v_dim=self.img_dim, # 256                                            
                    l_dim=self.lang_dim, # 768
                    embed_dim=self.embed_dim, # 2048
                    num_heads=self.n_head, # 8
                    dropout=0.1,
                    drop_path=.0,
                    init_values=1.0 / 6,
                    )
        
    def init_configs(self, ):
        # common params
        self.img_dim =  256

        self.max_query_len = 256
        self.n_layers =1

        # mha params
        self.n_head = 8
        self.embed_dim = 2048 # 2048 by default
        
        self.lang_dim = 256

    def forward(self, x, task=None, visual=False):  
        visual_features = x["visual"]
        language_dict_features = x["lang"]
        
        fused_visual_features, language_features = self.b_attn(
            visual_features, language_dict_features["hidden"], language_dict_features["masks"], task
        )
        language_dict_features["hidden"] = language_features
        return {"visual": fused_visual_features, "lang": language_dict_features}


class BiMultiHeadAttention(nn.Module):
    def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1):
        super(BiMultiHeadAttention, self).__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.v_dim = v_dim
        self.l_dim = l_dim

        assert (
                self.head_dim * self.num_heads == self.embed_dim
        ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
        self.scale = self.head_dim ** (-0.5)
        self.dropout = dropout

        self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
        self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
        self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
        self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)

        self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
        self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)

        self.stable_softmax_2d =  False
        self.clamp_min_for_underflow = True
        self.clamp_max_for_overflow = True

        self._reset_parameters()

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.v_proj.weight)
        self.v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.l_proj.weight)
        self.l_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.values_v_proj.weight)
        self.values_v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.values_l_proj.weight)
        self.values_l_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.out_v_proj.weight)
        self.out_v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.out_l_proj.weight)
        self.out_l_proj.bias.data.fill_(0)

    def forward(self, v, l, attention_mask_l=None, visual=False):
        bsz, tgt_len, embed_dim = v.size()

        query_states = self.v_proj(v) * self.scale
        key_states = self._shape(self.l_proj(l), -1, bsz)
        value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
        value_l_states = self._shape(self.values_l_proj(l), -1, bsz)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim) # (bs * 8, -1, embed_dim//8)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) # (bs * 8, seq_len_img, embed_dim//8)
        key_states = key_states.view(*proj_shape) # (bs * 8, seq_len_text, embed_dim//8)
        value_v_states = value_v_states.view(*proj_shape)
        value_l_states = value_l_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # (bs * 8, seq_len_img, seq_len_text)
        
        if visual:
            import re
            CLIP_PATH = "/home/sojungan/weight/GLEE/clip_vit_base_patch32"
            from transformers import CLIPTokenizer
            tokenizer = CLIPTokenizer.from_pretrained(CLIP_PATH) 
            visual_weight = torch.softmax(attn_weights, dim=-1)
            text_inputs = ['plate', 'White plate', 'White plate with a fork sitting on it',
                           'dressage', 'home plate', 'White plate with red, blue, and yellow square pattern']
            #text_inputs = [
            #    'man', 'the black man', 'the black man wears a red t shirt',
            #    'woman', 'the cave man', 'the black man in the mirror']
            short_token_map = {
                "<|startoftext|>": "[SOS]",
                "<|endoftext|>": "[EOS]"
            }

            def shorten_special_tokens(tokens):
                return [short_token_map.get(tok, tok) for tok in tokens]
            tokenized = [tokenizer.convert_ids_to_tokens(tokenizer(text)["input_ids"]) for text in text_inputs]
            cleaned = [[re.sub(r"</[Ww]>", "", tok) for tok in toks] for toks in tokenized]
            cleaned = [shorten_special_tokens(tokens) for tokens in cleaned]
            visualize_cross_attention(visual_weight, text_tokens=cleaned, text_inputs=text_inputs)
            exit()


        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
            )

        # attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1)

        if self.stable_softmax_2d:
            attn_weights = attn_weights - attn_weights.max()
        
        if self.clamp_min_for_underflow:
            attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
        if self.clamp_max_for_overflow:
            attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range

        attn_weights_T = attn_weights.transpose(1, 2)
        attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
            0])
        if self.clamp_min_for_underflow:
            attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
        if self.clamp_max_for_overflow:
            attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range

        attn_weights_l = attn_weights_l.softmax(dim=-1)
        # assert attention_mask_l.dtype == torch.int64

        if attention_mask_l is not None:
            assert (attention_mask_l.dim() == 2) # (bs, seq_len)
            attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) # (bs, 1, 1, seq_len)
            attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
            attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)

            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)

        attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
        attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)

        attn_output_v = torch.bmm(attn_probs_v, value_l_states)
        attn_output_l = torch.bmm(attn_probs_l, value_v_states)


        if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
            )

        if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
            raise ValueError(
                f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
            )

        attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output_v = attn_output_v.transpose(1, 2)
        attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)

        attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
        attn_output_l = attn_output_l.transpose(1, 2)
        attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)

        attn_output_v = self.out_v_proj(attn_output_v)
        attn_output_l = self.out_l_proj(attn_output_l)

        return attn_output_v, attn_output_l


class BiAttentionBlockForCheckpoint(nn.Module):
    
    def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
                 drop_path=.0, init_values=1e-4,  ):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super(BiAttentionBlockForCheckpoint, self).__init__()

        # pre layer norm
        self.layer_norm_v = nn.LayerNorm(v_dim)
        self.layer_norm_l = nn.LayerNorm(l_dim)
        self.attn = BiMultiHeadAttention(v_dim=v_dim,
                                         l_dim=l_dim,
                                         embed_dim=embed_dim,
                                         num_heads=num_heads,
                                         dropout=dropout,
                                        )

        # add layer scale for training stability
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        #register tokens
        self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
        self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)


    def forward(self, v, l, attention_mask_l=None, task=None, visual=False):
        # v: visual features, (bs, sigma(HW), 256)
        # l: language features, (bs, seq_len, 768)
        v = self.layer_norm_v(v) # [1, 9180, 256]
        l = self.layer_norm_l(l) # [1, 834, 256]
        delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l, visual=visual) 
        v = v + self.drop_path(self.gamma_v * delta_v)
        l = l + self.drop_path(self.gamma_l * delta_l)
        return v, l


