import torch
import torch.nn as nn
import torchvision.ops as ops
import torch.nn.functional as F
import re

class TAC(nn.Module):
    def __init__(self):
        super(TAC,self).__init__()
        self.dinov2_dim = 768
        self.hidden_dim = 4096
        self.dropout = 0.1
        # LEF
        self.LEF = nn.Sequential(
            ops.SqueezeExcitation(12,6,activation=nn.GELU),
            nn.Conv2d(12,6,kernel_size=1,bias=False),
            ops.SqueezeExcitation(6,3,activation=nn.GELU),
            nn.Conv2d(6,3,kernel_size=1,bias=False),
            ops.SqueezeExcitation(3,1,activation=nn.GELU),
            nn.Conv2d(3,1,kernel_size=1,bias=False)
        )
        self.selector_prior_bias = nn.Parameter(torch.tensor(0.0)) 
        self.selector_cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
        self.cur_self_attention = nn.MultiheadAttention(embed_dim=(self.dinov2_dim), num_heads=32,batch_first=True,add_bias_kv=True)
        self.prior_self_attention = nn.MultiheadAttention(embed_dim=(self.dinov2_dim), num_heads=32,batch_first=True,add_bias_kv=True)
        self.cros_attention = nn.MultiheadAttention(embed_dim=(self.dinov2_dim), num_heads=32,batch_first=True,add_bias_kv=True)
        
        self.norm1 = nn.LayerNorm(self.dinov2_dim)
        self.norm2 = nn.LayerNorm(self.dinov2_dim)
        self.norm3 = nn.LayerNorm(self.dinov2_dim)
        self.norm4 = nn.LayerNorm(self.dinov2_dim)
        
        self.mlp_ff = nn.Sequential(
            nn.Linear(self.dinov2_dim, self.dinov2_dim),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.dinov2_dim, self.dinov2_dim),
            nn.Dropout(self.dropout)
        )

        self.mlp_final = nn.Sequential(
            nn.Linear(self.dinov2_dim, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, self.hidden_dim)
        )
        self.dropout1 = nn.Dropout(self.dropout)
        self.dropout2 = nn.Dropout(self.dropout)
        self.dropout3 = nn.Dropout(self.dropout)
        
    def calculate_cosine_similarity(self, tensor1, tensor2):
        assert tensor1.shape == tensor2.shape,

        tensor1_flat = tensor1.view(tensor1.size(0), -1)
        tensor2_flat = tensor2.view(tensor2.size(0), -1)

        tensor1_flat_normalized = tensor1_flat / tensor1_flat.norm(dim=-1, keepdim=True)
        tensor2_flat_normalized = tensor2_flat / tensor2_flat.norm(dim=-1, keepdim=True)
        cosine_similarities = self.selector_cos(tensor1_flat_normalized, tensor2_flat_normalized)

        cosine_similarities_normalized = ((cosine_similarities + 1) / 2).pow(self.hidden_dim.pow(1/4))
        cosine_similarities_normalized = cosine_similarities_normalized.view(-1, 1, 1)
        return cosine_similarities_normalized 
    
    # self-attention block   
    def cur_self_att_block(self,x):
        x = self.cur_self_attention(x,x,x)[0]
        return self.dropout1(x)
    # self-attention block   
    def prior_self_att_block(self,x):
        x = self.prior_self_attention(x,x,x)[0]
        return self.dropout2(x)
    # cross attention block 
    def cros_att_block(self,x,y):
        x = self.cros_attention(x,y,y)[0]
        return self.dropout3(x)
        
    def forward(self, cur_features, prev_features, *args, **kwargs):       
        cur_features = self.LEF(cur_features.permute(1, 0, 2, 3)).squeeze(1)
        prev_features = self.LEF(prev_features.permute(1, 0, 2, 3)).squeeze(1)
        cur_features_orig = cur_features.clone()
        cos= self.calculate_cosine_similarity(cur_features,prev_features)
        prev_weight = cos * self.selector_prior_bias
        prev_features= prev_features + prev_weight
        cur_features = self.norm1(cur_features + self.cur_self_att_block(cur_features))
        prev_features = self.norm2(prev_features + self.prior_self_att_block(prev_features))
        combined_features = self.norm3(cur_features + self.cros_att_block(cur_features,prev_features))
        output = self.norm4(cur_features_orig + self.mlp_ff(combined_features))
        output = self.mlp_final(output)
        return output    
    @property
    def config(self):
        return {"mm_projector_type": 'temporal_alignment_connector'}

    

def build_vision_projector(config, delay_load=False, *args,**kwargs):
    projector_type = getattr(config, 'mm_projector_type', 'linear')
    if projector_type == 'linear':
        return nn.Linear(config.mm_hidden_size, config.hidden_size)
    mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
    if mlp_gelu_match:
        mlp_depth = int(mlp_gelu_match.group(1))
        modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        return nn.Sequential(*modules)
    
    if projector_type == 'temporal_alignment_connector':
        return TAC()
    
    raise ValueError(f'Unknown projector type: {projector_type}')
