from utils import *

from SA_UNet import SA_UNet

import pdb
from omegaconf import OmegaConf
from einops import rearrange, repeat, reduce

class SlotAttentionImage(nn.Module):
    
    def __init__(self, num_iterations, num_slots,
                 input_size, slot_size, mlp_hidden_size,
                 epsilon=1e-8):
        super().__init__()
        
        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.input_size = input_size

        self.slot_size=slot_size

        self.mlp_hidden_size = mlp_hidden_size
        self.epsilon = epsilon

        # parameters for Gaussian initialization (shared by all slots).
        self.mu = nn.Parameter(torch.Tensor(1, 1, slot_size))
        self.log_sigma = nn.Parameter(torch.Tensor(1, 1, slot_size))
        nn.init.xavier_uniform_(self.mu)
        nn.init.xavier_uniform_(self.log_sigma)

        # norms
        self.norm_inputs = nn.LayerNorm(input_size)
        self.norm_slots = nn.LayerNorm(self.slot_size)
        self.norm_mlp = nn.LayerNorm(self.slot_size)
        
        # linear maps for the attention module.
        self.project_k = linear(input_size, slot_size, bias=False)
        self.project_q = linear(self.slot_size, self.slot_size, bias=False)
        self.project_v = linear(input_size, slot_size, bias=False)

        
        # slot update functions.
        self.gru = gru_cell(slot_size, self.slot_size)
        self.mlp = nn.Sequential(
            linear(self.slot_size, mlp_hidden_size, weight_init='kaiming'),
            nn.ReLU(),
            linear(mlp_hidden_size, self.slot_size))

    def forward(self, inputs,sigma=0):
        B, num_inputs, input_size = inputs.size()

        # initialize slots
        slots = inputs.new_empty(B, self.num_slots, self.slot_size).normal_()
        slots = self.mu + torch.exp(self.log_sigma) * slots

        # setup key and value
        inputs = self.norm_inputs(inputs)
        k = self.project_k(inputs)  # Shape: [batch_size,  num_inputs, slot_size].
        v = self.project_v(inputs)  # Shape: [batch_size,  num_inputs, slot_size].
        k = (self.slot_size ** (-0.5)) * k
        
        for i in range(self.num_iterations):
            
            slots_prev = slots

            slots = self.norm_slots(slots)

            # Attention.
            q = self.project_q(slots)  # Shape: [batch_size, num_slots, slot_size].
            #print(q.shape)

            attn_logits = torch.bmm(k[:], q.transpose(-1, -2))  # Shape: [batch_size, num_inputs, num_slots].
            attn_vis = F.softmax(attn_logits, dim=-1)  # B, num_inputs, num_slots

            # Weighted mean.
            attn = attn_vis + self.epsilon
            attn = attn / torch.sum(attn, dim=-2, keepdim=True)
            updates = torch.bmm(attn.transpose(-1, -2), v[:])    # `updates` has shape: [batch_size, num_slots, slot_size].


            # Slot update.
            #pdb.set_trace()
            slots = self.gru(updates.view(-1, self.slot_size),
                                slots_prev.view(-1, self.slot_size))
            slots = slots.view(-1, self.num_slots, self.slot_size)

            slots = slots + self.mlp(self.norm_mlp(slots)) # B, num_slots, slot_size

        return slots, attn_vis


class CartesianPositionalEmbedding(nn.Module):

    def __init__(self, channels, image_size):
        super().__init__()

        self.projection = conv2d(4, channels, 1)
        self.pe = nn.Parameter(self.build_grid(image_size).unsqueeze(0), requires_grad=False)

    def build_grid(self, side_length):
        coords = torch.linspace(0., 1., side_length + 1)
        coords = 0.5 * (coords[:-1] + coords[1:])
        grid_y, grid_x = torch.meshgrid(coords, coords)
        return torch.stack((grid_x, grid_y, 1 - grid_x, 1 - grid_y), dim=0)

    def forward(self, inputs):
        # `inputs` has shape: [batch_size, out_channels, height, width].
        # `grid` has shape: [batch_size, in_channels, height, width].
        return inputs + self.projection(self.pe)



class SAEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.cnn=SA_UNet(3,config['feat_dim'], config['SA_ch_mult'])


        self.pos = CartesianPositionalEmbedding(config['feat_dim'], config['image_shape'][-1]//4)

        self.layer_norm = nn.LayerNorm(config['feat_dim'])

        self.mlp = nn.Sequential(
            linear(config['feat_dim'], config['feat_dim'], weight_init='kaiming'),
            nn.ReLU(),
            linear(config['feat_dim'], config['feat_dim']))

        self.slot_attention = SlotAttentionImage(
            3, config['phase_param']['train']['num_slots']+1,
            config['feat_dim'], config['slot_size'], 128)

        self.slot_proj = linear(config['slot_size'], config['slot_size'], bias=False)

    def forward(self,x):
        B,C,H,W=x.shape
        H=H//4
        W=W//4
        emb = self.cnn(x)      # B, d_model, H_enc, W_enc
        emb = self.pos(emb)             # B, d_model, H_enc, W_enc

        emb_set = emb.permute(0, 2, 3, 1).flatten(start_dim=1, end_dim=2)    # B, H_enc * W_enc, d_model
        emb_set = self.mlp(self.layer_norm(emb_set))  # B, H_enc * W_enc, d_model
        slots, attns = self.slot_attention(emb_set)       # slots: B, num_slots, slot_size
                                                                        # attns: B, num_inputs, num_slots
        
        attns = attns.transpose(-1,-2).reshape(B,-1,1,H,W)  #B, num_slots,1,H,W
                    
        
        return slots,attns