'''
@Author: Wenhao Ding
@Email: wenhaod@andrew.cmu.edu
@Date: 2019-11-18 22:13:22
LastEditTime: 2021-05-05 23:01:55
@Description:
'''

import torch
import torch.nn as nn


class DiceLoss(nn.Module):
    ''' Dice coefficient loss function to maximize IoU '''
    def __init__(self, ignore_target=-1):
        super().__init__()
        self.ignore_target = ignore_target

    def forward(self, predict, target):
        """
        :param predict: (N), logit
        :param target: (N), {0, 1}
        """
        predict = torch.sigmoid(predict)
        #mask = (target != self.ignore_target).float()
        #return 1.0 - (torch.min(predict, target) * mask).sum() / torch.clamp((torch.max(predict, target) * mask).sum(), min=1.0)
        return 1.0 - (torch.min(predict, target)).sum() / torch.clamp((torch.max(predict, target)).sum(), min=1.0)


class CrosEntropyLoss(nn.Module):
    def __init__(self, number_class=3):
        super().__init__()
        self.number_class = number_class

    def forward(self, predict, target):
        """
        :param predict: (N, number_class) 
        :param target: (N), {0, ..., number_class-1}
        """

        predict = predict.view(-1, self.number_class)
        target = target.view(-1, 1)[:, 0]
        return nn.functional.nll_loss(predict, target)


class MSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, predict, target):
        return nn.functional.mse_loss(predict, target)


class MeanMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, predict, target):
        mean_1 = predict[0].mean(dim=0)
        mean_2 = target[0].mean(dim=0)
        return nn.functional.mse_loss(mean_1, mean_2)


    """ This is a straightforward implementation but consumes too much memory. 
        Not recommended.
    """
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    @staticmethod
    def chamfer_distance_without_batch(p1, p2):
        """ Calculate Chamfer Distance between two point sets.
            A simple implementation of Chamfer distance from https://gist.github.com/WangZixuan/4c4cdf49ce9989175e94524afc946726
        Args:
            p1: [1, N, D]
            p2: [1, M, D]
        Return: 
            dist: sum of Chamfer Distance of two point sets
        """
        assert p1.size(0) == 1 and p2.size(0) == 1
        assert p1.size(2) == p2.size(2)

        p1 = p1.repeat(p2.size(1), 1, 1)
        p1 = p1.transpose(0, 1)
        p2 = p2.repeat(p1.size(0), 1, 1)
        dist = torch.add(p1, torch.neg(p2))
        dist = torch.norm(dist, 2, dim=2)
        dist = torch.min(dist, dim=1)[0]
        dist = torch.mean(dist)
        return dist

    @staticmethod
    def chamfer_distance_with_batch(p1, p2, debug):
        """ Calculate Chamfer Distance between two point sets.
            A simple implementation of Chamfer distance from https://gist.github.com/WangZixuan/4c4cdf49ce9989175e94524afc946726
        Args:
            p1: [B, N, D]
            p2: [B, M, D]
        Return: 
            dist: sum of all batches of Chamfer Distance of two point sets
        """
        assert p1.size(0) == p2.size(0) and p1.size(2) == p2.size(2)
        p1 = p1.unsqueeze(1)
        p2 = p2.unsqueeze(1)
        p1 = p1.repeat(1, p2.size(2), 1, 1)
        p1 = p1.transpose(1, 2)
        p2 = p2.repeat(1, p1.size(1), 1, 1)
        dist = torch.add(p1, torch.neg(p2))
        dist = torch.norm(dist, 2, dim=3)
        dist = torch.min(dist, dim=2)[0]
        dist = torch.mean(dist)
        return dist

    def forward(self, predict, target):
        # DEBUG: use the distance between mean points
        predict_mean = torch.mean(predict[0][:, 0:2], dim=0)
        groundtruth_mean = torch.mean(target[0][:, 0:2], dim=0)
        #predict_mean.register_hook(lambda grad: print('gradient of predict_mean', grad)) 
        #print(CPU(predict_mean), CPU(groundtruth_mean))

        #all_dist = torch.sum((predict_mean - groundtruth_mean)**2)
        #all_dist = nn.functional.mse_loss(predict_mean, groundtruth_mean)
        predict_mean.retain_grad()

        # the input of chamfer distance should [1, N, 3] and [1, M, 3]
        all_dist = self.chamfer_distance_without_batch(predict[0][None], target[0][None])
        return all_dist, predict_mean, groundtruth_mean

    """ Code from https://github.com/ThibaultGROUEIX/ChamferDistancePytorch
    """
    def __init__(self):
        super().__init__()
        self.chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()

    def forward(self, predict, target):   
        dist1, dist2, idx1, idx2 = self.chamLoss(predict, target)
        loss = dist1.mean() + dist2.mean()
        return loss