from utils import *

from dvae import dVAE


class SlotAttentionVideo(nn.Module):
    
    def __init__(self, num_iterations, num_slots, view_size,
                 input_size, slot_size, mlp_hidden_size,
                 epsilon=1e-8):
        super().__init__()
        
        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.input_size = input_size
        self.slot_size = slot_size
        self.mlp_hidden_size = mlp_hidden_size
        self.epsilon = epsilon
        self.view_size = view_size
        self.full_slots_size = slot_size + view_size
        self.upd_slots_size = slot_size

        # parameters for Gaussian initialization (shared by all slots).
        self.slot_mu = nn.Parameter(torch.Tensor(1, 1, slot_size))
        self.slot_log_sigma = nn.Parameter(torch.Tensor(1, 1, slot_size))
        nn.init.xavier_uniform_(self.slot_mu)
        nn.init.xavier_uniform_(self.slot_log_sigma)

        # norms
        self.norm_inputs = nn.LayerNorm(input_size)
        self.norm_slots = nn.LayerNorm(self.full_slots_size)
        self.norm_mlp = nn.LayerNorm(self.upd_slots_size)
        
        # linear maps for the attention module.
        self.project_q = linear(self.full_slots_size, self.full_slots_size, bias=False)
        self.project_k = linear(input_size, self.full_slots_size, bias=False)
        self.project_v = linear(input_size, self.full_slots_size, bias=False)
        
        # slot update functions.
        self.gru = gru_cell(self.full_slots_size, self.upd_slots_size)
        self.mlp = nn.Sequential(
            linear(self.upd_slots_size, mlp_hidden_size, weight_init='kaiming'),
            nn.ReLU(),
            linear(mlp_hidden_size, self.upd_slots_size))

    def forward(self, slots_view, inputs):
        B, T, num_inputs, input_size = inputs.size()

        # initialize slots
        slots = inputs.new_empty(B, self.num_slots, self.slot_size).normal_()
        slots_obj = self.slot_mu + torch.exp(self.slot_log_sigma) * slots

        # setup key and value
        inputs = self.norm_inputs(inputs)
        k = self.project_k(inputs)  # Shape: [batch_size, T, num_inputs, slot_size].
        v = self.project_v(inputs)  # Shape: [batch_size, T, num_inputs, slot_size].
        k = (self.full_slots_size ** (-0.5)) * k

        for i in range(self.num_iterations):

            slots = torch.cat([slots_view[:, :, None].expand(-1, -1, self.num_slots, -1),
                               slots_obj[:, None].expand(-1, T, -1, -1)], dim=-1)

            slots_upd = slots_obj[:, None].expand(-1, T, -1, -1)
            
            slots = self.norm_slots(slots)

            # Attention.
            q = self.project_q(slots)  # Shape: [batch_size, num_slots, full_slots_size].
            attn_logits = torch.matmul(k, q.transpose(-1, -2))
            attn_vis = F.softmax(attn_logits, dim=-1)
            # `attn_vis` has shape: [batch_size, num_inputs, num_slots].

            # Weighted mean.
            attn = attn_vis + self.epsilon
            attn = attn / torch.sum(attn, dim=-2, keepdim=True)
            updates = torch.matmul(attn.transpose(-1, -2), v)
            # `updates` has shape: [batch_size, num_slots, full_slots_size].

            # Slot update.
            slots_upd = self.gru(updates.reshape(-1, self.full_slots_size),
                                slots_upd.reshape(-1, self.upd_slots_size))
            slots_upd = slots_upd.reshape(B, T, self.num_slots, self.upd_slots_size)

            # use MLP only when more than one iterations
            if i < self.num_iterations - 1:
                slots_upd = slots_upd + self.mlp(self.norm_mlp(slots_upd))

            slots_obj = slots_upd.mean(1)
      
        return slots_obj, attn_vis


class LearnedPositionalEmbedding1D(nn.Module):

    def __init__(self, num_inputs, input_size, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.pe = nn.Parameter(torch.zeros(1, num_inputs, input_size), requires_grad=True)
        nn.init.trunc_normal_(self.pe)

    def forward(self, input, offset=0):
        """
        input: batch_size x seq_len x d_model
        return: batch_size x seq_len x d_model
        """
        T = input.shape[1]
        return self.dropout(input + self.pe[:, offset:offset + T])


class CartesianPositionalEmbedding(nn.Module):

    def __init__(self, channels, image_size):
        super().__init__()

        self.projection = conv2d(4, channels, 1)
        self.pe = nn.Parameter(self.build_grid(image_size).unsqueeze(0), requires_grad=False)

    def build_grid(self, side_length):
        coords = torch.linspace(0., 1., side_length + 1)
        coords = 0.5 * (coords[:-1] + coords[1:])
        grid_y, grid_x = torch.meshgrid(coords, coords)
        return torch.stack((grid_x, grid_y, 1 - grid_x, 1 - grid_y), dim=0)

    def forward(self, inputs):
        # `inputs` has shape: [batch_size, out_channels, height, width].
        # `grid` has shape: [batch_size, in_channels, height, width].
        return inputs + self.projection(self.pe)


class LORMEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        
        self.cnn = nn.Sequential(
            Conv2dBlock(args.img_channels, args.cnn_hidden_size, 5, 1 if args.image_size == 64 else 2, 2),
            Conv2dBlock(args.cnn_hidden_size, args.cnn_hidden_size, 5, 1, 2),
            Conv2dBlock(args.cnn_hidden_size, args.cnn_hidden_size, 5, 1, 2),
            conv2d(args.cnn_hidden_size, args.d_model, 5, 1, 2),
        )

        self.pos = CartesianPositionalEmbedding(args.d_model, args.image_size if args.image_size == 64 else args.image_size // 2)

        self.layer_norm = nn.LayerNorm(args.d_model)

        self.mlp = nn.Sequential(
            linear(args.d_model, args.d_model, weight_init='kaiming'),
            nn.ReLU(),
            linear(args.d_model, args.d_model))

        self.savi = SlotAttentionVideo(
            args.num_iterations, args.num_slots, args.view_size,
            args.d_model, args.slot_size, args.mlp_hidden_size,
        )

        self.slot_combine = linear(args.slot_size + args.view_size, args.slot_size + args.view_size, bias=False)
        self.slot_proj = linear(args.slot_size + args.view_size, args.d_model, bias=False)

        self.enc_viewattr = nn.Sequential(
            Conv2dBlock(args.img_channels, args.cnn_hidden_size, 5, 1 if args.image_size == 64 else 2, 2),
            Conv2dBlock(args.cnn_hidden_size, args.cnn_hidden_size, 5, 1, 2),
            Conv2dBlock(args.cnn_hidden_size, args.cnn_hidden_size, 5, 1, 2),
            conv2d(args.cnn_hidden_size, args.d_model, 5, 1, 2),
        )

        self.mlp_view =  nn.Sequential(
            linear(args.d_model, args.d_model, weight_init='kaiming'),
            nn.ReLU(),
            linear(args.d_model, args.view_size),
        )


class PatchDecoder(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.decoder = build_mlp(args.d_model, args.d_model + 1, features=[1024, 1024, 1024])

        self.pos = LearnedPositionalEmbedding1D(args.num_patches, args.d_model)

        self.head = linear(args.d_model, args.vocab_size, bias=False)


class LORM(nn.Module):
    
    def __init__(self, args):
        super().__init__()
        
        self.num_iterations = args.num_iterations
        self.num_slots = args.num_slots
        self.cnn_hidden_size = args.cnn_hidden_size
        self.slot_size = args.slot_size
        self.mlp_hidden_size = args.mlp_hidden_size
        self.img_channels = args.img_channels
        self.image_size = args.image_size
        self.vocab_size = args.vocab_size
        self.d_model = args.d_model
        self.num_patches = args.num_patches
        self.view_size = args.view_size

        # dvae
        self.dvae = dVAE(args.vocab_size, args.img_channels)

        # encoder networks
        self.lorm_encoder = LORMEncoder(args)

        # decoder networks
        self.lorm_decoder = PatchDecoder(args)

    def forward(self, video, tau, hard):
        B, T, C, H, W = video.size()

        video_flat = video.flatten(end_dim=1)                               # B * T, C, H, W

        # dvae encode
        z_logits = F.log_softmax(self.dvae.encoder(video_flat), dim=1)       # B * T, vocab_size, H_enc, W_enc
        z_soft = gumbel_softmax(z_logits, tau, hard, dim=1)                  # B * T, vocab_size, H_enc, W_enc
        z_hard = gumbel_softmax(z_logits, tau, True, dim=1).detach()         # B * T, vocab_size, H_enc, W_enc
        z_hard = z_hard.permute(0, 2, 3, 1).flatten(start_dim=1, end_dim=2)  # B * T, H_enc * W_enc, vocab_size

        # Encode Viewpoint attribute
        view_feat = self.lorm_encoder.enc_viewattr(video_flat)   # B * T, d_model, H_enc, W_enc
        slots_view = self.lorm_encoder.mlp_view(view_feat.reshape(*view_feat.shape[:2],-1).mean(-1)).reshape(B, T, -1)
        
        # savi
        emb = self.lorm_encoder.cnn(video_flat)      # B * T, cnn_hidden_size, H, W
        emb = self.lorm_encoder.pos(emb)             # B * T, cnn_hidden_size, H, W
        H_enc, W_enc = emb.shape[-2:]

        emb_set = emb.permute(0, 2, 3, 1).flatten(start_dim=1, end_dim=2)                                   # B * T, H * W, cnn_hidden_size
        emb_set = self.lorm_encoder.mlp(self.lorm_encoder.layer_norm(emb_set))                            # B * T, H * W, cnn_hidden_size
        emb_set = emb_set.reshape(B, T, H_enc * W_enc, self.d_model)                                        # B, T, H * W, cnn_hidden_size
        
        slots_obj, attns = self.lorm_encoder.savi(slots_view, emb_set)         # slots: B, T, num_slots, slot_size
                                                                # attns: B, T, num_slots, num_inputs
                                   
        attns = attns\
            .transpose(-1, -2)\
            .reshape(B, T, self.num_slots, 1, H_enc, W_enc)\
            .repeat_interleave(H // H_enc, dim=-2)\
            .repeat_interleave(W // W_enc, dim=-1)          # B, T, num_slots, 1, H, W
        attns = video.unsqueeze(2) * attns + (1. - attns)   # B, T, num_slots, C, H, W

        # decode

        # Combine the view and objs attributes 
        slots = torch.cat(
            [
                slots_view[:, :, None].expand(-1, -1, self.num_slots, -1),
                slots_obj[:, None].expand(-1, T, -1, -1)
            ],dim=-1)
        slots = self.lorm_encoder.slot_combine(slots)

        slots = self.lorm_encoder.slot_proj(slots)         # B, T, num_slots, d_model

        slots_flat = slots.flatten(0, -2)   # B * T * num_slots, d_model
        slots_flat = slots_flat.unsqueeze(1).expand(-1, self.num_patches, -1)   # B * T * num_slots, num_patches, d_model

        # Simple learned additive embedding as in ViT
        slots_flat = self.lorm_decoder.pos(slots_flat)  # B * T * num_slots, num_patches, d_model

        output = self.lorm_decoder.decoder(slots_flat) # B * T * num_slots, num_patches, vocab_size + 1
        output = output.reshape(B * T, self.num_slots, self.num_patches, -1) # B * T, num_slots, num_patches, vocab_size + 1

        # Split out alpha channel and normalize over slots.
        decoded_patches, alpha = output.split([self.d_model, 1], dim=-1) # B * T, num_slots, num_patches, vocab_size / 1
        alpha = alpha.softmax(dim=-3)   # B * T, num_slots, num_patches, 1

        pred = torch.sum(decoded_patches * alpha, dim=-3) # B * T, num_patches, vocab_size
        pred = self.lorm_decoder.head(pred)    # B * T, num_patches, vocab_size

        H_enc, W_enc = (self.image_size // 4), (self.image_size // 4)

        masks = alpha.squeeze(-1)   # B * T, num_slots, num_patches
        masks = masks\
            .reshape(B, T, self.num_slots, 1, H_enc, W_enc)\
            .repeat_interleave(H // H_enc, dim=-2)\
            .repeat_interleave(W // W_enc, dim=-1)          # B, T, num_slots, 1, H, W
        masks = video.unsqueeze(2) * masks + (1. - masks)                               # B, T, num_slots, C, H, W

        cross_entropy = -(z_hard * torch.log_softmax(pred, dim=-1)).sum() / (B * T)                         # 1

        # dvae recon
        dvae_recon = self.dvae.decoder(z_soft).reshape(B, T, C, H, W)               # B, T, C, H, W
        dvae_mse = ((video - dvae_recon) ** 2).sum() / (B * T)                      # 1

        return (dvae_recon.clamp(0., 1.),
                cross_entropy,
                dvae_mse,
                attns,
                masks)

    def encode(self, video):
        B, T, C, H, W = video.size()

        video_flat = video.flatten(end_dim=1)

        # savi
        emb = self.lorm_encoder.cnn(video_flat)      # B * T, cnn_hidden_size, H, W
        emb = self.lorm_encoder.pos(emb)             # B * T, cnn_hidden_size, H, W
        H_enc, W_enc = emb.shape[-2:]

        emb_set = emb.permute(0, 2, 3, 1).flatten(start_dim=1, end_dim=2)                                   # B * T, H * W, cnn_hidden_size
        emb_set = self.lorm_encoder.mlp(self.lorm_encoder.layer_norm(emb_set))                            # B * T, H * W, cnn_hidden_size
        emb_set = emb_set.reshape(B, T, H_enc * W_enc, self.d_model)                                                # B, T, H * W, cnn_hidden_size

        # Encode Viewpoint attribute
        view_feat = self.lorm_encoder.enc_viewattr(video_flat)   # B * T, d_model, H_enc, W_enc
        view_feat_re = view_feat.reshape(*view_feat.shape[:2], -1)
        slots_view = self.lorm_encoder.mlp_view(view_feat_re.mean(-1))
        slots_view = slots_view.reshape(B, T, -1)

        slots_obj, attns = self.lorm_encoder.savi(slots_view, emb_set)     # slots: B, T, num_slots, slot_size
                                                            # attns: B, T, num_slots, num_inputs

        attns = attns \
            .transpose(-1, -2) \
            .reshape(B, T, self.num_slots, 1, H_enc, W_enc) \
            .repeat_interleave(H // H_enc, dim=-2) \
            .repeat_interleave(W // W_enc, dim=-1)                      # B, T, num_slots, 1, H, W

        attns_vis = video.unsqueeze(2) * attns + (1. - attns)           # B, T, num_slots, C, H, W

        return slots_obj, slots_view, attns_vis, attns

    def decode(self, slots_obj, slots_view):
        B, T, _ = slots_view.size()

        H, W = self.image_size, self.image_size
        H_enc, W_enc = (self.image_size // 4), (self.image_size // 4)

        # Combine the view and objs attributes 
        slots = torch.cat(
            [
                slots_view[:, :, None].expand(-1, -1, self.num_slots, -1),
                slots_obj[:, None].expand(-1, T, -1, -1)
            ], dim=-1)
        slots = self.lorm_encoder.slot_combine(slots)

        # decode
        slots = self.lorm_encoder.slot_proj(slots) # B, T, num_slots, d_model

        slots_flat = slots.flatten(0, -2)   # B * T * num_slots, d_model
        slots_flat = slots_flat.unsqueeze(1).expand(-1, self.num_patches, -1)   # B * T * num_slots, num_patches, d_model

        # Simple learned additive embedding as in ViT
        slots_flat = self.lorm_decoder.pos(slots_flat) # B * T * num_slots, num_patches, d_model

        output = self.lorm_decoder.decoder(slots_flat) # B * T * num_slots, num_patches, vocab_size + 1
        output = output.reshape(B * T, self.num_slots, self.num_patches, -1) # B * T, num_slots, num_patches, vocab_size + 1

        # Split out alpha channel and normalize over slots.
        decoded_patches, alpha = output.split([self.d_model, 1], dim=-1) # B * T, num_slots, num_patches, vocab_size / 1
        alpha = alpha.softmax(dim=-3)   # B * T, num_slots, num_patches, 1

        pred = torch.sum(decoded_patches * alpha, dim=-3) # B * T, num_patches, vocab_size
        pred = F.one_hot(self.lorm_decoder.head(pred).argmax(dim=-1), self.vocab_size)
        pred = pred.permute(0, 2, 1).float().reshape(B * T, self.vocab_size, H_enc, W_enc)   # B * T, vocab_size, H_enc, W_enc

        gen_mlp = self.dvae.decoder(pred)

        masks = alpha.squeeze(-1)   # B * T, num_slots, num_patches
        masks = masks\
            .reshape(B, T, self.num_slots, 1, H_enc, W_enc)\
            .repeat_interleave(H // H_enc, dim=-2)\
            .repeat_interleave(W // W_enc, dim=-1)          # B, T, num_slots, 1, H, W
        
        decoded_patches = decoded_patches.flatten(end_dim=1)
        slots_pred = F.one_hot(self.lorm_decoder.head(decoded_patches).argmax(dim=-1), self.vocab_size)    # B * T * num_slots, num_patches, vocab_size
        slots_pred = slots_pred.permute(0, 2, 1).float().reshape(B * T * self.num_slots, self.vocab_size, H_enc, W_enc) # B * T * num_slots, vocab_size, H_enc, W_enc

        masks_vis = self.dvae.decoder(slots_pred).reshape(B, T, self.num_slots, -1, H, W)
        masks_vis = masks_vis.clamp(0., 1.)

        masks_vis = masks_vis * masks + (1. - masks)           # B, T, num_slots, C, H, W

        return gen_mlp.clamp(0., 1.), masks_vis, masks

    def reconstruct(self, video):
        """
        image: batch_size x img_channels x H x W
        """
        B, T, C, H, W = video.size()
        slots_obj,slots_view, _, _ = self.encode(video)
        recon_mlp, _, _ = self.decode(slots_obj, slots_view)
        recon_mlp = recon_mlp.reshape(B, T, C, H, W)

        return recon_mlp
