import torch
import open3d as o3d

from extensions.chamfer_distance.chamfer_distance import ChamferDistance
import numpy as np
# from extensions.earth_movers_distance.emd import EarthMoverDistance


CD = ChamferDistance()
# EMD = EarthMoverDistance()


def l2_cd(pcs1, pcs2):
    dist1, dist2 = CD(pcs1, pcs2)
    dist1 = torch.mean(dist1, dim=1)
    dist2 = torch.mean(dist2, dim=1)
    return torch.sum(dist1 + dist2)


def l1_cd(pcs1, pcs2):
    dist1, dist2 = CD(pcs1, pcs2)
    dist1 = torch.mean(torch.sqrt(dist1), 1)
    dist2 = torch.mean(torch.sqrt(dist2), 1)
    return torch.sum(dist1 + dist2) / 2


# def emd(pcs1, pcs2):
#     dists = EMD(pcs1, pcs2)
#     return torch.sum(dists)

def dense_score(pred, gt, th=0.01):
    """
    References: https://github.com/lmb-freiburg/what3d/blob/master/util.py

    Args:
        pred (np.ndarray): (N1, 3)
        gt   (np.ndarray): (N2, 3)
        th   (float): a distance threshhold
    """
    pred = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pred))
    gt = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(gt))

    tree=o3d.geometry.KDTreeFlann(gt)
    treepred=o3d.geometry.KDTreeFlann(pred)
    pred_dense,gt_dense=[],[]
    for query_point in gt.points:
        [k,idx,_]=tree.search_radius_vector_3d(query_point,0.02)
        num_neighbors=len(idx)
        gt_dense.append(num_neighbors)
        [k, idx, _] = treepred.search_radius_vector_3d(query_point, 0.02)
        num_neighbors = len(idx)
        pred_dense.append(num_neighbors)
    dense_mae = float(sum((gt_dense[i]-pred_dense[i])**2 for i in range(len(gt_dense)))**0.5)/(float(len(gt_dense))**0.5)
    # print(dense_mae)
    return dense_mae

def f_score(pred, gt, th=0.01):
    """
    References: https://github.com/lmb-freiburg/what3d/blob/master/util.py

    Args:
        pred (np.ndarray): (N1, 3)
        gt   (np.ndarray): (N2, 3)
        th   (float): a distance threshhold
    """
    pred = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pred))
    gt = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(gt))

    dist1 = pred.compute_point_cloud_distance(gt)
    dist2 = gt.compute_point_cloud_distance(pred)

    recall = float(sum(d < th for d in dist2)) / float(len(dist2))
    precision = float(sum(d < th for d in dist1)) / float(len(dist1))
    return 2 * recall * precision / (recall + precision) if recall + precision else 0
