from torchmetrics import Metric
import torch

class Chamfer_FScore(Metric):
    def __init__(self, dist_sync_on_step=False, p=2, **kwargs):
        super().__init__(dist_sync_on_step=dist_sync_on_step, **kwargs)
        self.add_state("chamfer", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("f_tau", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("f_2tau", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("p_tau", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("r_tau", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("p_2tau", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("r_2tau", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_pcs", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.p = p

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        # cdist = torch.cdist(preds, target)
        BS = len(preds)
        for i in range(BS):
            # print(preds[i].shape)
            assert(len(preds[i].shape) == 2)
            if preds[i].shape[0] == 0:
                # print("empty pred")
                self.chamfer += 1.0
            else:
                # import pdb; pdb.set_trace()
                cdist = torch.cdist(preds[i], target[i], p=self.p) #(P, T)
                min_dist1 = cdist.min(axis=0)[0] #(T)
                min_dist2 = cdist.min(axis=1)[0] #(P)
                self.chamfer += chamfer(min_dist1, min_dist2)
                f_tau, p_tau, r_tau = fscore(min_dist1, min_dist2, tau=3e-2)
                self.f_tau += f_tau
                self.p_tau += p_tau
                self.r_tau += r_tau
                f_2tau, p_2tau, r_2tau = fscore(min_dist1, min_dist2, tau=6e-2)
                self.f_2tau += f_2tau
                self.p_2tau += p_2tau
                self.r_2tau += r_2tau
        self.num_pcs += BS

    def compute(self):
        avg_chamfer = self.chamfer/self.num_pcs
        avg_f_tau = self.f_tau/self.num_pcs
        avg_p_tau = self.p_tau/self.num_pcs
        avg_r_tau = self.r_tau/self.num_pcs
        avg_f_2tau = self.f_2tau/self.num_pcs
        avg_p_2tau = self.p_2tau/self.num_pcs
        avg_r_2tau = self.r_2tau/self.num_pcs
        return (avg_chamfer, avg_f_tau, avg_p_tau, avg_r_tau, avg_f_2tau, avg_p_2tau, avg_r_2tau)

    def returned_metrics(self):
        return ['chamfer', 'f_tau', 'p_tau', 'r_tau', 'f_2tau', 'p_2tau', 'r_2tau']

def chamfer(min_dist1, min_dist2):
    # cdist = torch.cdist(pc1, pc2, p=2.0)
    chamfer = 0.5 * torch.mean(min_dist1) + 0.5 * torch.mean(min_dist2)
    return chamfer

def fscore(min_dist1, min_dist2, tau=1e-4):
    precision = (min_dist1 < tau).sum().float() / min_dist1.numel()
    recall = (min_dist2 < tau).sum().float() / min_dist2.numel()
    if (precision + recall) == 0:
        zero_tensor = torch.tensor(0.).to(min_dist1.device)
        return zero_tensor, zero_tensor, zero_tensor
    return 2 * precision * recall / (precision + recall), precision, recall

