from math import prod
from einops import rearrange, repeat, reduce, unpack
from einops.layers.torch import Rearrange

import torch
import torch.nn.functional as F
from torch import nn
from src.models.modules.gta import Attention, CrossAttention
from src.models.modules.rmsnorm import RMSNorm
# from src.models.modules.posemb import PositionalEncoding, RotaryPositionalEncoding
from src.utils.misc import default, pair, fold, unfold


# def pack(tensors: list):
#     return einops.pack(tensors, 'b * d')


# def unpack(tensor: torch.Tensor, ps: list):
#     return einops.unpack(tensor, ps, 'b * d')


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    
    def forward(self, x):
        return self.net(x)


# class GTASlotEncoder(nn.Module):
#     def __init__(self, dim, depth, heads, mlp_dim, 
#                  resolutions, **_
#                  ):
#         super().__init__()
#         attn_args = dict(dim=dim, num_heads=heads, resolutions=resolutions)
#         self.layers = nn.ModuleList([])
#         for _ in range(depth):
#             self.layers.append(nn.ModuleList([
#                 Attention(
#                     gta='qkv', 
#                     qk_norm=True,
#                     **attn_args
#                 ),
#                 CrossAttention(
#                     gta='kv', 
#                     qk_norm='q', 
#                     inverted=True, 
#                     **attn_args
#                 ),
#                 FeedForward(dim, mlp_dim)
#             ]))
    
#     def forward(self, patch, slot):
#         patch, slot, bsize = fold(patch, slot)
#         for patch_attn, slot_attn, ff in self.layers:

#             ### spatial patch-patch attention w/ GTA
#             patch = patch_attn(patch) + patch
#             patch = ff(patch) + patch

#             slot = slot_attn(slot, patch) + slot
#             slot = ff(slot) + slot

#         slot = unfold(slot, b=bsize)
#         return slot


class GTAPatchEncoder(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, 
                 resolutions, **_
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        attn_args = dict(dim=dim, num_heads=heads, qk_norm=True, resolutions=resolutions)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(gta='qkv', **attn_args),
                FeedForward(dim, mlp_dim)
            ]))
    
    def forward(self, patch, slot):
        patch, bsize = fold(patch)
        for patch_attn, ff in self.layers:

            ### spatial slot-patch attention
            patch = patch_attn(patch) + patch
            
            patch = ff(patch) + patch
        patch = unfold(patch, b=bsize)
        return patch


class GTAInvertedEncoder(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, 
                 resolutions, **_
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        attn_args = dict(dim=dim, num_heads=heads, resolutions=resolutions)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                CrossAttention(
                    gta='kv', 
                    qk_norm='q', 
                    inverted=True,
                    **attn_args
                ),
                FeedForward(dim, mlp_dim)
            ]))
    
    def forward(self, patch, slot):
        patch, slot, bsize = fold(patch, slot)
        for slot_attn, ff in self.layers:

            ### spatial slot-patch attention
            slot = slot_attn(slot, patch) + slot
            
            slot = ff(slot) + slot
        slot = unfold(slot, b=bsize)
        return slot


class GTAInvertedMemoryEncoder(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, 
                 resolutions, **_
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        attn_args = dict(dim=dim, num_heads=heads, resolutions=resolutions)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                MemoryAttention(
                    gta='',
                    qk_norm=True,
                    **attn_args
                ),
                CrossAttention(
                    gta='kv', 
                    qk_norm='q', 
                    inverted=True,
                    **attn_args
                ),
                FeedForward(dim, mlp_dim)
            ]))
    
    def forward(self, patch, slot):
        patch, slot, bsize = fold(patch, slot)
        for memory_attn, slot_attn, ff in self.layers:

            slot = memory_attn(slot, bsize) + slot

            ### spatial slot-patch attention
            slot = slot_attn(slot, patch) + slot
            
            slot = ff(slot) + slot
        slot = unfold(slot, b=bsize)
        return slot


class GTATransformerEncoder(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, 
                 qk_norm: bool,
                 resolutions,
                 v_transform: bool,
                 inverted: bool = False,
                 **_
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        attn_args = dict(dim=dim, num_heads=heads, qk_norm=qk_norm, resolutions=resolutions)
        attn_type = 'qkv' if v_transform else 'qk'
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(gta=attn_type, **attn_args),
                CrossAttention(gta='kv', inverted=inverted, **attn_args),
                FeedForward(dim, mlp_dim)
            ]))
    
    def forward(self, patch, slot):
        patch, slot, bsize = fold(patch, slot)
        for patch_attn, slot_attn, ff in self.layers:

            ### spatial patch-patch attention w/ GTA
            patch = patch_attn(patch) + patch
            
            ### spatial slot-patch attention
            slot = slot_attn(slot, patch) + slot
            
            patch = ff(patch) + patch
            slot = ff(slot) + slot
        slot = unfold(slot, b=bsize)
        return slot



class IdentityCA(nn.Module):
    def forward(self, x, y):
        return 0

class MemoryAttention(CrossAttention):
    def __init__(self, *args, history_size=4, **kwargs):
        super().__init__(*args, **kwargs)
        self.history_size = history_size

    def forward(self, slot, batch_size):
        ### temporal slot attention
        # First create a memory of the last history_size slots using torch.nn.functional.unfold
        # Then apply the memory attention between the current and the last history_size slots
        # Finally, update the slot tokens with the memory attention

        num_slots = slot.shape[1]
        memory = rearrange(slot, '(seq b) slot dim -> b dim slot seq', b=batch_size)
        memory = F.pad(memory, (self.history_size - 1, 0)) # pad the memory along the sequence dimension
        memory = F.unfold(memory, kernel_size=(num_slots, self.history_size))
        memory = rearrange(memory, 'b (dim slot) seq -> (seq b) slot dim', 
                            slot = num_slots * self.history_size)
        return super().forward(slot, memory)


class GTATransformerEncoder5(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, 
                 qk_norm: bool,
                 resolutions,
                 v_transform: bool,
                 inverted: bool = False,
                 history_size: int = 4,
                 **_
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        attn_args = dict(dim=dim, num_heads=heads, qk_norm=qk_norm)
        attn_type = 'qkv' if v_transform else 'qk'
        for layer in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(gta=attn_type, resolutions=resolutions, **attn_args),
                CrossAttention(gta='', **attn_args) if layer > 0 else IdentityCA(),
                CrossAttention(gta='', inverted=inverted, **attn_args),
                FeedForward(dim, mlp_dim)
            ]))
        # self.norm = RMSNorm(dim)
        self.history_size = history_size
    
    def forward(self, patch, slot):
        num_slots = slot.shape[2]
        patch, slot, batch_size = fold(patch, slot)
        for patch_attn, memory_attn, slot_attn, ff in self.layers:    
            
            
            memory = rearrange(slot, '(seq b) slot dim -> b dim slot seq', b=batch_size)
            memory = F.pad(memory, (self.history_size - 1, 0)) # pad the memory along the sequence dimension
            memory = F.unfold(memory, kernel_size=(num_slots, self.history_size))
            memory = rearrange(memory, 'b (dim slot) seq -> (seq b) slot dim', 
                               slot = num_slots * self.history_size)
            slot = memory_attn(slot, memory) + slot
            

            ### spatial patch-patch attention w/ GTA
            patch = patch_attn(patch) + patch

            ### spatial slot attention
            slot = slot_attn(slot, patch) + slot

            patch = ff(patch) + patch
            slot = ff(slot) + slot
        
        # slot = self.norm(slot)
        slot = unfold(slot, b=batch_size)
        return slot


# class GTATransformerEncoderTC(nn.Module):
#     def __init__(self, dim, depth, heads, mlp_dim, 
#                  qk_norm: bool,
#                  resolutions,
#                  v_transform: bool,
#                  inverted: bool = False,
#                  TC_layers: int = 6,
#                  # projector: nn.Module = nn.Identity(),
#                  ):
#         super().__init__()
#         attn_args = dict(dim=dim, num_heads=heads, qk_norm=qk_norm)
#         attn_type = 'qkv' if v_transform else 'qk'
#         layers = nn.ModuleList([])
#         for _ in range(depth):
#             layers.append(nn.ModuleList([
#                 Attention(gta=attn_type, resolutions=resolutions, **attn_args),
#                 CrossAttention(inverted=inverted, **attn_args),
#                 FeedForward(dim, mlp_dim)
#             ]))
#         self.init_layers = layers[:TC_layers]
#         self.post_layers = layers[TC_layers:]

#         self.slot_norm = RMSNorm(dim)
#         # self.projector = FeedForward(dim, mlp_dim)
#         self.projector = nn.TransformerEncoder(
#             nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim, 
#                                        activation=nn.SiLU(),
#                                        batch_first=True, norm_first=True, dropout=0.0),
#             num_layers=2,
#             norm=RMSNorm(dim),
#         )

#     def _forward(self, patches, slots, layers):
#         for patch_attn, slot_attn, ff in layers:
#             ### spatial patch-patch attention w/ GTA
#             patches = patch_attn(patches) + patches

#             ### spatial slot attention
#             slots = slot_attn(slots, patches) + slots

#             patches = ff(patches) + patches
#             slots = ff(slots) + slots
#         return patches, slots
    
#     def forward(self, patches, slots, batch_size):
#         seq_len = patches.shape[0] // batch_size

#         ### initialize slot tokens to keep time consistency
#         init_slot = slots[ : batch_size]
#         pe = slots[ : batch_size]
#         updated_slots = []
#         updated_patches = []
#         for t in range(seq_len):
#             patch_t = patches[t * batch_size : (t + 1) * batch_size]
#             updated_patch, updated_slot = self._forward(patch_t, init_slot, self.init_layers)
#             updated_patches.append(updated_patch)
#             updated_slots.append(updated_slot)

#             # init_slot = self.projector(updated_slot) + pe
#             init_slot = self.projector(updated_slot)
        
#         patches = torch.cat(updated_patches, dim=0)
#         slots = torch.cat(updated_slots, dim=0)

#         ### apply post layers that are time independent
#         patches, slots = self._forward(patches, slots, self.post_layers)

#         slots = self.slot_norm(slots)
#         return patches, slots    




# class GTATransformerDecoder4(nn.Module):
#     # initializer_range = 0.02
#     initializer_range = 0.1

#     def __init__(self, dim, depth, heads, mlp_dim, 
#                  qk_norm: bool,
#                  resolutions,
#                  v_transform: bool,
#                  inverted: bool = False,
#                  ):
#         super().__init__()
#         self.layers = nn.ModuleList([])
#         attn_args = dict(dim=dim, num_heads=heads, qk_norm=qk_norm, resolutions=resolutions)
#         attn_type = 'qkv' if v_transform else 'qk'
#         for _ in range(depth):
#             self.layers.append(nn.ModuleList([
#                 CrossAttention(gta='q', **attn_args, inverted=inverted),
#                 Attention(gta=attn_type, **attn_args),
#                 FeedForward(dim, mlp_dim)
#             ]))
        
#         self.patch_embed = nn.Parameter(torch.zeros(dim))
#         torch.nn.init.normal_(self.patch_embed, std=self.initializer_range)
    
#     def forward(self, _patch, slot):
#         patch = repeat(self.patch_embed, 'd -> sb n d', sb=_patch.shape[0], n=_patch.shape[1])
#         for cross_attn, patch_attn, ff in self.layers:

#             patch = cross_attn(patch, slot) + patch
            
#             patch = patch_attn(patch) + patch

#             patch = ff(patch) + patch
#             slot = ff(slot) + slot
#         return patch, slot


class GTATransformerDecoderIndep(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, 
                 qk_norm: bool,
                 resolutions,
                 v_transform: bool,
                 inverted: bool = False,
                 **kwargs
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        attn_args = dict(dim=dim, num_heads=heads, qk_norm=qk_norm, resolutions=resolutions)
        attn_type = 'qkv' if v_transform else 'qk'
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(gta=attn_type, **attn_args),
                FeedForward(dim, mlp_dim)
            ]))
        
        self.num_patches = prod(resolutions)
        
    def forward(self, slot):
        num_slots = slot.shape[2]
        slot, bsize = fold(slot)

        patch = repeat(slot, 'tb k d -> (tb k) n d', n=self.num_patches)
        for patch_attn, ff in self.layers:

            patch = patch_attn(patch) + patch

            patch = ff(patch) + patch
        
        patch = rearrange(patch, '(tb k) n d -> tb k n d', k=num_slots)
        patch = unfold(patch, b=bsize)
        return patch


class GTATransformerPatchDecoder(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, 
                 qk_norm: bool,
                 resolutions,
                 v_transform: bool,
                 inverted: bool = False,
                 **kwargs
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        attn_args = dict(dim=dim, num_heads=heads, qk_norm=qk_norm, resolutions=resolutions)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(gta='qkv', **attn_args),
                FeedForward(dim, mlp_dim)
            ]))
        
        self.num_patches = prod(resolutions)
        
    def forward(self, patch):
        patch, bsize = fold(patch)

        for patch_attn, ff in self.layers:

            patch = patch_attn(patch) + patch

            patch = ff(patch) + patch
        
        patch = rearrange(patch, 'tb n d -> tb () n d')
        patch = unfold(patch, b=bsize)
        return patch



class GTAViT(nn.Module):
    TransformerClass: nn.Module = None
    
    def __init__(self, img_size, patch_size, embed_dim, depth, num_heads, mlp_ratio, img_channels, 
                 **kwargs):
        super().__init__()
        image_height, image_width = pair(img_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0,\
            'Image dimensions must be divisible by the patch size.'

        self.patch_dim = img_channels * patch_height * patch_width
        self.dim, self.patch_height, self.patch_width = embed_dim, patch_height, patch_width
        self.h = image_height // patch_height
        self.w = image_width // patch_width

        self.transformer = self.TransformerClass(
            embed_dim, depth, num_heads, embed_dim * mlp_ratio,
            resolutions=[self.h, self.w], **kwargs
        )


class GTAViTEncoder(GTAViT):
    # initializer_range = 0.02
    initializer_range = 0.1
    # TransformerClass = GTATransformerEncoder
    # TransformerClass = GTATransformerEncoder2
    # TransformerClass = GTATransformerEncoder3
    # TransformerClass = GTATransformerEncoder4
    TransformerClass = GTATransformerEncoder5
    # TransformerClass = GTATransformerEncoderTC
    
    def __init__(self, num_slots, z_norm = False, temporal_pe = nn.Identity(), 
                 arch = None, **kwargs):
        if arch:
            self.TransformerClass = globals()[arch]
        super().__init__(**kwargs)

        self.to_patch_embedding = nn.Sequential(
            Rearrange("... c (h p1) (w p2) -> ... (h w) (p1 p2 c)", 
                      p1 = self.patch_height, p2 = self.patch_width),
            nn.LayerNorm(self.patch_dim),
            nn.Linear(self.patch_dim, self.dim),
            nn.LayerNorm(self.dim),
        )

        self.slot_embed = nn.Parameter(torch.zeros(num_slots, self.dim))
        torch.nn.init.normal_(self.slot_embed, std=self.initializer_range)
        # PEClass = PositionalEncoding if pe_type == 'additive' else RotaryPositionalEncoding
        # self.temporal_pe = PEClass(self.dim) if use_pe else nn.Identity()
        self.temporal_pe = temporal_pe
        self.norm = RMSNorm(self.dim) if z_norm else nn.Identity()

    def forward(self, images):
        patches = self.to_patch_embedding(images)
        seq_len, batch_size = patches.shape[0], patches.shape[1]

        ### apply time emb to slot tokens
        slots = repeat(self.slot_embed, 'n d -> s n d', s = seq_len)
        slots = self.temporal_pe(slots)
        slots = repeat(slots, 's n d -> s b n d', b = batch_size)

        slots = self.transformer(patches, slots)

        slots = self.norm(slots)
        return slots
    
    
class GTAViTDecoder(GTAViT):
    TransformerClass = GTATransformerDecoderIndep

    def __init__(self, arch=None, **kwargs):
        if arch:
            self.TransformerClass = globals()[arch]
        super().__init__(**kwargs)
        kwargs['dim'] = kwargs['embed_dim']
        self.to_pixels = MaskedComposer(**kwargs)

    def forward(self, slots):
        patches = self.transformer(slots)
        images = self.to_pixels(patches)
        return images


class SlotImage():
    slot_dim = -4
    permutations = None

    def __init__(self, rgb, mask) -> None:
        self.rgb = rgb
        self.mask = mask

    @property
    def num_slots(self):
        return self.rgb.shape[self.slot_dim]
    
    def set_permutations(self):
        if self.permutations is not None:
            return
        # Get the shape of images
        seq_len, batch_size, num_slots = self.rgb.shape[:self.slot_dim]

        # Generate random permutation indices for each (seq, batch) pair
        permutations = [torch.randperm(num_slots, device=self.rgb.device) 
                                    for _ in range(seq_len * batch_size)]
        permutations = torch.stack(permutations)

        # Reshape the permutations to match the [seq, batch, dim] structure
        self.permutations = permutations.view(seq_len, batch_size, num_slots)

    @property
    def shuffled_rgb(self):
        self.set_permutations()
        return self.rgb.gather(SlotImage.slot_dim, self.permutations)

    @property
    def shuffled_mask(self):
        self.set_permutations()
        return self.mask.gather(SlotImage.slot_dim, self.permutations)

    def dropped_mask(self, num_drops: int):
        return self.shuffled_mask[:, :, :num_drops].sum(dim=SlotImage.slot_dim)

    def compose(self, keepslot: bool = False, num_drops: int = 0):
        '''
        Args:
            keepslot: bool
                keep the slot tokens or not
            num_drops: int
                number of slots to drop
        '''
        image = self.shuffled_rgb * self.shuffled_mask if num_drops > 0 else self.rgb * self.mask
        image = image[:, :, num_drops:]
        image = image if keepslot else reduce(image, '... k c h w -> ... c h w', 'sum')
        
        return image


class MaskedComposer(nn.Module):
    """
    Compose images with masks
    """

    def __init__(self, img_size, patch_size, dim, img_channels, softmax=True, **kwargs):
        super().__init__()
        image_height, image_width = pair(img_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'

        rgba_channels = img_channels + 1
        self.unpatchfy = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, patch_height * patch_width * rgba_channels),
            Rearrange("... (h w) (ph pw c) -> ... c (h ph) (w pw)", 
                      ph = patch_height, pw = patch_width, h = image_height // patch_height),
        )
        self.img_channels = img_channels
        self.softmax = softmax

    def forward(self, components):
        alphaimages = self.unpatchfy(components)
        images, masks = unpack(alphaimages, [(self.img_channels,), (1,)], 's b k * h w')
        masks = masks.softmax(dim=SlotImage.slot_dim) if self.softmax else masks.sigmoid()
        return SlotImage(images, masks)
