



import os
import torch
import torch.nn.parallel
import concurrent.futures
import numpy as np
import pickle as pkl
from datasets.ycb.dataset_prim_tsn import PoseDataset
import torch.utils.data
import tqdm

cls_num = 22
cls_lst = [i for i in range(cls_num)] 
cls_name = ['002_master_chef_can', '003_cracker_box', '004_sugar_box', '005_tomato_soup_can', '006_mustard_bottle',
            '007_tuna_fish_can', '008_pudding_box', '009_gelatin_box', '010_potted_meat_can', '011_banana',
            '019_pitcher_base', '021_bleach_cleanser', '024_bowl', '025_mug', '035_power_drill', '036_wood_block',
            '037_scissors', '040_large_marker', '051_large_clamp', '052_extra_large_clamp', '061_foam_brick']
log_eval_dir = './'

class YCBEval():

    def __init__(self):
        n_cls = cls_num
        self.n_cls = cls_num
        self.cls_add_dis = [list() for i in range(n_cls)]
        self.cls_adds_dis = [list() for i in range(n_cls)]
        self.cls_add_s_dis = [list() for i in range(n_cls)]
        self.sym_cls_ids = [13, 16, 19, 20, 21]
        self.log_eval_dir = log_eval_dir

    def cal_auc(self):
        add_auc_lst = []
        adds_auc_lst = []
        add_s_auc_lst = []

        add_2cm_lst = []
        adds_2cm_lst = []
        add_s_2cm_lst = []

        for cls_id in range(1, self.n_cls):
            if (cls_id) in self.sym_cls_ids:
                self.cls_add_s_dis[cls_id] = self.cls_adds_dis[cls_id]
            else:
                self.cls_add_s_dis[cls_id] = self.cls_add_dis[cls_id]
            self.cls_add_s_dis[0] += self.cls_add_s_dis[cls_id]
        for i in range(self.n_cls):
            add_auc, add_2cm = cal_auc(self.cls_add_dis[i], max_dis=0.1, thr_m=0.02)
            adds_auc, adds_2cm = cal_auc(self.cls_adds_dis[i], max_dis=0.1, thr_m=0.02)
            add_s_auc, add_s_2cm = cal_auc(self.cls_add_s_dis[i], max_dis=0.1, thr_m=0.02)
            add_auc_lst.append(add_auc)
            adds_auc_lst.append(adds_auc)
            add_s_auc_lst.append(add_s_auc)

            add_2cm_lst.append(add_2cm)
            adds_2cm_lst.append(adds_2cm)
            add_s_2cm_lst.append(add_s_2cm)
            if i == 0:
                continue
            print(cls_name[i-1])
            print("**** add: {:.2f}, adds: {:.2f}, add(-s): {:.2f}".format(add_auc, adds_auc, add_s_auc))
            print("<2cm add: {:.2f}, adds: {:.2f}, add(-s): {:.2f}".format(add_2cm, adds_2cm, add_s_2cm))

        print("Average of all object:")
        print("**** add: {:.2f}, adds: {:.2f}, add(-s): {:.2f}".format(np.mean(add_auc_lst[1:]), np.mean(adds_auc_lst[1:]), np.mean(add_s_auc_lst[1:])))
        print("<2cm add: {:.2f}, adds: {:.2f}, add(-s): {:.2f}".format(np.mean(add_2cm_lst[1:]), np.mean(adds_2cm_lst[1:]), np.mean(add_s_2cm_lst[1:])))

        print("All object (following PoseCNN):")
        print("**** add: {:.2f}, adds: {:.2f}, add(-s): {:.2f}".format(add_auc_lst[0], adds_auc_lst[0], add_s_auc_lst[0]))
        print("<2cm add: {:.2f}, adds: {:.2f}, add(-s): {:.2f}".format(add_2cm_lst[0], adds_2cm_lst[0], add_s_2cm_lst[0]))

        sv_info = dict(
            add_dis_lst = self.cls_add_dis,
            adds_dis_lst = self.cls_adds_dis,
            add_auc_lst = add_auc_lst,
            adds_auc_lst = adds_auc_lst,
            add_s_auc_lst = add_s_auc_lst,
        )
        sv_pth = os.path.join(
            self.log_eval_dir,
            'pvn3d_eval_cuda_{}_{}_{}.pkl'.format(
                adds_auc_lst[0], add_auc_lst[0], add_s_auc_lst[0]
            )
        )
        
        return {'auc': add_s_auc_lst[0]}

    def eval_pose_parallel(self, pred_RT_lst, pred_clsID_lst, gt_RT_lst, gt_clsID_lst, models_pts_lst):
        bs = len(pred_clsID_lst)
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=bs
        ) as executor:
            for res in executor.map(eval_metric, pred_RT_lst, pred_clsID_lst, gt_RT_lst, gt_clsID_lst, models_pts_lst):
                cls_add_dis_lst, cls_adds_dis_lst = res
                self.cls_add_dis = self.merge_lst(
                    self.cls_add_dis, cls_add_dis_lst
                )
                self.cls_adds_dis = self.merge_lst(
                    self.cls_adds_dis, cls_adds_dis_lst
                )

    def merge_lst(self, targ, src):
        for i in range(len(targ)):
            targ[i] += src[i]
        return targ

def eval_metric(pred_RT, pred_cls_id, gt_RT, gt_cls_id, models_pts):
    
    n_cls = 22
    cls_add_dis = [list() for i in range(n_cls)]
    cls_adds_dis = [list() for i in range(n_cls)]
    if pred_cls_id[0] == 0 or pred_cls_id[0] != gt_cls_id[0]:
        return cls_add_dis, cls_adds_dis

    pred_RT = torch.from_numpy(pred_RT.astype(np.float32))
    gt_RT = torch.from_numpy(gt_RT.astype(np.float32))

    pt = pred_RT[:, 3]
    R_pre = pred_RT[:, :3]

    t_tar = gt_RT[:, 3]
    R_tar = gt_RT[:, :3]

    
    obj_grps = models_pts
    add_j = torch.zeros((len(obj_grps))).cuda()
    for j, grp in enumerate(obj_grps):
        _, num_p = grp.size()

        npt = pt.unsqueeze(dim=1).repeat(1, num_p)  
        ntt = t_tar.unsqueeze(dim=1).repeat(1, num_p).contiguous()  
        pred = R_pre @ grp + npt  
        targ = R_tar @ grp + ntt  

        pred = pred.unsqueeze(dim=0).repeat(num_p, 1, 1).contiguous()  
        targ = targ.permute(1, 0).unsqueeze(dim=2).repeat(1, 1, num_p).contiguous()  
        min_dist, _ = torch.min(torch.norm(pred - targ, dim=1), dim=1)  

        if len(obj_grps) == 3 and j == 2:
            
            add_j[j] = torch.max(min_dist, dim=0)[0]  
        else:
            add_j[j] = torch.mean(min_dist, dim=0)  

    
    if len(obj_grps) == 3 and obj_grps[2].size()[1] > 1:
        add_i = torch.max(add_j, dim=0)[0]  
    else:
        add_i = torch.mean(add_j, dim=0)  

    add = add_i
    adds = add_i
    cls_add_dis[pred_cls_id[0]].append(add.item())
    cls_adds_dis[pred_cls_id[0]].append(adds.item())
    cls_add_dis[0].append(add.item())
    cls_adds_dis[0].append(adds.item())

    return cls_add_dis, cls_adds_dis

def cal_add_cuda(pred_RT, gt_RT, p3ds):
    _, N = p3ds.size()
    pred_p3ds = torch.mm(pred_RT[:, :3], p3ds) + pred_RT[:, 3].view(3, 1).repeat(1, N)
    gt_p3ds = torch.mm(gt_RT[:, :3], p3ds) + gt_RT[:, 3].view(3, 1).repeat(1, N)
    dis = torch.norm(pred_p3ds - gt_p3ds, dim=0)
    return torch.mean(dis)

def cal_adds_cuda(pred_RT, gt_RT, p3ds):
    _, N = p3ds.size()
    pd = torch.mm(pred_RT[:, :3], p3ds) + pred_RT[:, 3].view(3, 1).repeat(1, N)
    pd = pd.view(1, 3, N).repeat(N, 1, 1).permute(2, 1, 0)
    gt = torch.mm(gt_RT[:, :3], p3ds) + gt_RT[:, 3].view(3, 1).repeat(1, N)
    gt = gt.view(1, 3, N).repeat(N, 1, 1)
    dis = torch.norm(pd - gt, dim=1)
    mdis = torch.min(dis, dim=1)[0]
    return torch.mean(mdis)

def cal_auc(add_dis, max_dis=0.1, thr_m = 0.02):
    D = np.array(add_dis)
    D[np.where(D > max_dis)] = np.inf
    D = np.sort(D)
    n = len(add_dis)
    acc = np.cumsum(np.ones((1, n)), dtype=np.float32) / n
    aps = VOCap(D, acc)

    add_t_cm = np.where(D < thr_m)[0].size / D.size

    return aps * 100, add_t_cm * 100

def VOCap(rec, prec):
    idx = np.where(rec != np.inf)
    if len(idx[0]) == 0:
        return 0
    rec = rec[idx]
    prec = prec[idx]
    mrec = np.array([0.0]+list(rec)+[0.1])
    mpre = np.array([0.0]+list(prec)+[prec[-1]])
    for i in range(1, prec.shape[0]):
        mpre[i] = max(mpre[i], mpre[i-1])
    i = np.where(mrec[1:] != mrec[0:-1])[0] + 1
    ap = np.sum((mrec[i] - mrec[i-1]) * mpre[i]) * 10
    return ap


























