import torch
import torch.nn as nn
from pytorch3d.loss import chamfer_distance
from src.utils.ssim import ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from extension.Chamfer3D.dist_chamfer_3D import chamfer_3DDist
from extension.emd.emd_module import emdFunction


class LossFN(nn.Module):
    def __init__(self, loss_list):
        super().__init__()
        self.loss_list = loss_list

    def forward(self, preds, targets):
        total_loss, info = 0, {}
        for item in self.loss_list:
            name = item['name']
            weight = item['weight']
            loss_fn = item['loss_fn']
            pred = preds[item['pred']]
            target = targets[item['target']]
            loss = loss_fn(pred, target)
            info[name] = loss.item()
            if item.get('used_for_optimization', True):
                total_loss += weight * loss
        info['total_loss'] = total_loss
        return info
    
def knn_point(nsample, xyz, new_xyz):
    """
    Input:
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
    return group_idx

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [N, C]
        dst: target points, [M, C]
    Output:
        dist: per-point square distance, [N, M]
    """
    N, _ = src.shape
    M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(1, 0))
    dist += torch.sum(src ** 2, -1).view(N, 1)
    dist += torch.sum(dst ** 2, -1).view(1, M)
    return dist   


def CoulombPotential(means):
    # Compute the distance between each pair of points
    nsample = 10
    dists = square_distance(means, means) 
    dists = torch.clamp(dists, min=1e-6)
    knn_dist, group_idx = torch.topk(dists, nsample, dim = -1, largest=False, sorted=True)
    knn_dist = knn_dist[:, 1:]
    
    # dist = means.unsqueeze(1) - means.unsqueeze(0) # n x n x 3
    # dists = torch.norm(dist, dim=-1) # n x n
    # dists = torch.clamp(dists, min=1e-6)
    # Compute the Coulomb potential
    potentials = 1 / knn_dist
    potentials = torch.mean(potentials)
    return potentials


class ChamferLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        if len(pred.shape) == 2:
            pred = pred.unsqueeze(0)
        if len(target.shape) == 2:
            target = target.unsqueeze(0)
        if len(pred.shape) == 4:
            pred = pred[:, 0]
        chamfer_loss, _ = chamfer_distance(pred, target)
        return chamfer_loss
    

class InfoChamferLoss(nn.Module):
    def __init__(self, one_side=False):
        super().__init__()
        self.chamfer_dist = chamfer_3DDist()
        self.one_side = one_side
    
    def forward(self, pred, target):
        if len(pred.shape) == 2:
            pred = pred.unsqueeze(0)
        if len(target.shape) == 2:
            target = target.unsqueeze(0)
        if len(pred.shape) == 4:
            pred = pred[:, 0]
        dist1, dist2, idx1, idx2 = self.chamfer_dist(pred, target)
        dist1 = torch.clamp(dist1, min=1e-9)
        d1 = torch.sqrt(dist1)
        distances1 = - torch.log(torch.exp(-0.5 * d1)/(torch.sum(torch.exp(-0.5 * d1) + 1e-7,dim=-1).unsqueeze(-1))**1e-7)
        
        if not self.one_side:
            dist2 = torch.clamp(dist2, min=1e-9)
            d2 = torch.sqrt(dist2)
            distances2 = - torch.log(torch.exp(-0.5 * d2)/(torch.sum(torch.exp(-0.5 * d2) + 1e-7,dim=-1).unsqueeze(-1))**1e-7)
            return (torch.sum(distances1) + torch.sum(distances2)) / 2
        
        return torch.sum(distances1)


def emd_loss(preds, gts, eps=0.005, iters=50):
    loss, _ = emdFunction.apply(preds, gts, eps, iters)
    return torch.sum(loss)

class EMDLoss(nn.Module):
    def __init__(self, eps=0.005, iters=50):
        super().__init__()
        self.eps = eps
        self.iters = iters

    def forward(self, pred, target, **kwargs):
        if len(pred.shape) == 2:
            pred = pred.unsqueeze(0)
        if len(target.shape) == 2:
            target = target.unsqueeze(0)
        if len(pred.shape) == 4:
            pred = pred[:, 0]
        return emd_loss(pred, target, eps=self.eps, iters=self.iters)


class CoulombLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        if len(pred.shape) == 5:
            b, n, c, h, w = pred.shape
            pred = pred.view(-1, c, h, w)
            target = target.view(-1, c, h, w)
        coulomb_loss = CoulombPotential(pred)
        return coulomb_loss

class SSIMLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        if len(pred.shape) == 5:
            b, n, c, h, w = pred.shape
            pred = pred.view(-1, c, h, w)
            target = target.view(-1, c, h, w)
        ssim_loss = 1 - ssim(pred, target)
        return ssim_loss


class LPIPSLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.lpips_loss = LearnedPerceptualImagePatchSimilarity()

    def forward(self, pred, target):
        if len(pred.shape) == 5:
            b, n, c, h, w = pred.shape
            pred = pred.view(-1, c, h, w)
            target = target.view(-1, c, h, w)
        pred = pred * 2 - 1
        target = target * 2 - 1
        return self.lpips_loss(pred, target)