from uu import decode
import einops
import numpy as np
from pyparsing import nested_expr
import torch
from torch import nn
from torch.nn import functional as F

from einops import rearrange
from einops.layers.torch import Rearrange
from slot_attention.model.slatn_pos_embeddings import SoftPositionEmbed
from slot_attention.model.transformer_blocks.cross_attention_block import CrossAttentionBlock
from slot_attention.model.transformer_blocks.transformer_encoder import TransformerEncoder
from slot_attention.model.transformer_blocks.posemb_2d import posemb_sincos_2d
from slot_attention.helpers.clustering_heuristic import extract_query_mat

from slot_attention.model.slatn_decoder import build_decoder
from slot_attention.model.model_utils import assert_shape
from slot_attention.model.model_utils import pair

# from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py
# pre-layernorm: https://sh-tsang.medium.com/review-pre-ln-transformer-on-layer-normalization-in-the-transformer-architecture-b6c91a89e9ab


class VitModel(nn.Module):
    def __init__(self, params, channels = 3):
        super().__init__()
        self.params = params
        out_features = params.hidden_dims[-1]
        
        image_size = params.decoder_resolution
        patch_size = params.vit_patch_size
        dim = params.vit_dim
        depth = params.vit_depth
        n_heads = params.vit_n_heads
        mlp_dim = params.vit_mlp_dim
        qk_dim = params.vit_qk_dim
        
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        assert self.params.num_slots % 2 == 0, "number of slots must be divisible by 2 for cross attention"

        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        )

        self.vit_encoder = TransformerEncoder(dim, depth, n_heads, mlp_dim, qk_dim=qk_dim)

        # self.grouping_module = CrossAttentionBlock(dim=dim, depth=1, mlp_dim=mlp_dim, qk_dim=dim)
        self.grouping_module = CrossAttentionBlock(dim, n_heads=n_heads, hidden_dim=mlp_dim, qk_dim=qk_dim)

        self.decoder_pos_embedding = SoftPositionEmbed(self.in_channels, self.slot_size, self.decoder_resolution)
        
        self.decoder = build_decoder(
            decoder_hidden_dims=params.decoder_hidden_dims,
            kernel_size=params.kernel_size,
            decoder_stride=params.decoder_stride,
            decoder_padding=params.decoder_padding,
            decoder_output_padding=params.decoder_output_padding,
            resolution=params.decoder_resolution,
            out_features=out_features,
            decoder_resolution=params.decoder_resolution,
            slot_size=params.slot_size,
        )

    def forward(self, img):
        device = img.device

        batch_size, num_channels, height, width = img.shape

        x = self.to_patch_embedding(img)
        x += self.pos_embedding.to(device, dtype=x.dtype)

        # # add cls-queries to the input
        x = torch.cat((x, torch.rand((batch_size, self.params.num_slots, self.params.slot_size)).to(device, dtype=x.dtype)), dim=1)
        
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            x = self.vit_encoder(x)
        
        # # extract query vectors by clustering
        # query_mat_np = extract_query_mat(x.detach().cpu().numpy(), self.params.num_slots)
        # query_mat_torch = torch.from_numpy(query_mat_np).to(device, dtype=x.dtype)
        # queries = einops.einsum(query_mat_torch, x, "b s n, b n d -> b s d")        
        
        # # extract cls-slots
        queries = x[:, :self.params.num_slots]
            
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            slots = self.grouping_module(queries, x)

        assert_shape(slots.size(), (batch_size, self.params.num_slots, self.params.slot_size))
        # `slots` has shape: [batch_size, num_slots, slot_size].
        batch_size, num_slots, slot_size = slots.shape

        # slots = slots.view(batch_size * num_slots, slot_size, 1, 1)
        # decoder_in = slots.repeat(1, 1, self.decoder_resolution[0], self.decoder_resolution[1])
        decoder_in = einops.repeat(slots, "b n d -> (b n) d h w", h=self.params.decoder_resolution[0], w=self.params.decoder_resolution[1])


        # hint: no positional encoding for the decoder
        out = self.decoder_pos_embedding(decoder_in)
        out = self.decoder(decoder_in)
        # `out` has shape: [batch_size*num_slots, num_channels+1, height, width].
        assert_shape(out.size(), (batch_size * self.params.num_slots, num_channels + 1, height, width))

        # out = out.view(batch_size, self.params.num_slots, num_channels + 1, height, width)
        out = einops.rearrange(out, "(b n) c h w -> b n c h w", b=batch_size)
        recons = out[:, :, :num_channels, :, :]
        masks = out[:, :, -1:, :, :]
        masks = F.softmax(masks, dim=1)
        recon_combined = torch.sum(recons * masks, dim=1)
        return recon_combined, recons, masks, slots
    
    
