import copy

import torch
from torch import nn
import einops as E

from torchtune.modules import (
    CausalSelfAttention,
    FeedForward,
    RMSNorm,
    RotaryPositionalEmbeddings,
)


def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
    """
    Return a list of ``n`` identical layers.

    Args:
        module (nn.Module): module to be cloned
        n (int): number of clones

    Returns:
        nn.ModuleList: list of ``n`` identical layers
    """
    # FIXME: copy.deepcopy() is not defined on nn.module
    return nn.ModuleList([copy.deepcopy(module) for i in range(n)])


def llama3_mlp(dim: int, hidden_dim: int) -> FeedForward:
    """
    Build the MLP layer associated with the Llama model.
    """
    gate_proj = nn.Linear(dim, hidden_dim, bias=False)
    down_proj = nn.Linear(hidden_dim, dim, bias=False)
    up_proj = nn.Linear(dim, hidden_dim, bias=False)
    return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)


class SelfAttention(nn.MultiheadAttention):
    def forward(self, x):
        x, _ = super().forward(x, x, x)
        return x


class SpatioTemporalTransformerLayer(nn.Module):
    def __init__(self, time_attn, space_attn, mlp, time_norm, space_norm, mlp_norm):
        super(SpatioTemporalTransformerLayer, self).__init__()
        self.time_attn = time_attn
        self.space_attn = space_attn
        self.mlp = mlp

        self.time_norm = time_norm
        self.space_norm = space_norm
        self.mlp_norm = mlp_norm
    
    def forward(self, x):
        '''
        Args:
            x (torch.Tensor): (batch_size, time, space, dim)
        '''
        time, space = x.size(1), x.size(2)

        x = E.rearrange(x, 'b t s d -> (b s) t d')
        x = self.time_attn(self.time_norm(x)) + x

        x = E.rearrange(x, '(b s) t d -> (b t) s d', s=space)
        x = self.space_attn(self.space_norm(x)) + x

        x = E.rearrange(x, '(b t) s d -> b s t d', t=time)
        x = self.mlp(self.mlp_norm(x)) + x
        return x


class SpatioTemporalTransformer(nn.Module):
    def __init__(self, dim, num_heads, num_layers, max_seq_len=4096, num_kv_heads=None, rope_base=500000):
        super(SpatioTemporalTransformer, self).__init__()
        head_dim = dim // num_heads
        num_kv_heads = num_kv_heads if num_kv_heads else num_heads
        hidden_dim = 4 * dim

        rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
        
        time_attn = CausalSelfAttention(
            embed_dim=dim, 
            num_heads=num_heads, 
            num_kv_heads=num_heads,
            head_dim=head_dim, 
            q_proj=nn.Linear(dim, num_heads * head_dim, bias=False),
            k_proj=nn.Linear(dim, num_kv_heads * head_dim, bias=False),
            v_proj=nn.Linear(dim, num_kv_heads * head_dim, bias=False),
            output_proj=nn.Linear(dim, dim, bias=False),
            pos_embeddings=rope,
            max_seq_len=max_seq_len,
            attn_dropout=0.0,
        )

        space_attn = SelfAttention(
            embed_dim=dim, 
            num_heads=num_heads, 
            dropout=0.0,
            bias=False,
            batch_first=True,
        )

        mlp = llama3_mlp(dim, hidden_dim=hidden_dim)

        norm = RMSNorm(dim)

        layer = SpatioTemporalTransformerLayer(time_attn, space_attn, mlp, *_get_clones(norm, 3))
        self.layers = _get_clones(layer, num_layers)
        self.norm = norm
    
    def forward(self, x):
        '''
        Args:
            x (torch.Tensor): (seq_len, batch_size, num_slots, dim)
        '''
        # Make x batch first
        x = x.transpose(0, 1)
        
        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)

        # Make x time first
        x = x.transpose(0, 1)
        return x