import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy


class ChamferDistance(nn.Module):
    def __init__(self):
        """
        Compute chamfer distance between point sets x and y.
        """
        super(ChamferDistance, self).__init__()

    @staticmethod
    def batch_pairwise_dist(x, y):
        xx = x.pow(2).sum(dim=-1)  # (B, N)
        yy = y.pow(2).sum(dim=-1)  # (B, M)
        zz = torch.matmul(x, y.transpose(2, 1))  # (B, N, M)
        rx = xx.unsqueeze(-1).expand_as(zz)  # (B, N) -> (B, N, 1) -> (B, N, M)
        ry = yy.unsqueeze(1).expand_as(zz)  # （B, M) ->（B, 1, M) ->  (B, N, M)
        P = (rx + ry - 2 * zz)  # (B, N, M)
        return P

    def forward(self, x, y):
        """
        :param x: tensor with size of (B, N, 3)
        :param y: tensor with size of (B, M, 3)
        :return: a scalar, Chamfer distance between point clouds x and y.
        """
        Dist = self.batch_pairwise_dist(x, y)  # (B, N, M)
        dist1, _ = torch.min(Dist, dim=2)  # (B, N)
        # mean_dist1 = dist1.mean(dim=-1, keepdim=True)  # (B, 1)
        dist2, _ = torch.min(Dist, dim=1)  # (B, M)
        # mean_dist2 = dist2.mean(dim=-1, keepdim=True)  # (B, 1)

        # return (mean_dist1 + mean_dist2).sum()
        return dist1.sum() + dist2.sum()


class KLDivergence(nn.Module):
    def __init__(self):
        """
        KL Divergence between N(mu, sigma^2) and N(0, 1).
        """
        super(KLDivergence, self).__init__()

    @staticmethod
    def forward(mu, log_var):
        """
        :param mu: mean tensor with size of (B, feat_dims, 1)
        :param log_var: log(sigma^2) tensor with size of (B, feat_dims, 1)
        :return: a scalar tensor of KL Divergence
        """
        mu = mu.squeeze()
        log_var = log_var.squeeze()
        KLD = torch.mean(0.5 * torch.sum(mu.pow(2) + log_var.exp() - log_var - 1, dim=1), dim=0)
        return KLD


class ULIPContrastiveLoss(nn.Module):
    def __init__(self):
        super(ULIPContrastiveLoss, self).__init__()

    @staticmethod
    def forward(logits_per_pc_text, logits_per_pc_image):
        """
        :param logits_per_pc_text: logits array with size of [B, B], which row is point cloud and column is text.
        :param logits_per_pc_image: logits array with size of [B, B], which row is point cloud and column is image.
        :return: a dictionary with keys 'loss', 'pc_text_acc', 'pc_image_acc'.
        """
        batch_size = logits_per_pc_text.shape[0]
        device = logits_per_pc_text.device
        labels = torch.arange(batch_size, device=device, dtype=torch.long)

        logits_per_text_pc = logits_per_pc_text.T
        logits_per_image_pc = logits_per_pc_image.T
        loss1 = (cross_entropy(logits_per_pc_text, labels) + cross_entropy(logits_per_text_pc, labels)) / 2
        loss2 = (cross_entropy(logits_per_pc_image, labels) + cross_entropy(logits_per_image_pc, labels)) / 2
        loss = loss1 + loss2

        # compute accuracy
        with torch.no_grad():
            pred = torch.argmax(logits_per_pc_text, dim=-1)
            correct = pred.eq(labels).sum()
            pc_text_acc = 100 * correct / batch_size

            pred = torch.argmax(logits_per_pc_image, dim=-1)
            correct = pred.eq(labels).sum()
            pc_img_acc = 100 * correct / batch_size

        return {'loss': loss, 'pc_text_acc': pc_text_acc, 'pc_img_acc': pc_img_acc}


class MultimodalSupConLoss(nn.Module):
    """
    This loss is the supervised version of ULIPContrastiveLoss.
    Although different point clouds in the same category have different geometric structures in detail, they share the
    same category information. So we can use the SupCon to improve the original ULIP loss.
    """
    def __init__(self):
        super(MultimodalSupConLoss, self).__init__()

    @staticmethod
    def compute_cross_entropy(p, q):
        q = nn.functional.log_softmax(q, dim=-1)
        loss = torch.sum(p * q, dim=-1)
        return - loss.mean()

    def forward(self, logits_per_pc_text, logits_per_pc_image, category):
        """
        :param logits_per_pc_text: logits array with size of [B, B], which row is point cloud and column is text.
        :param logits_per_pc_image: logits array with size of [B, B], which row is point cloud and column is image.
        :param category: category label array with size of [B], each element is the category of the point cloud.
        :return: a dictionary with keys 'loss', 'pc_text_acc', 'pc_image_acc'.
        """
        batch_size = logits_per_pc_text.shape[0]
        device = logits_per_pc_text.device
        labels = torch.arange(batch_size, device=device, dtype=torch.long)

        # Building mask matrix for pc-text contrastive loss
        mask = torch.eq(category.view(-1, 1), category.view(1, -1)).float().to(device)
        prob = mask / mask.sum(1, keepdim=True).clamp(min=1.0)

        loss1 = (self.compute_cross_entropy(prob, logits_per_pc_text) + self.compute_cross_entropy(prob, logits_per_pc_text.T)) / 2
        loss2 = (cross_entropy(logits_per_pc_image, labels) + cross_entropy(logits_per_pc_image.T, labels)) / 2
        loss = loss1 + loss2

        # compute accuracy
        with torch.no_grad():
            pred = torch.argmax(logits_per_pc_text, dim=-1)
            correct = mask[labels, list(torch.unbind(pred))].eq(torch.ones(1, dtype=mask.dtype, device=device)).sum()
            pc_text_acc = 100 * correct / batch_size

            pred = torch.argmax(logits_per_pc_image, dim=-1)
            correct = pred.eq(labels).sum()
            pc_img_acc = 100 * correct / batch_size

        return {'loss': loss, 'pc_text_acc': pc_text_acc, 'pc_img_acc': pc_img_acc}


if __name__ == '__main__':
    logits1 = torch.randn(16, 16)
    logits2 = torch.randn(16, 16)
    labels = torch.randint(low=0, high=5, size=[16])
    logits1 /= logits1.norm(p=2, dim=1, keepdim=True)
    logits2 /= logits2.norm(p=2, dim=1, keepdim=True)
    loss = MultimodalSupConLoss()
    loss_dict = loss(logits1, logits2, labels)
    print(loss_dict['loss'], loss_dict['pc_text_acc'], loss_dict['pc_img_acc'])
