from typing import Tuple
from sympy import true

import torch
from torch import nn
from torch.nn import functional as F
from slot_attention.model.slatn_pos_embeddings import SoftPositionEmbed
from slot_attention.model.slatn_decoder import build_decoder
from slot_attention.losses.ari import adjusted_rand_index
from slot_attention.model.model_utils import Tensor
from slot_attention.model.model_utils import assert_shape


class HopfieldLayerProjectBeforeOnlyQQQ(nn.Module):
    def __init__(self, in_features, num_iterations, num_slots, slot_size, mlp_hidden_size, hopfield_beta=50, epsilon=1e-8):
        super().__init__()
        self.in_features = in_features
        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.slot_size = slot_size  # number of hidden layers in slot dimensions
        self.mlp_hidden_size = mlp_hidden_size
        self.epsilon = epsilon

        self.norm_inputs = nn.LayerNorm(self.in_features)
        # I guess this is layer norm across each slot? should look into this
        self.norm_slots = nn.LayerNorm(self.slot_size)

        # Linear maps for the attention module.
        self.project_q = nn.Linear(self.slot_size, self.slot_size, bias=False)
        self.project_k = nn.Linear(self.in_features, self.slot_size, bias=False)

        # Hopfield Layer
        self.hopfield_beta = hopfield_beta

    def forward(self, query: Tensor, inputs: Tensor, **kwargs):
        
        vis_carrier = kwargs.get('vis_carrier', None)
        
        # `inputs` has shape [batch_size, num_inputs, inputs_size].
        batch_size, num_inputs, inputs_size = inputs.shape
        # print("inputs.shape", inputs.shape)
        inputs = self.norm_inputs(inputs)  # Apply layer norm to the input.
        k = self.project_k(inputs)  # Shape: [batch_size, num_inputs, slot_size].
        assert_shape(k.size(), (batch_size, num_inputs, self.slot_size))
        # v = self.project_v(inputs)  # Shape: [batch_size, num_inputs, slot_size].
        # assert_shape(v.size(), (batch_size, num_inputs, self.slot_size))

        batch_size, num_slots, slots_size = query.shape
        slots = self.norm_slots(query)
        q = self.project_q(slots)  # Shape: [batch_size, num_inputs, slot_size].
        assert_shape(slots.size(), (batch_size, num_slots, self.slot_size))

        for iteration in range(self.num_iterations):

            if iteration != 0:
                q = self.norm_slots(slots)
            attn_norm_factor = num_slots ** -0.5                
            attn_logits = attn_norm_factor * torch.matmul(k, q.transpose(2, 1))
            attn = F.softmax(attn_logits, dim=-1)
            
            if vis_carrier is not None:
                # log histogram of project_k weights
                vis_carrier.add_qk_masks(name='Attention Logits', mask=attn_logits[0].detach().cpu().numpy())
                vis_carrier.add_qk_masks(name='Attention', mask=attn[0].detach().cpu().numpy())
            
            # `attn` has shape: [batch_size, num_inputs, num_slots].
            assert_shape(attn.size(), (batch_size, num_inputs, num_slots))

            slots = torch.matmul(attn.transpose(1, 2), k)
            assert_shape(slots.size(), (batch_size, num_slots, self.slot_size))
            
        return slots


class HopfieldModelSprites(nn.Module):
    def __init__(
        self,
        resolution: Tuple[int, int],
        num_slots: int,
        num_iterations,
        in_channels: int = 3,
        kernel_size: int = 5,
        slot_size: int = 64,
        hidden_dims: Tuple[int, ...] = (64, 64, 64, 64),
        hidden_dims_query: Tuple[int, ...] = (64, 64, 64, 64),
        decoder_resolution: Tuple[int, int] = (8, 8),
        decoder_stride: int = 1,
        decoder_padding: int = 1,
        decoder_output_padding: int = 0,
        decoder_hidden_dims: Tuple[int, ...] = (64, 64, 64),
        
        empty_cache=False,
        hopfield_steps_eps=1e-6,
        use_hopfield=True,
        use_hopfield_norm_before=False,
        use_hopfield_project_before=False,
        use_hopfield_actual=False,
        use_hopfield_qqq=False,
        use_hopfield_reverse_softmax=False,
        use_hopfield_only_qqq=False,
        use_double_softmax=False,
        use_competetion=False,
        use_gumble_softmax=False,
        l1_loss=False,
        hopfield_beta=50,
        use_average_pool=False,
        average_pool_size=2,
        average_pool_stride=2,
        use_max_pool=False,
        use_residual_path=False,
        max_pool_size=2,
        max_pool_stride=2,
    ):
        super().__init__()
        self.resolution = resolution
        self.num_slots = num_slots
        self.num_iterations = num_iterations
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.slot_size = slot_size
        self.empty_cache = empty_cache
        self.hidden_dims = hidden_dims
        self.hidden_dims_query = hidden_dims_query
        self.decoder_resolution = decoder_resolution
        self.decoder_stride = decoder_stride
        self.decoder_padding = decoder_padding
        self.decoder_output_padding = decoder_output_padding
        self.decoder_hidden_dims = decoder_hidden_dims
        
        self.out_features = self.hidden_dims[-1]
        self.out_features_query = self.hidden_dims_query[-1]
        self.use_hopfield = use_hopfield
        self.hopfield_steps_eps = hopfield_steps_eps
        self.hopfield_beta = hopfield_beta
        self.use_hopfield_norm_before = use_hopfield_norm_before
        self.use_hopfield_project_before = use_hopfield_project_before
        self.use_hopfield_actual = use_hopfield_actual 
        self.use_hopfield_qqq = use_hopfield_qqq
        self.use_average_pool = use_average_pool
        self.average_pool_size = average_pool_size
        self.average_pool_stride = average_pool_stride
        self.use_hopfield_reverse_softmax = use_hopfield_reverse_softmax
        self.use_hopfield_only_qqq = use_hopfield_only_qqq
        self.use_double_softmax = use_double_softmax
        self.use_competetion = use_competetion
        self.use_gumble_softmax = use_gumble_softmax
        self.l1_loss = l1_loss
        self.use_max_pool = use_max_pool
        self.use_residual_path = use_residual_path
        self.max_pool_size = max_pool_size
        self.max_pool_stride = max_pool_stride


        modules = []
        channels = self.in_channels
        # Build Encoder - 1
        for h_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        channels,
                        out_channels=h_dim,
                        kernel_size=self.kernel_size,
                        stride=1,
                        padding=self.kernel_size // 2,
                    ),
                    nn.LeakyReLU(),  # TODO: Paper uses ReLU, page 23
                )
            )
            channels = h_dim
        modules_query = []
        if use_max_pool:
            for h_dim in hidden_dims_query:
                modules_query.append(
                    nn.Sequential(
                        nn.Conv2d(
                            channels,
                            out_channels=h_dim,
                            kernel_size=self.kernel_size,
                            stride=1,
                            padding=self.kernel_size // 2,
                        ),
                        nn.LeakyReLU(),
                        nn.MaxPool2d(self.max_pool_size, stride=self.max_pool_stride)
                    )
                )
                channels = h_dim
            modules_query.append(
                nn.Sequential(
                    nn.Conv2d(
                        64,
                        out_channels=64,
                        kernel_size=self.kernel_size,
                        stride=1,
                        padding=self.kernel_size // 2,
                    ),
                    nn.LeakyReLU(),
                    nn.MaxPool2d(self.max_pool_size, stride=1)
                )
            )
            channels = 64                
        else:
            for h_dim in hidden_dims_query:
                modules_query.append(
                    nn.Sequential(
                        nn.Conv2d(
                            channels,
                            out_channels=h_dim,
                            kernel_size=self.kernel_size,
                            stride=1,
                            padding=self.kernel_size // 2,
                        ),
                        nn.LeakyReLU(),
                        nn.AvgPool2d(self.average_pool_size, stride=self.average_pool_stride)
                    )
                )
                channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.encoder_pos_embedding = SoftPositionEmbed(self.in_channels, self.out_features, resolution)
        self.average_pool_layer = nn.AvgPool2d(self.average_pool_size, stride=self.average_pool_stride)
        self.query_compressor = nn.Sequential(*modules_query)
        self.query_out_layer = nn.Sequential(
            nn.Linear(self.out_features_query, self.out_features_query),
            nn.LeakyReLU(),
            nn.Linear(self.out_features_query, self.out_features_query),
        )
        self.encoder_out_layer = nn.Sequential(
            nn.Linear(self.out_features, self.out_features),
            nn.LeakyReLU(),
            nn.Linear(self.out_features, self.out_features),
        )
        if self.use_residual_path:
            self.norm_mlp = nn.LayerNorm(self.slot_size)

        self.decoder = build_decoder(
            decoder_hidden_dims=self.decoder_hidden_dims, 
            kernel_size=self.kernel_size, 
            decoder_stride=self.decoder_stride, 
            decoder_padding=self.decoder_padding, 
            decoder_output_padding=self.decoder_output_padding,
            resolution=self.resolution,
            out_features=self.out_features,
            decoder_resolution=self.decoder_resolution,
            slot_size=self.slot_size
        )
        self.decoder_pos_embedding = SoftPositionEmbed(self.in_channels, self.slot_size, self.decoder_resolution)

        if self.use_hopfield:
            if self.use_hopfield_only_qqq:
                self.slot_attention = HopfieldLayerProjectBeforeOnlyQQQ(
                    in_features=self.out_features,
                    num_iterations=self.num_iterations,
                    num_slots=self.num_slots,
                    slot_size=self.slot_size,
                    mlp_hidden_size=128,
                    hopfield_beta=self.hopfield_beta
                )

            else:
                raise NotImplementedError

    def forward(self, x, **kwargs):
        if self.empty_cache:
            torch.cuda.empty_cache()

        batch_size, num_channels, height, width = x.shape
        encoder_out = self.encoder(x)
        encoder_out = self.encoder_pos_embedding(encoder_out)

        # `encoder_out` has shape: [batch_size, filter_size, height, width]
        if self.use_average_pool:
            encoder_out = self.average_pool_layer(encoder_out)

        query_out = self.query_compressor(encoder_out)

        encoder_out = torch.flatten(encoder_out, start_dim=2, end_dim=3)
        query_out = torch.flatten(query_out, start_dim=2, end_dim=3)
        # `encoder_out` has shape: [batch_size, filter_size, height*width]
        encoder_out = encoder_out.permute(0, 2, 1)
        query_out = query_out.permute(0, 2, 1)

        encoder_out = self.encoder_out_layer(encoder_out)
        query_out = self.query_out_layer(query_out)
        
        batch_size, num_patches, slot_size = encoder_out.shape
        batch_size, num_slots, slot_size = query_out.shape
        # `encoder_out` has shape: [batch_size, height*width, filter_size]
        # batch_size, 16384, 64
        if self.use_hopfield:
            slots = self.slot_attention(query_out, encoder_out, **kwargs)

            # slots shape: [batch_size, height*width, slot_size]
            assert_shape(slots.size(), (batch_size, num_slots, self.slot_size))
            # `slots` has shape: [batch_size, num_slots, slot_size].
            batch_size, num_slots, slot_size = slots.shape

            if self.use_residual_path:
                slots = slots + self.norm_mlp(query_out)
        
        else:
            slots = query_out
            
        slots = slots.view(batch_size * num_slots, slot_size, 1, 1)

        decoder_in = slots.repeat(1, 1, self.decoder_resolution[0], self.decoder_resolution[1])

        out = self.decoder_pos_embedding(decoder_in)
        out = self.decoder(out)
        # `out` has shape: [batch_size*num_slots, num_channels+1, height, width].
        assert_shape(out.size(), (batch_size * num_slots, num_channels + 1, height, width))

        out = out.view(batch_size, num_slots, num_channels + 1, height, width)
        recons = out[:, :, :num_channels, :, :]
        masks = out[:, :, -1:, :, :]
        if self.use_gumble_softmax:
            masks = F.gumbel_softmax(masks, dim=1, hard=False) 
        else:
            masks = F.softmax(masks, dim=1)

        # TODO: try doing a max here instead of a sum, forcing the model to use one slot for each patch
        recon_combined = torch.sum(recons * masks, dim=1)
        return recon_combined, recons, masks, slots

    def loss_function(self, input):
        true_imgs, true_masks = input
        if self.l1_loss:
            recon_combined, recons, masks, slots = self.forward(true_imgs)            
            loss = F.l1_loss(recon_combined, true_imgs)
            
        else:
            recon_combined, recons, masks, slots = self.forward(true_imgs)
            loss = F.mse_loss(recon_combined, true_imgs)

        # for tetrominoes
        # exclude background labels from ARI
        with torch.no_grad():
            true_masks_wo_bg = true_masks[:,1:,:,:]
            masks_wo_bg = masks[:,1:,:,:]
            # flatten masks along H and W dimensions
            true_masks_wo_bg = true_masks_wo_bg.view(true_masks_wo_bg.shape[0], true_masks_wo_bg.shape[1], -1)
            masks_wo_bg = masks_wo_bg.view(masks_wo_bg.shape[0], masks_wo_bg.shape[1], -1)
            # swap axes to be in the format (batch, points, labels)
            true_masks_wo_bg = true_masks_wo_bg.permute(0, 2, 1)
            masks_wo_bg = masks_wo_bg.permute(0, 2, 1)
            ari = adjusted_rand_index(true_masks_wo_bg, masks_wo_bg)
            mean_ari = torch.mean(ari)
        
        return {
                "loss": loss,
                'ari': mean_ari,
            }
 

