import time
from tracemalloc import start
import einops
from einops.layers.torch import Rearrange
import numpy as np
from sklearn.metrics import adjusted_rand_score
from sympy import true
from slot_attention.helpers.k_means_pp_as_sbp import k_means_plus_plus_init
from slot_attention.helpers.soft_k_means import soft_k_means
from slot_attention.model.slot_attention import SlotAttention
from slot_attention.model.slot_attention_pp import SlotAttentionPlusPLus
from slot_attention.model.transformer_blocks.transformer_encoder import TransformerEncoder
from slot_attention.model.slatn_decoder import build_decoder
from slot_attention.model.slatn_pos_embeddings import SoftPositionEmbed
from slot_attention.model.model_vit import pair
from slot_attention.model.model_utils import Tensor, assert_shape

import torch
from torch import nn
from torch.nn import functional as F

from typing import Tuple
from slot_attention.model.transformer_blocks.cross_attention_block import CrossAttentionBlock

from slot_attention.model.transformer_blocks.posemb_2d import posemb_sincos_2d


class SlotAttentionModelVitEncoder(nn.Module):
    def __init__(
        self,
        params,
        resolution: Tuple[int, int],
        num_slots: int,
        num_iterations,
        in_channels: int = 3,
        kernel_size: int = 5,
        slot_size: int = 64,
        hidden_dims: Tuple[int, ...] = (64, 64, 64, 64),
        hidden_dims_query: Tuple[int, ...] = (64, 64, 64, 64),
        decoder_resolution: Tuple[int, int] = (8, 8),
        decoder_hidden_dims: Tuple[int, ...] = (64, 64, 64, 64),
        decoder_stride: int = 1,
        decoder_padding: int = 1,
        decoder_output_padding: int = 0,
        empty_cache=False,
    ):
        super().__init__()
        self.params = params
        self.resolution = resolution
        self.num_slots = num_slots
        self.num_iterations = num_iterations
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.hidden_dims_query = hidden_dims_query
        self.decoder_resolution = decoder_resolution
        self.decoder_stride = decoder_stride
        self.decoder_padding = decoder_padding
        self.decoder_output_padding = decoder_output_padding
        self.decoder_hidden_dims = decoder_hidden_dims
        self.slot_size = slot_size
        self.empty_cache = empty_cache
        self.hidden_dims = hidden_dims
        self.out_features = self.hidden_dims[-1]
        dim = self.params.vit_dim
        depth = self.params.vit_depth
        n_heads = self.params.vit_n_heads
        mlp_dim = self.params.vit_mlp_dim
        qk_dim = self.params.vit_qk_dim

        ## Encoder
        channels = self.in_channels
        out_features = params.hidden_dims[-1]
        
        image_height, image_width = pair(self.params.resolution)
        patch_height, patch_width = pair(self.params.vit_patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # assert self.params.num_slots % 2 == 0, "number of slots must be divisible by 2 for cross attention"

        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        self.encoder = TransformerEncoder(params, dim, depth, n_heads, mlp_dim, qk_dim=qk_dim)

        self.decoder = build_decoder(
            decoder_hidden_dims=self.decoder_hidden_dims,
            kernel_size=self.kernel_size,
            decoder_stride=self.decoder_stride,
            decoder_padding=self.decoder_padding,
            decoder_output_padding=self.decoder_output_padding,
            resolution=self.resolution,
            out_features=self.out_features,
            decoder_resolution=self.decoder_resolution,
            slot_size=self.slot_size,
        )
        self.decoder_pos_embedding = SoftPositionEmbed(self.in_channels, self.slot_size, self.decoder_resolution)

        if self.params.grouping_module == 'slot_attention':

            self.slot_attention = SlotAttention(
                in_features=self.slot_size,
                num_iterations=self.num_iterations,
                num_slots=self.num_slots,
                slot_size=self.slot_size,
                mlp_hidden_size=128,
            )
    
        elif self.params.grouping_module == 'slot_attention_pp':
            self.slot_attention_pp = SlotAttentionPlusPLus(
                params=self.params,
                in_features=self.slot_size,
                num_iterations=self.num_iterations,
                num_slots=self.num_slots,
                slot_size=self.slot_size,
                mlp_hidden_size=128,
            )
    
        # modules = []
        # channels = self.in_channels
        # # Build Encoder
        # for h_dim in self.hidden_dims:
        #     modules.append(
        #         nn.Sequential(
        #             nn.Conv2d(
        #                 channels,
        #                 out_channels=h_dim,
        #                 kernel_size=self.kernel_size,
        #                 stride=1,
        #                 padding=self.kernel_size // 2,
        #             ),
        #             nn.LeakyReLU(),
        #         )
        #     )
        #     channels = h_dim

        # self.encoder = nn.Sequential(*modules)
        self.encoder_pos_embedding = SoftPositionEmbed(self.in_channels, self.out_features, resolution)
        self.encoder_out_layer = nn.Sequential(
            nn.Linear(self.out_features, self.out_features),
            nn.LeakyReLU(),
            nn.Linear(self.out_features, self.out_features),
        )

    def forward(self, img, **kwargs):
        device = img.device

        batch_size, num_channels, height, width = img.shape

        x = self.to_patch_embedding(img)
        x += self.pos_embedding.to(device, dtype=x.dtype)

        encoder_out = self.encoder(x)
        # `encoder_out` has shape: [batch_size, n_patches, d]
        
        # encoder_out = self.encoder_out_layer(encoder_out)

        if self.params.grouping_module == 'slot_attention':
            slots = self.slot_attention(encoder_out, **kwargs)
        elif self.params.grouping_module == 'slot_attention_pp':
            slots = self.slot_attention_pp(encoder_out, **kwargs)
        else:
            raise NotImplementedError(f'Grouping module {self.params.grouping_module} not implemented.')

        assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size))
        # `slots` has shape: [batch_size, num_slots, slot_size].
        batch_size, num_slots, slot_size = slots.shape

        decoder_in = einops.repeat(slots, "b n d -> (b n) d h w", h=self.params.decoder_resolution[0], w=self.params.decoder_resolution[1])

        out = self.decoder_pos_embedding(decoder_in)
        out = self.decoder(out)
        # `out` has shape: [batch_size*num_slots, num_channels+1, height, width].
        assert_shape(out.size(), (batch_size * num_slots, num_channels + 1, height, width))

        out = out.view(batch_size, num_slots, num_channels + 1, height, width)
        recons = out[:, :, :num_channels, :, :]
        masks = out[:, :, -1:, :, :]
        if 'slatn_mask_sig' not in self.params or not self.params.slatn_mask_sig:
            masks = F.softmax(masks, dim=1)
        else:
            masks_r = F.softmax(masks, dim=1)
            masks_p = F.sigmoid(masks) 
            masks = masks_p * masks_r
        recon_combined = torch.sum(recons * masks, dim=1)
        return recon_combined, recons, masks, slots