from chamfer_distance import ChamferDistance
chamfer_dist = ChamferDistance()
import torch

def chamfer_loss(pred, gt):
    dist1, dist2 = chamfer_dist(gt, pred)
    loss = (torch.mean(dist1)) + (torch.mean(dist2))
    return loss

def chamfer_loss_cf(pred, gt, alpha):
    coarse, fine = pred
    coarse_dist1, coarse_dist2 = chamfer_dist(coarse, gt)
    coarse_loss = (torch.mean(coarse_dist1) + (torch.mean(coarse_dist2))) / 2
    fine_dist1, fine_dist2 = chamfer_dist(fine, gt)
    fine_loss = (torch.mean(fine_dist1) + (torch.mean(fine_dist2))) / 2
    
    loss = coarse_loss + alpha * fine_loss
    return loss

