import copy
import torch
import torch.nn.functional as F
from torch import nn

class IterativeDecoder(nn.Module):
    def __init__(self, 
        d_model=256, 
        nhead=8,
        num_decoder_layers=6, 
        dim_feedforward=1024, 
        dropout=0.1,
        activation="relu", 
        return_intermediate_dec=False,
        extra_track_attn=False, 
        n_detect_query=50):
        super().__init__()

        self.d_model = d_model
        self.nhead = nhead

        self.encoder = None

        if num_decoder_layers is not None:
            decoder_layer = CoarseDecoderLayer(
                d_model, 
                dim_feedforward, 
                dropout, 
                activation,
                nhead,)
            
            scene_decoder_layer = SceneDecoderLayer(
                d_model, 
                dim_feedforward, 
                dropout, 
                activation,
                nhead,)
            
            fine_action_decoder_layer = FineDecoderLayer(
                d_model, 
                dim_feedforward, 
                dropout, 
                activation,
                nhead,)

            self.coarse_action_decoder = CoarseDecoder(
                decoder_layer, 
                num_decoder_layers,
                return_intermediate_dec)
        
            self.scene_decoder = SceneDecoder(
                scene_decoder_layer, 
                num_decoder_layers,
                return_intermediate_dec)

            self.fine_action_decoder = FineDecoder(
                fine_action_decoder_layer, 
                num_decoder_layers,
                return_intermediate_dec)
            
        else:
            self.decoder = None

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
    def forward(self, src, tgt=None, tgt_occ=None, src_padding_mask=None, tgt_padding_mask=None):

        bs, num_mode, num_agents, dim = tgt.shape
        num_elements = src.shape[1]
        _, H, W ,_ = tgt_occ.shape
        action = tgt.view(bs * num_mode, num_agents, dim) 
        vector_feature = src.unsqueeze(1).repeat(1, num_mode, 1, 1).view(bs * num_mode, num_elements, dim) 
        occ = tgt_occ.unsqueeze(1).repeat(1, num_mode, 1, 1, 1).view(-1, H, W, dim) 
        tgt_padding_mask_m = tgt_padding_mask.unsqueeze(1).repeat(1, num_mode, 1).view(-1, num_agents) 
        src_padding_mask_m = src_padding_mask.unsqueeze(1).repeat(1, num_mode, 1).view(-1, num_elements) 

        # decode coarse trajectories 
        coarse_action = self.coarse_action_decoder(action, vector_feature, occ, tgt_padding_mask_m, src_padding_mask_m) 
        # decoder joint occ scenes 
        joint_scene = self.scene_decoder(occ, coarse_action, tgt_padding_mask_m) 
        # # decode fine trajectories 
        fine_action = self.fine_action_decoder(coarse_action, vector_feature, joint_scene, tgt_padding_mask_m, src_padding_mask_m)

        return coarse_action, joint_scene, fine_action


class CoarseDecoderLayer(nn.Module):
    def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_heads=8):
        super().__init__()
        self.num_head = n_heads
        
        # cross attention
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        self.cross_attn_2 = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.dropout5 = nn.Dropout(dropout)
        self.norm4 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)


    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def _forward_self_attn(self, tgt, tgt_padding_mask):
        # q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(tgt.transpose(0, 1), 
                              tgt.transpose(0, 1), 
                              tgt.transpose(0, 1), 
                              key_padding_mask=tgt_padding_mask)[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        return self.norm2(tgt)

    def forward(self, tgt, src, src_occ, tgt_padding_mask=None, src_padding_mask=None,):
        # self attention
        tgt = self._forward_self_attn(tgt, tgt_padding_mask)

        # cross attention with vectorized feature
        tgt2 = self.cross_attn(tgt.transpose(0, 1), 
                               src.transpose(0, 1), 
                               src.transpose(0, 1),
                               key_padding_mask=src_padding_mask)[0].transpose(0, 1)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # cross attention with occ feature
        tgt3 = self.cross_attn_2(tgt.transpose(0, 1), 
                               src_occ.flatten(start_dim=1, end_dim=2).transpose(0, 1),
                               src_occ.flatten(start_dim=1, end_dim=2).transpose(0, 1),
                               key_padding_mask=None)[0].transpose(0, 1)
        tgt = tgt + self.dropout5(tgt3)
        tgt = self.norm4(tgt)

        # ffn
        tgt = self.forward_ffn(tgt)
        return tgt  

class CoarseDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate

    def forward(self, tgt, src, src_occ, tgt_padding_mask = None, src_padding_mask = None):

        output = tgt
        intermediate = []
        for lid, layer in enumerate(self.layers):
            output = layer(output, src, src_occ, tgt_padding_mask, src_padding_mask)
            if self.return_intermediate:
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)
        else:
            return output

class SceneDecoderLayer(nn.Module):
    def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_heads=8):
        super().__init__()
        self.num_head = n_heads
        
        # cross attention
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)


    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt
    
    def _forward_self_attn(self, tgt, tgt_padding_mask):
        tgt2 = self.self_attn(tgt.transpose(0, 1), 
                              tgt.transpose(0, 1), 
                              tgt.transpose(0, 1), 
                              key_padding_mask=tgt_padding_mask)[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        return self.norm2(tgt)


    def forward(self, tgt, src, src_padding_mask=None, tgt_padding_mask=None):

        _ , H, W, dim = tgt.shape
        tgt = tgt.view(-1, H*W, dim)
        # cross attention with coarse action
        tgt2 = self.cross_attn(tgt.transpose(0, 1), 
                               src.transpose(0, 1), 
                               src.transpose(0, 1),
                               key_padding_mask=src_padding_mask)[0].transpose(0, 1)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # self attention
        tgt = self._forward_self_attn(tgt, tgt_padding_mask)

        # ffn
        tgt = self.forward_ffn(tgt)

        return tgt.view(-1, H, W, dim)

class SceneDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate

    def forward(self, tgt, src, src_padding_mask=None, tgt_padding_mask=None):
        output = tgt
        intermediate = []
        for lid, layer in enumerate(self.layers):
            output = layer(output, src, src_padding_mask, tgt_padding_mask)
            if self.return_intermediate:
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)
        else:
            return output

class FineDecoderLayer(nn.Module):
    def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_heads=8):
        super().__init__()
        self.num_head = n_heads
        
        # cross attention
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        self.cross_attn_2 = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.dropout5 = nn.Dropout(dropout)
        self.norm4 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)


    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def _forward_self_attn(self, tgt, tgt_padding_mask):
        # q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(tgt.transpose(0, 1), 
                              tgt.transpose(0, 1), 
                              tgt.transpose(0, 1), 
                              key_padding_mask=tgt_padding_mask)[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        return self.norm2(tgt)

    def forward(self, tgt, src, src_occ, tgt_padding_mask=None, src_padding_mask=None):
        # self attention
        tgt = self._forward_self_attn(tgt, tgt_padding_mask)

        # cross attention with vector feature
        tgt2 = self.cross_attn(tgt.transpose(0, 1), 
                               src.transpose(0, 1),
                               src.transpose(0, 1),
                               key_padding_mask=src_padding_mask)[0].transpose(0, 1)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # cross attention with future occ feature
        tgt3 = self.cross_attn_2(tgt.transpose(0, 1), 
                               src_occ.flatten(start_dim=1, end_dim=2).transpose(0, 1),
                               src_occ.flatten(start_dim=1, end_dim=2).transpose(0, 1),
                               key_padding_mask=None)[0].transpose(0, 1)
        tgt = tgt + self.dropout5(tgt3)
        tgt = self.norm4(tgt)
        # ffn
        tgt = self.forward_ffn(tgt)

        return tgt  

class FineDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate

    def forward(self, tgt, src, src_occ, tgt_padding_mask = None, src_padding_mask = None):
        output = tgt
        intermediate = []
        for lid, layer in enumerate(self.layers):
            output = layer(output, src, src_occ, tgt_padding_mask, src_padding_mask)
            if self.return_intermediate:
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)
        else:
            return output
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return nn.ReLU(True)
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")



