import dis
import einops
import wandb
from slot_attention.helpers.k_means_pp_as_sbp import k_means_plus_plus_init
from slot_attention.helpers.soft_k_means import soft_k_means
from slot_attention.model.model_utils import assert_shape


import torch
from torch import Tensor, nn
from torch.nn import functional as F


class SlotAttentionPlusPLus(nn.Module):
    def __init__(self, params, in_features, num_iterations, num_slots, slot_size, mlp_hidden_size, epsilon=1e-8):
        super().__init__()
        self.params = params
        self.in_features = in_features
        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.slot_size = slot_size  # number of hidden layers in slot dimensions
        self.mlp_hidden_size = mlp_hidden_size
        self.epsilon = epsilon

        self.norm_inputs = nn.LayerNorm(self.in_features)
        # I guess this is layer norm across each slot? should look into this
        self.norm_slots = nn.LayerNorm(self.slot_size)
        self.norm_mlp = nn.LayerNorm(self.slot_size)

        # Linear maps for the attention module.
        self.project_q = nn.Linear(self.slot_size, self.slot_size, bias=False)
        self.project_k = nn.Linear(self.in_features, self.slot_size, bias=False)
        self.project_v = nn.Linear(self.in_features, self.slot_size, bias=False)

        # Slot update functions.
        self.gru = nn.GRUCell(self.slot_size, self.slot_size)
        self.mlp = nn.Sequential(
            nn.Linear(self.slot_size, self.mlp_hidden_size),
            nn.ReLU(),
            nn.Linear(self.mlp_hidden_size, self.slot_size),
        )

    def forward(self, inputs: Tensor, **kwargs):
        # `inputs` has shape [batch_size, num_inputs, inputs_size].
        batch_size, num_inputs, inputs_size = inputs.shape
        inputs = self.norm_inputs(inputs)  # Apply layer norm to the input.
        k = self.project_k(inputs)  # Shape: [batch_size, num_inputs, slot_size].
        assert_shape(k.size(), (batch_size, num_inputs, self.slot_size))
        v = self.project_v(inputs)  # Shape: [batch_size, num_inputs, slot_size].
        assert_shape(v.size(), (batch_size, num_inputs, self.slot_size))

        vis_carrier = kwargs.get('vis_carrier', None)

        # start_t = time.time()
        
        # layer norm k
        if self.params.slatn_pp_layernorm_k:
            k = self.norm_slots(k)
        
        # run k-means++
        with torch.no_grad():
            # do several trials and pick the best one
            k_rep = einops.repeat(k, 'b n d -> (b s) n d', s=self.params.slatn_pp_n_trials)
            
            # init
            centroid_selection_mat = k_means_plus_plus_init(self.params, k_rep, self.num_slots, dist_func=self.params.slatn_pp_dist_func, **kwargs)
            q = torch.einsum('b n k, b n d -> b k d', centroid_selection_mat, k_rep)
            # execute soft k-means
            q, converged_selection_mat, dissim_measure = soft_k_means(k_rep, q, n_iterations=self.params.slatn_pp_n_iterations, **kwargs)
            
            # pick the best seed
            converged_selection_mat = einops.rearrange(converged_selection_mat, '(b s) n k -> b s n k', s=self.params.slatn_pp_n_trials)
            dissim_measure = einops.rearrange(dissim_measure, '(b s) -> b s', s=self.params.slatn_pp_n_trials)
            best_idcs = torch.argmin(dissim_measure, dim=-1)
            converged_selection_mat = converged_selection_mat[torch.arange(batch_size), best_idcs]

        q = torch.einsum('b n k, b n d -> b k d', converged_selection_mat, k)
        if vis_carrier is not None and not self.params.use_vit:
            # log histogram of project_k weights
            print(f'of the trials with dissimilarity-measures: {dissim_measure[0]}, chose {dissim_measure[0, best_idcs[0]]}')
            wandb.log({'project_k weights histogram': wandb.Histogram(self.project_k.weight.detach().cpu().numpy().flatten())})
            vis_carrier.add_qk_masks(name=f'Centroid selection mat', mask=centroid_selection_mat[0].detach().cpu().numpy())
            vis_carrier.add_qk_masks(name='Converged selection mat', mask=converged_selection_mat[0].detach().cpu().numpy())
            vis_carrier.add_queries_keys(name='Q,K before Slot Attention softmax', queries=q[0].detach().cpu().numpy(), keys=k[0].detach().cpu().numpy())
        
        attn_norm_factor = self.slot_size ** -0.5
        attn_logits = attn_norm_factor * torch.matmul(k, q.transpose(2, 1))
        attn = F.softmax(attn_logits, dim=-1)
        
        if self.params.slatn_use_competition:
            # Weighted mean.
            attn = attn + self.epsilon
            attn = attn / torch.sum(attn, dim=1, keepdim=True)
        
        if vis_carrier is not None:
            vis_carrier.add_qk_masks(name='Slot attention', mask=attn[0].detach().cpu().numpy())
        
        slots = torch.matmul(attn.transpose(1, 2), v)
        
        slots = slots + self.mlp(self.norm_mlp(slots))
        
        return slots
        
        # print(f'k_means_plus_plus', time.time() - start_t)
        # start_t = time.time()

        # Multiple rounds of attention.
        for _ in range(self.num_iterations):
            slots_prev = slots
            slots = self.norm_slots(slots)

            # Attention.
            q = self.project_q(slots)  # Shape: [batch_size, num_slots, slot_size].
            assert_shape(q.size(), (batch_size, self.num_slots, self.slot_size))

            attn_norm_factor = self.slot_size ** -0.5
            attn_logits = attn_norm_factor * torch.matmul(k, q.transpose(2, 1))
            attn = F.softmax(attn_logits, dim=-1)
            # `attn` has shape: [batch_size, num_inputs, num_slots].
            assert_shape(attn.size(), (batch_size, num_inputs, self.num_slots))

            if self.params.slatn_use_competition:
                # Weighted mean.
                attn = attn + self.epsilon
                attn = attn / torch.sum(attn, dim=1, keepdim=True)
            
            vis_carrier = kwargs.get('vis_carrier', None)
            if vis_carrier is not None:
                vis_carrier.add_qk_masks(name='Slot attention', mask=attn[0].detach().cpu().numpy())
            
            updates = torch.matmul(attn.transpose(1, 2), v)
            # `updates` has shape: [batch_size, num_slots, slot_size].
            assert_shape(updates.size(), (batch_size, self.num_slots, self.slot_size))

            if self.params.slatn_use_gru:
                # Slot update.
                # GRU is expecting inputs of size (N,H) so flatten batch and slots dimension
                slots = self.gru(
                    updates.view(batch_size * self.num_slots, self.slot_size),
                    slots_prev.view(batch_size * self.num_slots, self.slot_size),
                )
                slots = slots.view(batch_size, self.num_slots, self.slot_size)
                assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size))
            else:
                slots = slots_prev + updates

            slots = slots + self.mlp(self.norm_mlp(slots))
            assert_shape(slots.size(), (batch_size, self.num_slots, self.slot_size))

        # print(f'slot_attention_pp', time.time() - start_t)

        return slots