# Taken from https://github.com/singhgautam/slate/blob/master/slot_attn.py
import torch
import torch.nn as nn
import torch.nn.functional as F

from .networks import linear, gru_cell


class SlotAttention(nn.Module):
    def __init__(
        self,
        num_iterations,
        num_slots,
        input_size,
        slot_size,
        mlp_hidden_size,
        heads,
        epsilon=1e-8,
    ):
        super().__init__()

        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.input_size = input_size
        self.slot_size = slot_size
        self.mlp_hidden_size = mlp_hidden_size
        self.epsilon = epsilon
        self.num_heads = heads

        self.norm_inputs = nn.LayerNorm(input_size)
        self.norm_slots = nn.LayerNorm(slot_size)
        self.norm_mlp = nn.LayerNorm(slot_size)

        # Linear maps for the attention module.
        self.project_q = linear(slot_size, slot_size, bias=False)
        self.project_k = linear(input_size, slot_size, bias=False)
        self.project_v = linear(input_size, slot_size, bias=False)

        # Slot update functions.
        # self.gru = gru_cell(slot_size, slot_size)

        self.mlp = nn.Sequential(
            linear(slot_size, mlp_hidden_size, weight_init="kaiming"),
            nn.ReLU(),
            linear(mlp_hidden_size, slot_size),
        )

    def forward(self, inputs, slots):
        # `inputs` has shape [batch_size, num_inputs, input_size].
        # `slots` has shape [batch_size, num_slots, slot_size].

        B, N_kv, D_inp = inputs.size()
        B, N_q, D_slot = slots.size()

        inputs = self.norm_inputs(inputs)
        k = (
            self.project_k(inputs).view(B, N_kv, self.num_heads, -1).transpose(1, 2)
        )  # Shape: [batch_size, num_heads, num_inputs, slot_size // num_heads].
        v = (
            self.project_v(inputs).view(B, N_kv, self.num_heads, -1).transpose(1, 2)
        )  # Shape: [batch_size, num_heads, num_inputs, slot_size // num_heads].
        k = ((self.slot_size // self.num_heads) ** (-0.5)) * k

        # Multiple rounds of attention.
        for _ in range(self.num_iterations):
            slots_prev = slots
            slots = self.norm_slots(slots)

            # Attention.
            q = (
                self.project_q(slots).view(B, N_q, self.num_heads, -1).transpose(1, 2)
            )  # Shape: [batch_size, num_heads, num_slots, slot_size // num_heads].
            attn_logits = torch.matmul(
                k, q.transpose(-1, -2)
            )  # Shape: [batch_size, num_heads, num_inputs, num_slots].
            attn = (
                F.softmax(
                    attn_logits.transpose(1, 2).reshape(B, N_kv, self.num_heads * N_q),
                    dim=-1,
                )
                .view(B, N_kv, self.num_heads, N_q)
                .transpose(1, 2)
            )  # Shape: [batch_size, num_heads, num_inputs, num_slots].
            attn_vis = attn.sum(1)  # Shape: [batch_size, num_inputs, num_slots].

            # Weighted mean.
            attn = attn + self.epsilon
            attn = attn / torch.sum(attn, dim=-2, keepdim=True)
            updates = torch.matmul(
                attn.transpose(-1, -2), v
            )  # Shape: [batch_size, num_heads, num_slots, slot_size // num_heads].
            updates = updates.transpose(1, 2).reshape(
                B, N_q, -1
            )  # Shape: [batch_size, num_slots, slot_size].

            # Slot update.
            updates = updates.view(-1, self.slot_size)
            slots_prev = slots_prev.view(-1, self.slot_size)

            slots = self.gru(updates, slots_prev)
            slots = slots.view(-1, self.num_slots, self.slot_size)
            slots = slots + self.mlp(self.norm_mlp(slots))

        return slots, attn_vis


class SlotAttentionEncoder(nn.Module):
    def __init__(
        self,
        num_iterations,
        num_slots,
        input_channels,
        slot_size,
        mlp_hidden_size,
        pos_channels,
        num_heads,
    ):
        super().__init__()

        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.input_channels = input_channels
        self.slot_size = slot_size
        self.mlp_hidden_size = mlp_hidden_size
        self.pos_channels = pos_channels

        self.layer_norm = nn.LayerNorm(input_channels)
        self.mlp = nn.Sequential(
            linear(input_channels, input_channels, weight_init="kaiming"),
            nn.ReLU(),
            linear(input_channels, input_channels),
        )

        # Parameters for Gaussian init (shared by all slots).
        self.slot_mu = nn.Parameter(torch.zeros(1, 1, slot_size))
        self.slot_log_sigma = nn.Parameter(torch.zeros(1, 1, slot_size))
        nn.init.xavier_uniform_(self.slot_mu)
        nn.init.xavier_uniform_(self.slot_log_sigma)

        self.slot_attention = SlotAttention(
            num_iterations,
            num_slots,
            input_channels,
            slot_size,
            mlp_hidden_size,
            num_heads,
        )

    def forward(self, x, slots=None):
        # `image` has shape: [batch_size, img_channels, img_height, img_width].
        # `encoder_grid` has shape: [batch_size, pos_channels, enc_height, enc_width].
        B, *_ = x.size()
        x = self.mlp(self.layer_norm(x))
        # `x` has shape: [batch_size, enc_height * enc_width, cnn_hidden_size].

        # Slot Attention module.
        if slots is None:
            slots = x.new_empty(B, self.num_slots, self.slot_size).normal_()
            slots = self.slot_mu + torch.exp(self.slot_log_sigma) * slots
        slots, attn = self.slot_attention(x, slots)
        # `slots` has shape: [batch_size, num_slots, slot_size].
        # `attn` has shape: [batch_size, enc_height * enc_width, num_slots].

        return slots, attn
