import time
import einops
import numpy as np
from sklearn.metrics import adjusted_rand_score
from sympy import true
from slot_attention.model.slot_attention import SlotAttention
from slot_attention.model.slatn_decoder import build_decoder
from slot_attention.model.slatn_pos_embeddings import SoftPositionEmbed
from slot_attention.model.model_utils import Tensor, assert_shape

import torch
from torch import nn
from torch.nn import functional as F

from typing import Tuple

from slot_attention.model.slot_attention_pp import SlotAttentionPlusPLus


class SlotAttentionModel(nn.Module):
    def __init__(
        self,
        params,
        resolution: Tuple[int, int],
        num_slots: int,
        num_iterations,
        in_channels: int = 3,
        kernel_size: int = 5,
        slot_size: int = 64,
        hidden_dims: Tuple[int, ...] = (64, 64, 64, 64),
        hidden_dims_query: Tuple[int, ...] = (64, 64, 64, 64),
        decoder_resolution: Tuple[int, int] = (8, 8),
        decoder_hidden_dims: Tuple[int, ...] = (64, 64, 64, 64),
        decoder_stride: int = 1,
        decoder_padding: int = 1,
        decoder_output_padding: int = 0,
        empty_cache=False,
    ):
        super().__init__()
        self.params = params
        self.resolution = resolution
        self.num_slots = num_slots
        self.num_iterations = num_iterations
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.hidden_dims_query = hidden_dims_query
        self.decoder_resolution = decoder_resolution
        self.decoder_stride = decoder_stride
        self.decoder_padding = decoder_padding
        self.decoder_output_padding = decoder_output_padding
        self.decoder_hidden_dims = decoder_hidden_dims
        self.slot_size = slot_size
        self.empty_cache = empty_cache
        self.hidden_dims = hidden_dims
        self.out_features = self.hidden_dims[-1]

        self.build_slatn_encoder(resolution)

        self.decoder = build_decoder(
            decoder_hidden_dims=self.decoder_hidden_dims,
            kernel_size=self.kernel_size,
            decoder_stride=self.decoder_stride,
            decoder_padding=self.decoder_padding,
            decoder_output_padding=self.decoder_output_padding,
            resolution=self.resolution,
            out_features=self.out_features,
            decoder_resolution=self.decoder_resolution,
            slot_size=self.slot_size,
        )
        self.decoder_pos_embedding = SoftPositionEmbed(self.in_channels, self.slot_size, self.decoder_resolution)

        if self.params.grouping_module == 'slot_attention':
            self.slot_attention = SlotAttention(
                in_features=self.out_features,
                num_iterations=self.num_iterations,
                num_slots=self.num_slots,
                slot_size=self.slot_size,
                mlp_hidden_size=128,
            )
        elif self.params.grouping_module == 'slot_attention_pp':
            self.slot_attention_pp = SlotAttentionPlusPLus(
                params=self.params,
                in_features=self.out_features,
                num_iterations=self.num_iterations,
                num_slots=self.num_slots,
                slot_size=self.slot_size,
                mlp_hidden_size=128,
            )

    def build_slatn_encoder(self, resolution):
        modules = []
        channels = self.in_channels
        # Build Encoder
        for h_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        channels,
                        out_channels=h_dim,
                        kernel_size=self.kernel_size,
                        stride=1,
                        padding=self.kernel_size // 2,
                    ),
                    nn.LeakyReLU(),
                )
            )
            channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.encoder_pos_embedding = SoftPositionEmbed(self.in_channels, self.out_features, resolution)
        self.encoder_out_layer = nn.Sequential(
            nn.Linear(self.out_features, self.out_features),
            nn.LeakyReLU(),
            nn.Linear(self.out_features, self.out_features),
        )

    def forward(self, x, **kwargs):
        if self.empty_cache:
            torch.cuda.empty_cache()

        batch_size, num_channels, height, width = x.shape
        encoder_out = self.encoder(x)
        encoder_out = self.encoder_pos_embedding(encoder_out)
        # `encoder_out` has shape: [batch_size, filter_size, height, width]
        encoder_out = torch.flatten(encoder_out, start_dim=2, end_dim=3)
        # `encoder_out` has shape: [batch_size, filter_size, height*width]
        encoder_out = encoder_out.permute(0, 2, 1)
        encoder_out = self.encoder_out_layer(encoder_out)
        # `encoder_out` has shape: [batch_size, height*width, filter_size]
        
        if self.params.grouping_module == 'slot_attention':
            slots = self.slot_attention(encoder_out, **kwargs)
        elif self.params.grouping_module == 'slot_attention_pp':
            slots = self.slot_attention_pp(encoder_out, **kwargs)
        else:
            raise NotImplementedError(f'Grouping module {self.params.grouping_module} not implemented.')
        
        assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size))
        # `slots` has shape: [batch_size, num_slots, slot_size].
        batch_size, num_slots, slot_size = slots.shape

        slots = slots.view(batch_size * num_slots, slot_size, 1, 1)
        decoder_in = slots.repeat(1, 1, self.decoder_resolution[0], self.decoder_resolution[1])

        out = self.decoder_pos_embedding(decoder_in)
        out = self.decoder(out)
        # `out` has shape: [batch_size*num_slots, num_channels+1, height, width].
        assert_shape(out.size(), (batch_size * num_slots, num_channels + 1, height, width))

        out = out.view(batch_size, num_slots, num_channels + 1, height, width)
        recons = out[:, :, :num_channels, :, :  ]
        masks = out[:, :, -1:, :, :]
        if 'slatn_mask_sig' not in self.params or not self.params.slatn_mask_sig:
            masks = F.softmax(masks, dim=1)
        else:
            masks_p = F.sigmoid(masks) 
            masks_r = F.softmax(masks, dim=1)
            masks = masks_p * masks_r
        recon_combined = torch.sum(recons * masks, dim=1)
        return recon_combined, recons, masks, slots