import torch
from .transformer import CloudGPT, MatchingGPT
from .pointnet import PointNetEncoder

import sys
sys.path.append('..')
from utils.distance import chamfer

def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

class SimplePCEmbedder(torch.nn.Module):

    def __init__(self, input_dim=2, dim=2, mlp_dims=[4, 8, 16], mlp_activation=torch.nn.ReLU):
        super().__init__()
        self.dim = dim

        layers = []
        for dim_in, dim_out in zip([input_dim] + mlp_dims, mlp_dims + [dim]):
            layers.append(torch.nn.Linear(dim_in, dim_out))
            layers.append(mlp_activation())

        layers = layers[:-1]

        self.fc = torch.nn.Sequential(*layers)
        self.fc.apply(init_weights)

    def forward(self, pc):
        '''
            pc : [torch.Tensor] B x M x 2
        '''
        x = self.fc(pc) # B x M x dim
        x = x.mean(1) # B x dim
        return x

class TransformerPCEmbedder(torch.nn.Module):

    def __init__(self, input_dim=2, latent_dim=16, dim_model=64, dim_keys=64, dim_hidden=64, nb_heads=4, nb_blocks=4, dropout=0.):
        super().__init__()
        self.transformer = CloudGPT(input_dim, latent_dim, dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout)

    def forward(self, pc):
        '''
            pc : [torch.Tensor] B x M x 2
        '''
        pc_embed = self.transformer(pc) # B x M x dim
        pc_embed = pc_embed.mean(1) # B x dim
        return pc_embed

class DistanceNetwork(torch.nn.Module):
    def __init__(self, input_dim=2, latent_dim=16, mlp_dims=[32, 64, 32, 16], mlp_activation=torch.nn.ReLU, *args, **kwargs):
        super().__init__()
        mlp_dims = list(mlp_dims)

        self.model = SimplePCEmbedder(input_dim, latent_dim, mlp_dims=[4, 8, 16], mlp_activation=mlp_activation)

        layers = []
        for dim_in, dim_out in zip([2*latent_dim] + mlp_dims, mlp_dims + [1]):
            layers.append(torch.nn.Linear(dim_in, dim_out))
            layers.append(mlp_activation())

        layers = layers[:-1]

        self.mlp = torch.nn.Sequential(*layers)
        self.mlp.apply(init_weights)

    def forward(self, pc_source, pc_target):
        '''
            pc_source : [torch.Tensor] B x M x 2
            pc_target : [torch.Tensor] B x M x 2
        '''
        # Compute embeddings
        emb_source = self.model(pc_source) # B x dim
        emb_target = self.model(pc_target) # B x dim

        # Compute distances
        dists = 0.5*self.mlp(torch.cat((emb_source, emb_target), 1)) # B x 1
        dists += 0.5*self.mlp(torch.cat((emb_target, emb_source), 1))


        return dists

class DistanceNetworkTransformer(torch.nn.Module):
    def __init__(self, input_dim=2, latent_dim=16, dim_model=64, dim_keys=64, dim_hidden=64, nb_heads=4, nb_blocks=4, dropout=0., mlp_dims=[32, 64, 32, 16], mlp_activation=torch.nn.ReLU, *args, **kwargs):
        super().__init__()
        mlp_dims = list(mlp_dims)

        self.model = TransformerPCEmbedder(input_dim, latent_dim, dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout)

        layers = []
        for dim_in, dim_out in zip([2*latent_dim] + mlp_dims, mlp_dims + [1]):
            layers.append(torch.nn.Linear(dim_in, dim_out))
            layers.append(mlp_activation())

        self.mlp = torch.nn.Sequential(*layers)
        self.mlp.apply(init_weights)

    def forward(self, pc_source, pc_target):
        '''
            pc_source : [torch.Tensor] B x M x 2
            pc_target : [torch.Tensor] B x M x 2
        '''
        # Compute embeddings
        emb_source = self.model(pc_source) # B x dim
        emb_target = self.model(pc_target) # B x dim

        # Compute distances
        dists = 0.5*self.mlp(torch.cat((emb_source, emb_target), 1)) # B x 1
        dists += 0.5*self.mlp(torch.cat((emb_target, emb_source), 1))


        return dists

class MatchingTransformer(torch.nn.Module):
    def __init__(self, input_dim=2, latent_dim=16, dim_model=64, dim_keys=64, dim_hidden=64, nb_heads=4, nb_blocks=4, dropout=0., *args, **kwargs):
        super().__init__()

        self.model = MatchingGPT(input_dim, latent_dim, dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout)

    def forward(self, pc_source, pc_target):
        '''
            pc_source : [torch.Tensor] B x M x 2
            pc_target : [torch.Tensor] B x M x 2
        '''
        # Get penultimate attention mask over both point clouds
        seq_a_b = torch.cat([pc_source, pc_target], dim=1)

        pos_emb_a = torch.zeros(pc_source.shape[1], device=seq_a_b.device) 
        pos_emb_b = torch.ones(pc_target.shape[1], device=seq_a_b.device)
        pos_emb = torch.cat([pos_emb_a, pos_emb_b], dim=0) #.unsqueeze(0).repeat(seq_a_b.shape[0], 1)

        mask = self.model(seq_a_b, pos_emb).squeeze(1) # B x H x 2M x 2M -> B x 2M x 2M
        matching_attn_mask_s_t = mask[:, :pc_source.shape[1], pc_source.shape[1]:] # B x M x M
        matching_attn_mask_t_s = mask[:, pc_source.shape[1]:, :pc_source.shape[1]] # B x M x M

        # Compute distances
        matchings_s_t = torch.argmax(matching_attn_mask_s_t, dim=-1) # B x M
        dists = torch.norm(pc_source - pc_target[torch.arange(pc_target.shape[0], device=pc_target.device)[:, None], matchings_s_t], dim=-1).sum(-1) # B
        
        matchings_t_s = torch.argmax(matching_attn_mask_t_s, dim=-1) # B x M
        dists += torch.norm(pc_target - pc_source[torch.arange(pc_source.shape[0], device=pc_source.device)[:, None], matchings_t_s], dim=-1).sum(-1) # B
        # dists += torch.norm(pc_source - pc_target[torch.arange(pc_target.shape[0], device=pc_target.device)[:, None], matchings_t_s], dim=-1).sum(-1) # B
        dists /= 2

        return dists, matching_attn_mask_s_t, matching_attn_mask_t_s
    

class PointNet(torch.nn.Module):
    def __init__(self, input_dim=2, latent_dim=16, mlp_dims=[32, 64, 32, 16], mlp_activation=torch.nn.ReLU, *args, **kwargs):
        super().__init__()
        mlp_dims = list(mlp_dims)

        self.model = PointNetEncoder(global_feat=True, feature_transform=False, channel=input_dim)
        self.fc = torch.nn.Sequential(torch.nn.Linear(1024, 512), torch.nn.BatchNorm1d(512), torch.nn.ReLU(), 
                                      torch.nn.Linear(512, 256), torch.nn.Dropout(p=0.), torch.nn.BatchNorm1d(256), torch.nn.ReLU(),
                                      torch.nn.Linear(256, latent_dim))

        layers = []
        for dim_in, dim_out in zip([2*latent_dim] + mlp_dims, mlp_dims + [1]):
            layers.append(torch.nn.Linear(dim_in, dim_out))
            layers.append(mlp_activation())

        layers = layers[:-1]

        self.mlp = torch.nn.Sequential(*layers)
        self.mlp.apply(init_weights)

    def forward(self, pc_source, pc_target):
        '''
            pc_source : [torch.Tensor] B x M x 2
            pc_target : [torch.Tensor] B x M x 2
        '''
        # Compute embeddings
        emb_source, _, _ = self.model(pc_source.permute(0,2,1)) # B x dim
        emb_source = self.fc(emb_source)
        emb_target, _, _ = self.model(pc_target.permute(0,2,1)) # B x dim
        emb_target = self.fc(emb_target)

        # Compute distances
        dists = 0.5*self.mlp(torch.cat((emb_source, emb_target), 1)) # B x 1
        dists += 0.5*self.mlp(torch.cat((emb_target, emb_source), 1))


        return dists
    

class ChamferNet(torch.nn.Module):
    def __init__(self, input_dim=2, *args, **kwargs):
        super().__init__()

        self.model = chamfer
        
    def forward(self, pc_source, pc_target):
        '''
            pc_source : [torch.Tensor] B x M x 2
            pc_target : [torch.Tensor] B x M x 2
        '''
        
        dists, _, left_idx, _, right_idx = self.model(pc_source, pc_target, return_matching=True)
        matching_attn_mask_s_t = torch.nn.functional.one_hot(left_idx, num_classes=pc_source.shape[1])
        matching_attn_mask_t_s = torch.nn.functional.one_hot(right_idx, num_classes=pc_target.shape[1])

        return dists, matching_attn_mask_s_t, matching_attn_mask_t_s