"""
    Slot-cuboid-based attention model for affordance learning.
"""
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from slot_attention import SlotAttention

# helper functions
class BuildGrid(nn.Module):
    """Create a grid with dimension: *resolution, dim(resolution)*2."""
    def __init__(self, resolution):
        super().__init__()
        self.coord1 = nn.Parameter(torch.linspace(-0.5, 0.5, resolution[0]), requires_grad=False)
        self.coord2 = nn.Parameter(torch.linspace(-0.5, 0.5, resolution[1]), requires_grad=False)
        self.coord3 = nn.Parameter(torch.linspace(-0.5, 0.5, resolution[2]), requires_grad=False)
    
    def forward(self, expand_shape):
        coord1 = self.coord1.view(1, -1, 1, 1).expand(expand_shape)
        coord2 = self.coord2.view(1, 1, -1, 1).expand(expand_shape)
        coord3 = self.coord3.view(1, 1, 1, -1).expand(expand_shape)
        coords = torch.stack([coord1, coord2, coord3], dim=-1)
        return coords

class SoftPositionEmbed(nn.Module):
    """Adds soft positional embedding with learnable projection."""
    def __init__(self, resolution):
        """Builds the soft position embedding layer.

        Args:
          resolution: Tuple of integers specifying width and height of grid.
        """
        super().__init__()
        self.resolution = resolution
        self.soft_embed = nn.Linear(3, 64)
        self.build_grid = BuildGrid(resolution)

    def forward(self, x):
        # [batch_size, 4, 4, 4, 3]
        coords = self.build_grid(x[:,0].shape)
        # [batch_size, 4, 4, 4, 64]
        coords = self.soft_embed(coords)
        # [batch_size, 4, 4, 4, 64]
        x = x.permute(0,2,3,4,1) + coords
        return x

def quat2mat(quat):
    """Compute rotation matrix from quaternion."""
    B = quat.shape[0]
    N = quat.shape[1]
    quat = quat.contiguous().view(-1,4)
    w, x, y, z = quat[:,0], quat[:,1], quat[:,2], quat[:,3]
    w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
    wx, wy, wz = w*x, w*y, w*z
    xy, xz, yz = x*y, x*z, y*z
    # The matrix can be simplified since the norm of a rotation quaternion is 1.
    rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
                          2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
                          2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B*N, 3, 3)
    rotMat = rotMat.view(B,N,3,3)
    return rotMat 



class CuboidPred(nn.Module):
    """Cuboid parameter prediction branch."""
    # translation is not needed since we could calculate 
    #   the center of cuboid by weighted coordinate sum
    def __init__(self):
        super().__init__()
        self.conv_cuboid  = nn.Sequential(nn.Conv1d(64, 256, kernel_size=1, bias=False),
                                          nn.LeakyReLU(negative_slope=0.2, inplace = True),
                                          nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                          nn.LeakyReLU(negative_slope=0.2, inplace = True))
        self.conv_scale   = nn.Conv1d(128, 3, kernel_size=1)
        self.conv_rotate  = nn.Conv1d(128, 4, kernel_size=1)
        self.conv_rotate.bias.data = torch.Tensor([1, 0, 0, 0])

    def forward(self, embed):
        x_cuboid = self.conv_cuboid(embed)

        scale = self.conv_scale(x_cuboid).transpose(2, 1)    # (batch_size, num_cuboid, 3)
        scale = torch.sigmoid(scale)                         # (batch_size, num_cuboid, 3)
        
        rotate = self.conv_rotate(x_cuboid).transpose(2, 1)  # (batch_size, num_cuboid, 4)
        rotate = quat2mat(F.normalize(rotate,dim=2,p=2))     # (batch_size, num_cuboid, 3, 3)

        return scale, rotate #, trans, exist



class SlotCuboidVox(nn.Module):
    """Slot cuboid attention module for affordance learning in voxels."""
    def __init__(self, resolution, num_slots, num_iterations, with_cuboid=True, with_afford=False):
        super().__init__()
        self.num_slots = num_slots
        self.num_iterations = num_iterations
        self.encoded_dim = [4,4,4]

        # simple convolutional voxel encoder
        self.encoder_cnn = nn.Sequential(
            nn.Conv3d(1,  32, kernel_size=3, padding=1),  nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),  nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3, stride=2, padding=1)
        )

        # positional embedding
        self.build_grid = BuildGrid(resolution)
        self.encoder_pos = SoftPositionEmbed(self.encoded_dim)
        self.decoder_pos = SoftPositionEmbed(self.encoded_dim)

        # feedforward network
        self.layer_norm = nn.LayerNorm(64)
        self.feed_forward = nn.Sequential(
            nn.Linear(64, 64),  
            nn.ReLU(),
            nn.Linear(64, 64)
        )

        # slot attention
        self.slot_attention = SlotAttention(
            num_slots = self.num_slots, 
            dim = 64,
            iters = self.num_iterations,
            hidden_dim = 128)

        # simple convolutional voxel decoder
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose3d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.ConvTranspose3d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.ConvTranspose3d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU()
        )
        self.decoder_mlp = nn.Linear(64, 2)

        # affordance set prediction
        self.with_afford = with_afford
        if with_afford:
            # from slot embedding
            self.decoder_cls = nn.Sequential(
                nn.Linear(64, 64), nn.ReLU(),
                nn.Linear(64, 25), nn.Sigmoid()
            )


        self.with_cuboid = with_cuboid
        if not with_cuboid:
            return
        # slot cuboid parameter prediction branch
        self.cube_pred = CuboidPred()

        # zero tensor
        self.zero_value = nn.Parameter(torch.FloatTensor([0.0]), requires_grad=False)
        # neighbor sum conv3d filter
        self.neighbor_filter = nn.Parameter(torch.FloatTensor([[[0,0,0],[0,1,0],[0,0,0]],
                                                                [[0,1,0],[1,0,1],[0,1,0]],
                                                                [[0,0,0],[0,1,0],[0,0,0]]]), 
                                                                requires_grad=False)
        
        # when multiplied with point/voxel coordinate after rotation, 
        #   it projects the point to three orthogonal planes at the origin
        self.mask_project = nn.Parameter(torch.Tensor([[0,1,1],[0,1,1],[1,0,1],[1,0,1],[1,1,0],[1,1,0]]).float(), requires_grad=False)
        # when multiplied with scale_tiled, 
        #   it provides three orthogonal planes (each with duplicate) with 
        #   the translation vectors to six planes of the scaled cuboid
        # different from cuboid abstraction's original code for simplicity
        self.mask_plane = nn.Parameter(torch.Tensor([[1,0,0],[-1,0,0],[0,1,0],[0,-1,0],[0,0,1],[0,0,-1]]).float(), requires_grad=False)


    def forward(self, x):
        batch_size = x.size(0)
        num_slots = self.num_slots

        # 3D CNN encoding
        # [batch_size, 1, 32, 32, 32]
        x = x.unsqueeze(1)
        # [batch_size, 64, 4, 4, 4]
        x = self.encoder_cnn(x)

        # soft positional embedding
        x = self.encoder_pos(x)

        # layer normalization on the last dimension: feature size
        x = self.layer_norm(x.view(batch_size, -1, 64))
        # [batch_size, 4*4*4, 64]
        x = self.feed_forward(x)

        # slot attention
        # [batch_size, num_slots=6, 64]
        slots = self.slot_attention(x)

        # spatial broadcast decoding
        x = slots.view(batch_size*num_slots, 1,1,1, -1)
        # [batch_size*num_slots, 4, 4, 4, 64]
        x = x.expand(batch_size*num_slots, *self.encoded_dim, 64)

        # soft positional embedding
        # [batch_size*num_slots, 4, 4, 4, 64]
        x = self.decoder_pos(x.permute(0,4,1,2,3))

        # 3D CNN decoding
        # [batch_size*num_slots, 64, 32, 32, 32]
        x = self.decoder_cnn(x.permute(0,4,1,2,3))
        # [batch_size*num_slots, 32, 32, 32, 2]
        x = self.decoder_mlp(x.permute(0,2,3,4,1))

        # undo combination of slot and batch dimension; split alpha masks.
        # [batch_size, num_slots, 32, 32, 32, 2]
        x = x.view(batch_size, num_slots, *x.shape[1:])
        # [batch_size, num_slots, 32, 32, 32, 1] for each
        recons, masks = torch.split(x, [1,1], dim=-1)

        if self.with_afford:
            #  [batch_size, num_slots, 25]
            # from slot embedding
            one_hots = self.decoder_cls(slots)
        else:
            one_hots = None

        # normalize alpha masks over slots.
        masks = F.softmax(masks, dim=1)
        # [batch_size, 32, 32, 32]
        recon_combined = torch.sum(recons * masks, dim=1)[:,:,:,:,0]

        

        if not self.with_cuboid:
            return {
            # reconstruction results and components
            'recon_combined': recon_combined,
            'recons': recons,
            'masks': masks,
            'slots': slots,
            # affordance
            'one_hots': one_hots
        }
        # slot cuboid parameter prediction branch
        scale, rotate = self.cube_pred(slots.permute(0,2,1))
        scale = scale * 0.5
        # scale:  [batch_size, num_slots, 3]
        # rotate: [batch_size, num_slots, 3, 3]
        # trans:  [batch_size, num_slots, 3]
        # exist:  [batch_size, num_slots, 1]
        

        hard_combo_threshold = .5
        # which voxels pass the threshold for the combinded shape. 1 for existing. [B, 32, 32, 32].
        vox_tmp = torch.heaviside(torch.sigmoid(recon_combined)-hard_combo_threshold, 
                                    self.zero_value)
        # [B, 32, 32, 32, 1].
        slot_contrib_idx = torch.argmax(recons * masks, dim=1)[...,0]
        # [B, N, 32, 32, 32].
        mask_surface = F.one_hot(slot_contrib_idx, num_classes=num_slots).permute(0,4,1,2,3)
        mask_surface = mask_surface * vox_tmp[:, None, ...].repeat(1,num_slots,1,1,1)
        # [B, N, n, 1]
        mask_surface = mask_surface.view(batch_size, num_slots, -1, 1).detach()

        # {-.5~.5, -.5~.5, -.5~.5} meshgrid vox_grid: [batch_size, num_vox=32*32*32, 3]
        output_shape = recon_combined.shape
        vox_grid = self.build_grid(output_shape).view(batch_size, -1, 3)
        num_vox = vox_grid.shape[1]
        # [batch_size, num_slots, num_vox, 3]
        vox_grid_tiled = vox_grid.unsqueeze(1).repeat(1, num_slots, 1, 1)
        # [batch_size, num_slots, num_vox, 1]
        weight = masks.view(batch_size, num_slots, -1, 1) * recons.view(batch_size, num_slots, -1, 1)
        weight = torch.sigmoid(weight).detach() * mask_surface
        # [batch_size, num_slots, num_vox, 3]
        eps = 1e-5
        weighted_sum = vox_grid_tiled * weight / (torch.sum(weight, dim = 2).unsqueeze(2).repeat(1,1,num_vox,1) + eps)
        
        # center of each cuboid: [batch_size, num_slots, 3]
        center = torch.sum(weighted_sum, dim = 2)
        # [batch_size, num_slots, num_vox, 3]
        center_tiled = center.unsqueeze(2).repeat(1,1,num_vox,1)
        # translate the voxels to the center: [batch_size, num_slots, num_vox, 3]
        vox_centered = vox_grid_tiled - center_tiled
        
        # rotate the voxels
        # inverse rotation matrix to rotate the voxels (swap the last two dimensions)
        # [batch_size, num_slots, 3, 3]
        rotate_inv = rotate.permute(0,1,3,2)
        # [batch_size, num_slots, 3, num_vox]
        vox_centered_transposed = vox_centered.permute(0,1,3,2)
        # vox_rotated = matmul(R^{-1}, pos_vox^T)
        # [batch_size, num_slots, 3, num_vox]
        vox_rotated = torch.einsum('abcd,abde->abce', rotate_inv, vox_centered_transposed) 
        # [batch_size, num_slots, num_vox, 3]
        vox_rotated = vox_rotated.permute(0,1,3,2)
        # [batch_size, num_slots, num_vox, num_faces=6, 3]
        vox_rotated = vox_rotated.unsqueeze(3).repeat(1,1,1,6,1)
        
        # draw scaled cuboid centered at the origin: [batch_size, num_slots, num_vox, num_faces, 3]
        scale_tiled = scale.view(batch_size, num_slots, 1, 1, 3).repeat(1,1,num_vox,6,1)
        # [batch_size, num_slots, num_vox, num_faces, 3]
        mask_plane_tiled = self.mask_plane.view(1,1,1,6,3).repeat(batch_size,num_slots,num_vox,1,1)
        translation_to_cuboid = scale_tiled * mask_plane_tiled
        mask_project_tiled = self.mask_project.view(1,1,1,6,3).repeat(batch_size,num_slots,num_vox,1,1)
        # calculate the projected voxels (projected to the scaled cuboid's six faces)
        vox_proj = vox_rotated * mask_project_tiled + translation_to_cuboid
        # cap those points outside the cuboid but on the six planes. shape: [batch_size, num_slots, num_vox, num_faces, 3]
        vox_proj = torch.maximum(torch.minimum(vox_proj, scale_tiled), -scale_tiled)


        # return vox_proj
        return {
            # surface mask
            'mask_surface': mask_surface,
            # reconstruction results and components
            'recon_combined': recon_combined,
            'recons': recons,
            'masks': masks,
            'slots': slots,
            # cuboid parameters
            'center': center,
            'scale': scale,
            'rotate': rotate,
            'vox_rotated': vox_rotated,
            'vox_proj': vox_proj,
            # affordance
            'one_hots': one_hots
        }



