import torch
import torch.nn.functional as F
from torch_linear_assignment import batch_linear_assignment
from einops import einsum, rearrange

from src.utils.misc import fold, unfold
from src.models.modules.gta import CrossAttention


def apply_sequential_permutation(assign, target, first_identity=False, inverse=False):
    '''
    Given a batch of sequential permutations, compute the cumulative permutations and apply them to the target tensor.
    Input:
        assign: (seq, bsize, num_slots) tensor
        target: (seq, bsize, num_slots, dim) tensor
    '''
    seq, bsize, num_slots, dim = target.shape
    slot_axis = 2

    seq_perms = assign.argsort(dim=-1)

    # Add identity permutation for the first time step
    if first_identity:
        identity = torch.arange(num_slots, device=target.device).repeat(bsize, 1)
        seq_perms = torch.cat([identity.unsqueeze(0), seq_perms], dim=0)

    # Here we want to compute the cumulative permutation P'(t) = Prod_{i=0}^{t} P(i) by using torch.gather
    cum_perms = [seq_perms[0]]
    for perm in seq_perms[1:]:
        cum_perms.append(cum_perms[-1].gather(slot_axis - 1, perm))
    cum_perms = torch.stack(cum_perms, dim=0) # [seq, bsize, num_slots]
    
    # print('cumulative permutation:')
    # print(cum_perms[:, 0])
    if inverse:
        cum_perms = cum_perms.argsort(dim=-1)

    # print('cumulative permutation:')
    # print(cum_perms[:, 0])

    # Apply the cumulative permutation to the target tensor
    expanded_cum_perms = cum_perms.unsqueeze(-1).expand(-1, -1, -1, dim)
    target = target.gather(slot_axis, expanded_cum_perms)

    return target

# def apply_sequential_permutation(seq_perms, target, first_identity=False, inverse=False):
#     '''
#     Given a batch of sequential permutations, compute the cumulative permutations and apply them to the target tensor.
#     Input:
#         seq_perms: (seq, bsize, num_slots) tensor
#         target: (seq, bsize, num_slots, dim) tensor
#     '''
#     seq, bsize, num_slots, dim = target.shape
#     slot_axis = 2

#     # Add identity permutation for the first time step
#     if first_identity:
#         identity = torch.arange(num_slots, device=target.device).repeat(bsize, 1)
#         seq_perms = torch.cat([identity.unsqueeze(0), seq_perms], dim=0)

#     # Here we want to compute the cumulative permutation P'(t) = Prod_{i=0}^{t} P(i) by using torch.gather
#     cum_perms = [seq_perms[0]]
#     for perm in seq_perms[1:]:
#         cum_perms.append(cum_perms[-1].gather(slot_axis - 1, perm))
#     cum_perms = torch.stack(cum_perms, dim=0) # [seq, bsize, num_slots]
    
#     if inverse:
#         cum_perms = cum_perms.argsort(dim=-1)

#     # print('cumulative permutation:')
#     # print(cum_perms[:, 0])

#     # Apply the cumulative permutation to the target tensor
#     expanded_cum_perms = cum_perms.unsqueeze(-1).expand(-1, -1, -1, dim)
#     target = target.gather(slot_axis, expanded_cum_perms)

#     return target


def seq_batch_linear_assignment(cost_matrix):
    """
    Args:
      cost_matrix: tensor of shape [seq, batch, n, m]
    Returns:
      tensor of shape [seq, batch, n]
    """
    cost_matrix, bsize = fold(cost_matrix)
    assignments = batch_linear_assignment(cost_matrix)
    assignments = unfold(assignments, b=bsize)
    return assignments
    

def permutation_matrices_to_indices(P):
    """
    Args:
      P: tensor of shape [n, k, k]
    Returns:
      tensor of shape [n, k]
    """
    return P.argmax(dim=-1)


def create_permutation_matrices(p):
    n, k = p.shape
    # Create an identity matrix of shape [k, k] and repeat it n times to get shape [n, k, k]
    identity = torch.eye(k, device=p.device)
    identity = identity.unsqueeze(0).expand(n, -1, -1)
    
    # Use advanced indexing to create the permutation matrices
    permutation_matrices = identity.gather(1, p.unsqueeze(-1).expand(-1, -1, k))
    
    return permutation_matrices


def to_action_invariant(X, action_dim=2):
    X = rearrange(X, '... (d a) -> ... d a', a=action_dim)
    return X.norm(dim=-1)


def solve_assignment(X, Y, normalize=True, action_invariant=True):
    if action_invariant:
        X = to_action_invariant(X)
        Y = to_action_invariant(Y)

    if normalize:
        # Normalize the inputs
        X = F.normalize(X, p=2, dim=-1)
        Y = F.normalize(Y, p=2, dim=-1)

    # Compute the cost matrix using Euclidean distances
    cost_matrix = torch.cdist(Y, X, p=2)
    
    # Solve the linear assignment problem
    assignment = batch_linear_assignment(cost_matrix)
    
    # # Create the permutation matrix P from the assignment
    # P = create_permutation_matrices(assignment)
    
    return assignment


def cumprod_matrix(U):
    # U : tensor of dimension (d ... n n)
    U_cum = [U[0]]
    for u in U[1:]:
        U_cum.append(torch.matmul(U_cum[-1], u))
    U_cum = torch.stack(U_cum)  # tensor of dimension (d ... n n)
    return U_cum


class HungarianLayer(torch.nn.Module):
    '''
    This module permutes slot tokens via solving the linear assignment problem using the Hungarian algorithm.
    It computes permutation matrices P(0), ..., P(T-1) to minimize ||P(t) Z(t+1) - Z(t)||^2, where Z(t) is the slot tokens at time t.
    Then outputs the permuted slot tokens Z'(t) = P'(t-1) Z(t) where P'(t) = Prod_{i=0}^{t} P(i).
        Input: slots (seq, batch, slot, dim)
        Output: permuted_slots (seq, batch, slot, dim)
    '''
    def __init__(self, normalize=False, invariant=False):
        super().__init__()
        self.normalize = normalize
        self.invariant = invariant
    
    def forward(self, slots):
        '''
        slots: torch.Tensor, shape [seq, batch, slot, dim]
        '''
        num_slots = slots.shape[-2]

        # Solve the assignment problem for the current and next time steps
        slots_prev, slots_next, bsize = fold(slots[:-1], slots[1:])
        assignments = solve_assignment(slots_prev, slots_next, 
                                       normalize=self.normalize, action_invariant=self.invariant)
        
        # Add identity permutation for the first time step
        identity = torch.arange(num_slots, device=slots.device).repeat(bsize, 1)
        assignments = torch.cat([identity, assignments], dim=0)
        
        # Compute the permutation matrices
        Ps = create_permutation_matrices(assignments)
        Ps = unfold(Ps, b=bsize)

        # Obtain cumulative permutation P'(t) = Prod_{i=0}^{t} P(i)
        Ps_cum = cumprod_matrix(Ps)
        
        # Permute the slot tokens
        perms = permutation_matrices_to_indices(Ps_cum)
        perms = perms.unsqueeze(-1).expand(-1, -1, -1, slots.shape[-1])
        slots = slots.gather(-2, perms)
        
        return slots


class MatchingAttention(CrossAttention):
    '''
    This module computes attention scores between the slot tokens at time t and t+1, and merges/permutes the slot tokens by the cross attention mechanism.
    '''
    def forward(self, slots):
        '''
        slots: torch.Tensor, shape [seq, batch, slot, dim]
        '''
        seq_len = slots.shape[0]
        updated_slots = [slots[0]]

        for t in range(seq_len - 1):
            slots_prev, slots_next = updated_slots[0], slots[t+1]
            slots_next = super().forward(slots_prev, slots_next)
            updated_slots.append(slots_next)
        
        return torch.stack(updated_slots)