import torch
import math
import os
import sys
import torch.nn as nn
import torch.nn.functional as F

proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
from pcdet.ops.pointnet2.pointnet2_stack import pointnet2_modules as pn2 
sys.path.append(os.path.join(proj_dir, "utils/emd"))
import emd_module as emd
sys.path.append(os.path.join(proj_dir, "utils/ChamferDistancePytorch"))
from chamfer3D import dist_chamfer_3D
from fscore import fscore


def calc_cd(output, gt, calc_f1=False, return_raw=False, normalize=False, separate=False, f1_thr=0.15):
    cham_loss = dist_chamfer_3D.chamfer_3DDist()
    dist1, dist2, idx1, idx2 = cham_loss(gt, output)
    cd_p = (torch.sqrt(dist1).mean(1) + torch.sqrt(dist2).mean(1)) / 2
    cd_t = (dist1.mean(1) + dist2.mean(1))

    if separate:
        res = [torch.cat([torch.sqrt(dist1).mean(1).unsqueeze(0), torch.sqrt(dist2).mean(1).unsqueeze(0)]),
               torch.cat([dist1.mean(1).unsqueeze(0),dist2.mean(1).unsqueeze(0)])]
    else:
        res = [cd_p, cd_t]
    if calc_f1:
        f1, _, _ = fscore(torch.sqrt(dist1), torch.sqrt(dist2), f1_thr)
        res.append(f1)
    if return_raw:
        res.extend([dist1, dist2, idx1, idx2])
    return res


def calc_cd_full(x, gt, T=1000, n_p=1, return_raw=False, separate=False, return_freq=False, non_reg=False, f1_thr=0.15):
    x = x.float()
    gt = gt.float()
    batch_size, n_x, _ = x.shape
    batch_size, n_gt, _ = gt.shape
    assert x.shape[0] == gt.shape[0]

    if non_reg:
        frac_12 = max(1, n_x / n_gt)
        frac_21 = max(1, n_gt / n_x)
    else:
        frac_12 = n_x / n_gt
        frac_21 = n_gt / n_x

    cd_p, cd_t, f1, dist1, dist2, idx1, idx2 = calc_cd(x, gt, calc_f1=True, return_raw=True, separate=separate, f1_thr=f1_thr)
    exp_dist1, exp_dist2 = torch.exp(-dist1 * T), torch.exp(-dist2 * T)

    hd, _ = torch.max(torch.sqrt(torch.cat((dist1,dist2),dim=1)),dim=1)

    loss1 = []
    loss2 = []
    gt_counted = []
    x_counted = []

    for b in range(batch_size):
        count1 = torch.bincount(idx1[b])
        weight1 = count1[idx1[b].long()].float().detach() ** n_p
        weight1 = (weight1 + 1e-6) ** (-1) * frac_21
        loss1.append((- exp_dist1[b] * weight1 + 1.).mean())

        count2 = torch.bincount(idx2[b])
        weight2 = count2[idx2[b].long()].float().detach() ** n_p
        weight2 = (weight2 + 1e-6) ** (-1) * frac_12
        loss2.append((- exp_dist2[b] * weight2 + 1.).mean())

        if return_freq:
            expand_count1 = torch.zeros_like(idx2[b])  # n_x
            expand_count1[:count1.shape[0]] = count1
            x_counted.append(expand_count1)
            expand_count2 = torch.zeros_like(idx1[b])  # n_gt
            expand_count2[:count2.shape[0]] = count2
            gt_counted.append(expand_count2)

    loss1 = torch.stack(loss1)
    loss2 = torch.stack(loss2)
    loss = (loss1 + loss2) / 2  # density-aware cd

    if separate:
        res = [hd, f1, torch.cat([loss1.unsqueeze(0), loss2.unsqueeze(0)]), cd_p, cd_t]
    else:
        res = [hd, f1, loss, cd_p, cd_t]
    if return_raw:
        res.extend([dist1, dist2, idx1, idx2])
    if return_freq:
        x_counted = torch.stack(x_counted)
        gt_counted = torch.stack(gt_counted)
        res.extend([x_counted, gt_counted])
    return res

