import torch
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F
from typing import Optional, Any, Union, Callable
from torch import Tensor
import copy
import math


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        print('max_len: ',max_len)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        pos_emb = self.pe[:x.size(1)]
        pos_emb = pos_emb.squeeze(1).unsqueeze(0).repeat(x.size(0), 1, 1)
        x = x + pos_emb
        return x
        
def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

class PositionHeadLayer(nn.Module):
    def __init__(self, width_temp=512, num_segments=8):
        super(PositionHeadLayer, self).__init__()
        self.position_head = nn.Linear(width_temp, num_segments)
    
    def forward(self, videos_features):
       
        videos_features_cls = videos_features[:, :1, :] 
        videos_features_rest = videos_features[:, 1:, :] 
        B,N,D = videos_features_rest.size(0), int(videos_features_rest.size(1)/2), videos_features_rest.size(2)
        videos_features_blocks = videos_features_rest.view(B, N, 2, D) 
        perm = torch.randperm(8)  
        videos_features_blocks_shuffled = videos_features_blocks[:, perm, :, :] 
        videos_features_rest_shuffled = videos_features_blocks_shuffled.view(B, 2*N, D)
        videos_features_shuffled = torch.cat([videos_features_cls, videos_features_rest_shuffled], dim=1)  
        videos_features_transformed = self.transformer_encoder(videos_features_shuffled) 
        sep_features = videos_features_transformed[:, 2::2, :]  
        position_logits = self.position_head(sep_features)  
        return position_logits

class MultiheadAttentionNoWvWo(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        assert d_model % nhead == 0
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        B, T, D = x.shape  
        q = self.q_proj(x).view(B, T, self.nhead, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.nhead, self.head_dim).transpose(1, 2)
        v = x.view(B, T, self.nhead, self.head_dim).transpose(1, 2) 

        attn_weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, T, D) 

        return self.dropout(attn_output)

class TransformerEncoderLayer(nn.Module):
    
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        if isinstance(activation, str):
            self.activation = _get_activation_fn(activation)
        else:
            self.activation = activation

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)  
            x = x + self._ff_block(self.norm2(x))  
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) 
            x = self.norm2(x + self._ff_block(x))

        return x

    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class TransformerEncoder(nn.Module):
    
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        
        output = src 

        for mod in self.layers:  
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None: 
            output = self.norm(output)

        return output  


class Temporal_Module(nn.Module):
    def __init__(self, device, conf, tem_prompts=None, task_id=-1, is_diff=False):
        super().__init__()
        
        self.device = device
        dim_model = conf['dim_model']
        nhead = conf['nhead']
        num_layers = conf['num_layers']
        self.num_segments = conf['num_segments']
        self.is_tem_prompts = conf['is_tem_prompts']
        self.is_use_diff_feat = conf['is_use_diff_feat']
        self.is_auxiliary_training = conf['is_auxiliary_training']
        self.is_diff = is_diff

        if self.is_tem_prompts:
            self.positional_encoder = PositionalEncoding(d_model = dim_model, max_len = 2*self.num_segments + 1)
        elif self.is_use_diff_feat and not self.is_diff:
            self.positional_encoder = PositionalEncoding(d_model = dim_model, max_len = 2*(self.num_segments+1))
        else:
            self.positional_encoder = PositionalEncoding(d_model = dim_model, max_len = self.num_segments + 1)
        encoder_layer = TransformerEncoderLayer(d_model=dim_model, nhead=nhead, batch_first=True, 
                                                norm_first = True)
        self.embedding = nn.Embedding(1, dim_model)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)

        if self.is_auxiliary_training:
            self.position_head = nn.Linear(dim_model, self.num_segments)
        if tem_prompts is not None:
            self.tem_prompts = nn.Parameter(tem_prompts.clone())
        else:
            self.tem_prompts = None
    def forward(self, videos_features, promptModule = None, type_task = 'TIL', is_pos_distill=False, diff_feat=None):
        B,N,D = videos_features.size(0), videos_features.size(1), videos_features.size(2),
        idx = torch.tensor([0]*videos_features.size(0)).to(self.device) 
        cls_emb = self.embedding(idx) 
        cls_emb = torch.unsqueeze(cls_emb, 1) 

        if self.is_diff:
            diff_features = torch.cat((cls_emb, videos_features), 1)
            diff_features = self.positional_encoder(diff_features) 
            diff_features = self.transformer_encoder(diff_features)
            return diff_features, diff_features[:, 0, :] 

        if self.tem_prompts is not None:
            expanded_tem_prompts = self.tem_prompts.unsqueeze(0).repeat(B, 1, 1)
            combined_features = torch.stack([videos_features, expanded_tem_prompts], dim=2).view(B, 2 * N, D) 
            videos_features = torch.cat((cls_emb, combined_features), 1)
        elif diff_feat is not None:
            seq_len = diff_feat.size(1)
            videos_features = torch.cat((cls_emb, videos_features), 1)
            videos_features = torch.stack([videos_features, diff_feat], dim=2).view(B, 2 * seq_len, D)
        else:
            videos_features = torch.cat((cls_emb, videos_features), 1) 

        zero_shape = (1, videos_features.size(1), videos_features.size(2)) 
        zero_tensor = torch.zeros(zero_shape).to(self.device)

        videos_features = self.positional_encoder(videos_features) 

        if self.is_auxiliary_training:
            videos_features_cls = videos_features[:, :1, :]  
            videos_features_rest = videos_features[:, 1:, :]  
            videos_features_blocks = videos_features_rest.view(B, N, 2, D) 
            perm = torch.randperm(8) 
            videos_features_blocks_shuffled = videos_features_blocks[:, perm, :, :]  
            videos_features_rest_shuffled = videos_features_blocks_shuffled.view(B, 2*N, D)
            videos_features_shuffled = torch.cat([videos_features_cls, videos_features_rest_shuffled], dim=1) 
            videos_features_transformed = self.transformer_encoder(videos_features_shuffled) 
            sep_features = videos_features_transformed[:, 2::2, :] 
            position_logits = self.position_head(sep_features)  
        else:
            position_logits = None
            perm = None

        if promptModule != None:
            videos_features = promptModule(videos_features, 'temp', type_task)

        videos_features = self.transformer_encoder(videos_features) 
        if promptModule != None: 
            L_tp = promptModule.num_sel_prompts*promptModule.L_tp_gn if promptModule.type_prompt == 'general' else promptModule.num_sel_prompts*promptModule.L_tp_tk  # 3
            videos_features = videos_features[:, :L_tp, :]
            context_emb = torch.mean(videos_features, dim = 1) 
        else:
            context_emb = videos_features[:, 0, :]
        if is_pos_distill:
            pos_emcoder = self.positional_encoder(zero_tensor)  
            pos_trans = self.transformer_encoder(pos_emcoder)
            pos_embed = pos_trans[:, 1:, :].squeeze(0)
        else:
            pos_embed=None
        return context_emb, pos_embed, position_logits, perm  