# -*- coding: utf-8 -*-
# @Author: Thibault GROUEIX
# @Date:   2019-08-07 20:54:24
# @Last Modified by:   Haozhe Xie
# @Last Modified time: 2019-12-18 15:06:25
# @Email:  cshzxie@gmail.com

import torch
import chamfer
### adds
from SEINT.SEINT_torch import SEINT_batch_vmap
class ChamferFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xyz1, xyz2):
        dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2)
        ctx.save_for_backward(xyz1, xyz2, idx1, idx2)

        return dist1, dist2

    @staticmethod
    def backward(ctx, grad_dist1, grad_dist2):
        xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
        grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
        return grad_xyz1, grad_xyz2


class ChamferDistanceL2(torch.nn.Module):
    f''' Chamder Distance L2
    '''
    def __init__(self, ignore_zeros=False):
        super().__init__()
        self.ignore_zeros = ignore_zeros

    def forward(self, xyz1, xyz2):
        batch_size = xyz1.size(0)
        if batch_size == 1 and self.ignore_zeros:
            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)

        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
        return torch.mean(dist1) + torch.mean(dist2)

class ChamferDistanceL2_split(torch.nn.Module):
    f''' Chamder Distance L2
    '''
    def __init__(self, ignore_zeros=False):
        super().__init__()
        self.ignore_zeros = ignore_zeros

    def forward(self, xyz1, xyz2):
        batch_size = xyz1.size(0)
        if batch_size == 1 and self.ignore_zeros:
            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)

        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
        return torch.mean(dist1), torch.mean(dist2)

class ChamferDistanceL1(torch.nn.Module):
    f''' Chamder Distance L1
    '''
    def __init__(self, ignore_zeros=False):
        super().__init__()
        self.ignore_zeros = ignore_zeros

    def forward(self, xyz1, xyz2):
        batch_size = xyz1.size(0)
        if batch_size == 1 and self.ignore_zeros:
            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)

        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
        # import pdb
        # pdb.set_trace()
        dist1 = torch.sqrt(dist1)
        dist2 = torch.sqrt(dist2)
        return (torch.mean(dist1) + torch.mean(dist2))/2


### adds

class CD2_fused_ISEINT(torch.nn.Module):
    f''' Chamfer Distance L2 + ISEINT '''
    def __init__(self, ignore_zeros=False, iseint_rep=100, rd_rad = 3, maxed = False):
        super().__init__()
        self.ignore_zeros = ignore_zeros
        self.rd_rad = rd_rad
        self.maxed = maxed
        self.iseint = SEINT_DISTANCE(rep=iseint_rep, rd_rad = rd_rad, maxed = self.maxed)

    def forward(self, xyz1, xyz2):
        batch_size = xyz1.size(0)
        if batch_size == 1 and self.ignore_zeros:
            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)

        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
        chamfer_loss = torch.mean(dist1) + torch.mean(dist2)

        iseint_loss = self.iseint(xyz1, xyz2) 
        return 0.9* chamfer_loss +  0.1 * iseint_loss

        

class CD2_fused_SEINT(torch.nn.Module):
    f''' Chamfer Distance L2 + SEINT '''
    def __init__(self, ignore_zeros=False, seint_rep=100, rd_rad = 3, maxed = True):
        super().__init__()
        self.ignore_zeros = ignore_zeros
        self.rd_rad = rd_rad
        self.maxed = maxed
        self.seint = SEINT_DISTANCE(rep=seint_rep, rd_rad = rd_rad, maxed = self.maxed)

    def forward(self, xyz1, xyz2):
        batch_size = xyz1.size(0)
        if batch_size == 1 and self.ignore_zeros:
            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)

        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
        chamfer_loss = torch.mean(dist1) + torch.mean(dist2)

        seint_loss = self.seint(xyz1, xyz2) 
        return 0.9* chamfer_loss +  0.1 * seint_loss


class SEINT_DISTANCE(torch.nn.Module):
    def __init__(self, rep=100,rd_rad = None,maxed = True):
        super().__init__()
        self.rep = rep
        self.rd_rad = rd_rad
        self.maxed = maxed

    def forward(self, xyz1, xyz2):
        """
        xyz1: [B, N, 3]
        xyz2: [B, N, 3]
        return: scalar SEINT loss (float)
        """
        loss = torch.mean(SEINT_batch_vmap(xyz1,xyz2,rep = self.rep ,rd_rad = self.rd_rad, maxed = self.maxed))/500000
        return loss


