from uu import decode
import einops
import numpy as np
from pyparsing import nested_expr
import torch
from torch import nn
from torch.nn import functional as F

from einops import rearrange
from einops.layers.torch import Rearrange
from slot_attention.model.slatn_pos_embeddings import SoftPositionEmbed
from slot_attention.model.transformer_blocks.posemb_1d import PositionalEncoding1d
from slot_attention.model.transformer_blocks.posemb_2d import posemb_sincos_2d

from slot_attention.model.slatn_decoder import build_decoder
from slot_attention.model.model_utils import assert_shape
from slot_attention.model.model_utils import pair
from slot_attention.model.transformer_blocks.vito_encoder import VitoEncoder

# from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py
# pre-layernorm: https://sh-tsang.medium.com/review-pre-ln-transformer-on-layer-normalization-in-the-transformer-architecture-b6c91a89e9ab


class VitoModel(nn.Module):
    def __init__(self, params, channels = 3):
        super().__init__()
        self.params = params
        out_features = params.hidden_dims[-1]
        
        image_size = params.decoder_resolution
        patch_size = params.vit_patch_size
        dim = params.vit_dim
        depth = params.vit_depth
        n_heads = params.vit_n_heads
        mlp_dim = params.vit_mlp_dim
        qk_dim = params.vit_qk_dim
        
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        
        assert self.params.num_slots % 2 == 0, "number of slots must be divisible by 2 for cross attention"

        self.pos_embedding_2d = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        )

        self.to_seq_embedding = PositionalEncoding1d(dim, dropout=params.vit_dropout)

        self.encoder = VitoEncoder(params, dim, depth, n_heads, mlp_dim, qk_dim=qk_dim, layernorm_bias=self.params.vito_layernorm_bias)

        self.decoder_pos_embedding = SoftPositionEmbed(out_features, params.slot_size, params.decoder_resolution)

        self.decoder = build_decoder(
            decoder_hidden_dims=params.decoder_hidden_dims,
            kernel_size=params.kernel_size,
            decoder_stride=params.decoder_stride,
            decoder_padding=params.decoder_padding,
            decoder_output_padding=params.decoder_output_padding,
            resolution=params.decoder_resolution,
            out_features=out_features,
            decoder_resolution=params.decoder_resolution,
            slot_size=params.slot_size,
        )

    def forward(self, img):
        device = img.device

        batch_size, num_channels, height, width = img.shape

        x = self.to_patch_embedding(img)
        x += self.pos_embedding_2d.to(device, dtype=x.dtype)

        slots = self.to_seq_embedding(batch_size, self.params.num_slots)

        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            slots = self.encoder(x, slots)

        assert_shape(slots.size(), (batch_size, self.params.num_slots, self.params.slot_size))
        # `slots` has shape: [batch_size, num_slots, slot_size].
        batch_size, num_slots, slot_size = slots.shape

        # slots = slots.view(batch_size * num_slots, slot_size, 1, 1)
        # decoder_in = slots.repeat(1, 1, self.decoder_resolution[0], self.decoder_resolution[1])
        decoder_in = einops.repeat(slots, "b n d -> (b n) d h w", h=self.params.decoder_resolution[0], w=self.params.decoder_resolution[1])

        out = self.decoder_pos_embedding(decoder_in)

        out = self.decoder(out)
        # `out` has shape: [batch_size*num_slots, num_channels+1, height, width].
        assert_shape(out.size(), (batch_size * self.params.num_slots, num_channels + 1, height, width))

        # out = out.view(batch_size, self.params.num_slots, num_channels + 1, height, width)
        out = einops.rearrange(out, "(b n) c h w -> b n c h w", b=batch_size)
        recons = out[:, :, :num_channels, :, :]
        masks = out[:, :, -1:, :, :]
        masks = F.softmax(masks, dim=1)
        recon_combined = torch.sum(recons * masks, dim=1)
        return recon_combined, recons, masks, slots
    
    
