import torch
from torch import nn
from torchtune.modules import RMSNorm

import einops as E
from einops.layers.torch import Rearrange

from src.models.modules.matching import seq_batch_linear_assignment, apply_sequential_permutation


class NoncausalParallel(nn.Module):
    kernel_size = 4
    #kernel_size = 2
    def __init__(self, embed_dim, num_actions, num_slots, delta_dim, lnorm=False):
        super(NoncausalParallel, self).__init__()

        component_size = embed_dim // num_actions
        
        self.delta_prep = nn.Sequential(
            Rearrange('t ... f n h a -> t ... n (f h a)'),
            # # RMSNorm(component_size),
            nn.LayerNorm(component_size) if lnorm else nn.Identity(),
        )
        self.delta_conv = nn.Sequential(
            Rearrange('t ... fha -> (...) fha t'),
            nn.Conv1d(component_size, component_size, self.kernel_size, 
                      padding = self.kernel_size // 2),
            Rearrange('(b s n) fha t -> t b s n fha', n=num_actions, s=num_slots),
        )
        self.delta_proj = nn.Sequential(
            # RMSNorm(component_size),
            nn.SiLU(),
            nn.Linear(component_size, delta_dim),
        )

        self.component_size = component_size
        self.delta_dim = delta_dim
        self.num_actions = num_actions
        self.num_slots = num_slots

    def forward(self, slots, actions=None):
    # def forward(self, slots):
        deltas = self.delta_prep(slots)
        deltas = self.delta_conv(deltas)
        deltas = self.delta_proj(deltas)

        return slots, deltas[1:-1]    # len(deltas) = len(slots) - 1


class CausalParallel(NoncausalParallel):
    kernel_size = 4
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.delta_conv = nn.Sequential(
            Rearrange('t ... fha -> (...) fha t'),
            nn.ZeroPad1d((self.kernel_size, 0)),
            nn.Conv1d(self.component_size, self.component_size, self.kernel_size),
            Rearrange('(b s n) fha t -> t b s n fha', n=self.num_actions, s=self.num_slots),
        )

def kinetic_energy(joint_deltas, latents):
    '''
    Compute the kinetic energy of the system given the joint deltas δ(t) and the latents z(t):
    KE = 1/2 Σ_i ||δ_i(t)||^2 ||z_i(t)||^2
    where i is the index of the num. of actions, and the norm is the L2 norm.

    Args:
        joint_deltas (torch.Tensor): (seq, batch, slot_next, slot, num_actions dim)
        latents (torch.Tensor): (seq, batch, slot, freq, num_actions, ...)
    Output:
        ke (torch.Tensor): (seq, batch, slot)
    '''
    latents = E.rearrange(latents, 'seq batch slot f na ... -> seq batch () slot na (f ...)')
    z_norm = latents.norm(dim=-1, p=2).pow(2)
    delta_norm = joint_deltas.norm(dim=-1, p=2).pow(2)
    ke = 0.5 * (delta_norm * z_norm).mean(dim=-1)
    return ke


class NoncausalJoint(NoncausalParallel):
    alignment_only = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        dim = kwargs['embed_dim']
        
        component_size = dim // self.num_actions
        self.delta_conv = nn.Linear(2 * component_size, component_size)

        # bi_component_size = 2 * dim // self.num_actions
        # hidden = 4 * bi_component_size
        # self.delta_conv = nn.Sequential(
        #     nn.Linear(bi_component_size, hidden),
        #     nn.SiLU(),
        #     nn.Linear(hidden, bi_component_size),
        # )

        # self.delta_proj = nn.Sequential(
        #     RMSNorm(bi_component_size),
        #     nn.Linear(bi_component_size, self.delta_dim)
        # )

    def _get_joint_deltas(self, latents):
        latents = self.delta_prep(latents) # [seq, batch, slot, num_actions, (freq, hidden, action_dim)]

        deltas_current = E.repeat(latents[:-1], 
                                'seq batch slot ... -> seq batch slot slot_next ...', 
                                slot_next=self.num_slots)
        
        deltas_next = E.repeat(latents[1:], 
                             'seq batch slot_next ... -> seq batch slot slot_next ...', 
                             slot=self.num_slots)
        
        joint_deltas = torch.cat([deltas_current, deltas_next], dim=-1)
        joint_deltas = self.delta_conv(joint_deltas)
        joint_deltas = self.delta_proj(joint_deltas)
        # joint_deltas = self.delta_conv(joint_deltas) + joint_deltas
        # joint_deltas = self.delta_proj(joint_deltas)
        return joint_deltas   # len(joint_deltas) = len(latents) - 1
    
    def align_slots(self, latents, assignments):
        # Apply the sequential permutation to the latents

        dims = latents.shape[3:]
        latents = latents.flatten(3, -1)
        
        latents = apply_sequential_permutation(assignments, latents, first_identity=True, inverse=True)
        latents = latents.reshape(*latents.shape[:-1], *dims)
        return latents


    def forward(self, latents, actions=None):

        joint_deltas = self._get_joint_deltas(latents)  # [seq, batch, slot, slot_next, num_actions dim]

        # cost_matrix = joint_deltas.norm(dim=-1, p=1).detach()
        # cost_matrix = joint_deltas.pow(2).mean(dim=[-1, -2]).detach()
        cost_matrix = kinetic_energy(joint_deltas, latents[:-1])
        assignments = seq_batch_linear_assignment(cost_matrix) # [seq-1, batch, slot_next]

        latents = self.align_slots(latents, assignments)
        if self.alignment_only:
            return latents

        # Here assignment [seq batch slot] contains indices of deltas achieving the minimum cost.
        # Specifically, assignment[t, b, s] = k means that joint_deltas[t, b, s, k, :] is the optimal delta.
        # We want to get the optimal deltas [seq batch slot num_acts dim] from joint_deltas using the assignment.

        # Expand assignment to match the dimensions of joint_deltas
        expanded_assignments = E.repeat(assignments, 
                                        'seq batch slot -> seq batch slot () num_acts dim', 
                                        num_acts=joint_deltas.shape[-2],
                                        dim=joint_deltas.shape[-1])

        # Use the expanded assignment in `gather` to get the optimal deltas
        deltas = joint_deltas.gather(3, expanded_assignments).squeeze(3)

        # # sanity check
        # ea2 = E.repeat(assignments, 
        #                 'seq batch slot -> seq batch slot () num_acts dim', 
        #                 num_acts=joint_deltas.shape[-2],
        #                 dim=joint_deltas.shape[-1])
        # deltas2 = joint_deltas.gather(2, ea2).squeeze(3)
        # print(deltas.pow(2).mean(), deltas2.pow(2).mean())

        # Apply the sequential permutation to the deltas
        deltas = E.rearrange(deltas, 'seq batch slot nacts dim -> seq batch slot (nacts dim)')
        deltas = apply_sequential_permutation(assignments[:-1], deltas, first_identity=True, inverse=True)
        # print('deltas after permutation:')
        # print(deltas[:, 0][..., 0].squeeze())
        deltas = E.rearrange(deltas, 'seq batch slot (nacts dim) -> seq batch slot nacts dim', dim=self.delta_dim)

        return latents, deltas


class Controller(nn.Module):
    kernel_size = 4
    def __init__(self, embed_dim, num_actions, num_slots, delta_dim,
                 transformer, max_action_vocab_size=100):
        super().__init__()
        self.to_action_embedding = nn.Embedding(max_action_vocab_size, embed_dim)

        component_size = embed_dim // num_actions
        self.proj = nn.Sequential(
            Rearrange('... f n h a -> ... n (f h a)'),
            # RMSNorm(component_size),            
            # nn.LayerNorm(component_size),            
            nn.SiLU(),
            nn.Linear(component_size, delta_dim),
            # nn.Linear(embed_dim, num_actions * delta_dim),
            # Rearrange('... (a d) -> ... a d', a=num_actions),
        )

        self.transformer = transformer

    def forward(self, latents, actions):
        # keep the last 4 dims of latents
        f, n, h, a = latents.shape[-4:]
        dims = {'f': f, 'n': n, 'h': h, 'a': a}

        # flatten the last 4 dims
        slots = E.rearrange(latents, 'seq batch slot ... -> seq batch slot (...)')
        
        slots = self._forward(slots, actions)

        # reshape slots to original shape
        slots = E.rearrange(slots, 'seq batch slot (f n h a) -> seq batch slot f n h a', **dims)

        # Project to deltas
        deltas = self.proj(slots)

        return latents, deltas
        

    def _forward(self, slots, actions):
        '''
        Given slots and actions, predict deltas (action signals for the next slots).
        The actions are given as discrete labels and should be converted to embeddings.
        The slots and action embeddings are concatenated and passed through the spatio temporal transformer.
        The output of the transformer is passed through a linear layer to predict the deltas.

        Args:
            slots (torch.Tensor): (seq_len, batch_size, num_slots, dim)
            actions (torch.Tensor): (seq_len, batch_size)
        Output:
            deltas (torch.Tensor): (seq_len, batch_size, num_slots, delta_dim)
        '''

        # Concatenate slots and action embeddings
        actions = self.to_action_embedding(actions)

        tokens, ps = E.pack([slots, actions], 't b * d')

        # Pass through the transformer
        tokens = self.transformer(tokens)

        # Unpack the tokens
        slots, _ = E.unpack(tokens, ps, 't b * d')

        return slots
