import os
import cv2
import json
import torch
import numpy as np
from tqdm import tqdm, trange
from collections import defaultdict

from scipy.stats import rankdata
from scipy.stats import pearsonr
from .plot_skeleton import draw_skeleton_coco, draw_skeleton_crowdpose, draw_skeleton_exlpose
from .oks_utils import oks_one, OKS, ause_oks


class visualizekpts():

    def __init__(self, dataset_name, subset_name, dataset_type, gt_path, img_dir, dt_path=None, method=None, dt_model_name='unknown'):
        '''
        dataset_name (str): the name of the dataset. should be choosen from coco, ochuman, crowdpose and exlpose
        dataset_type (str): identify train, val or test dataset. should also be the name of the image folders for corresponding dataset
        gt_path (str): the path of the json file of the gts. should be 'x/results/xxx.json'
        dt_path (str): the path of the json file of the dts. should be 'x/results/xxx.json'
        '''
        super(visualizekpts, self).__init__()
        self.dataset_name = dataset_name
        self.subset_name = subset_name
        self.dataset_type = dataset_type
        self.gt_path = gt_path
        self.img_dir = img_dir

        self.dt_path = dt_path
        self.method = method
        self.model_name= dt_model_name

        if dataset_name == 'crowdpose' or dataset_name == 'exlpose' or dataset_name == 'exlpose-ocn':
            from crowdposetools.coco import COCO
            from crowdposetools.cocoeval import COCOeval
            self.num_joints = 14
            self.draw_skeleton = draw_skeleton_crowdpose if dataset_name == 'crowdpose' else draw_skeleton_exlpose 
        else:
            print("init coco ...")
            from pycocotools.coco import COCO
            from pycocotools.cocoeval import COCOeval
            self.num_joints = 17
            self.draw_skeleton = draw_skeleton_coco

        self.cocoGt = COCO(self.gt_path)
        self.img_ids = list(self.cocoGt.imgs.keys())
        self.img_ids = [
            img_id
            for img_id in self.img_ids
            if len(self.cocoGt.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
        ]

        if dt_path:
            # self.dt_path = os.path.join('DT_json', dt_name)
            self.cocoDt = self.cocoGt.loadRes(self.dt_path)
            self.cocoeval = COCOeval(self.cocoGt, self.cocoDt, 'keypoints')

        self.data_dir = self.img_dir

    def visualizeGT(self, outdir=None):
        if not outdir:
            outdir = './visualize/{}'.format(self.gt_name.split('.')[0])
        if not os.path.exists(outdir):
            os.mkdir(outdir)

        bar = tqdm(self.img_ids)
        for image_id in bar:
            file_name = self.cocoGt.loadImgs(image_id)[0]['file_name']
            # print(image_id, file_name)
            file_path = os.path.join(self.data_dir, file_name)
            img = cv2.imread(file_path)

            kpts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(image_id))
            if len(kpts) == 0:
                continue
            kpts = [kpt['keypoints'] for kpt in kpts]
            kpts = np.array(kpts).reshape((-1, self.num_joints, 3))
            
        
            skeleton = self.draw_skeleton(img, kpts)
            save_path = os.path.join(outdir, file_name)
            cv2.imwrite(save_path, skeleton)

    def visualizeDT(self, score_thres=0.0, outdir=None, num_sample=None, brighten_up=False):
        if not self.dt_path:
            raise('please intialize the class with the path of the detection json.')

        if not outdir:
           outdir='./visualize/{}'.format(self.dt_name.split('.')[0])

        if not os.path.exists(outdir):
            os.mkdir(outdir)

        if num_sample and len(self.img_ids) > num_sample:
            image_ids = np.random.permutation(np.array(self.img_ids).astype(np.int32))[:num_sample]
            image_ids = image_ids.tolist()
        else:
            image_ids = self.img_ids

        bar = tqdm(image_ids)
        for image_id in bar:
            file_name = self.cocoGt.loadImgs((image_id))[0]['file_name']
            file_path = os.path.join(self.data_dir, file_name)
            img = cv2.imread(file_path)

            if brighten_up:
                # Calculate the average pixel intensity
                img = img.astype(np.float32)
                average_intensity = np.mean(img)

                # Compute the scaling factor
                scale_factor = 128 / average_intensity

                # Scale up the image by multiplying each pixel with the scaling factor
                img = np.clip(img * scale_factor, 0, 255).astype(np.uint8)

            kpts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(image_id))
            kpts = [kpt['keypoints'] for kpt in kpts if kpt['score'] > score_thres]
            kpts = np.array(kpts).reshape((-1, self.num_joints, 3))

            skeleton = self.draw_skeleton(img, kpts)
            save_path = os.path.join(outdir, f'{image_id}.jpg')

            if skeleton.shape[0] > 1024 or skeleton.shape[1] > 1024:
                skeleton = cv2.resize(skeleton, (0,0), fx=0.5, fy=0.5)
            cv2.imwrite(save_path, skeleton)

    def _get_gt_score_list(self, oks_matrix, conf_s):
            # oks_matrix DT x GT
            gt_match_ids = oks_matrix.argmax(0)
            gt_match_scores = oks_matrix.max(0)
            match_list = []

            for gt_id, (det_id, oks) in enumerate(zip(gt_match_ids, gt_match_scores)):
                match_list.append([det_id, oks, conf_s[det_id]])
            return match_list

    def analyze_one_image_DT(self, img_id=None, outdir=None):
        if not self.dt_path:
            raise('please intialize the class with the name of the detection json.')
        if not img_id:
            raise('please input the index of the image.')

        if not outdir:
           outdir='./visualize/{}_{}_{}'.format(self.dataset_name, self.dataset_type, img_id)

        if not os.path.exists(outdir):
            os.mkdir(outdir)

        file_name = self.cocoDt.loadImgs(img_id)[0]['file_name']
        file_path = os.path.join(self.data_dir, file_name)
        img = cv2.imread(file_path)

        self.cocoeval.evaluate() # to get the oks matrix
        gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(img_id))
        kpts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(img_id))
        scores = np.array([kpt['score'] for kpt in kpts])
        kpts = [kpt['keypoints'] for kpt in kpts]
        kpts = np.array(kpts).reshape((-1, 17, 3))

        sorted_idx = scores.argsort()[::-1]
        kpts = kpts[sorted_idx]
        scores = scores[sorted_idx]
            
        oks_scores = np.array(self.cocoeval.computeOks(img_id, 1))  # DT x GT
        match_list = self._get_gt_score_list(oks_scores, scores)
        selected_det_ids = [x[0] for x in match_list]

        for det_id, (kpt, score) in enumerate(zip(kpts, scores)):
            canvas = img.copy()
            skeleton = self.draw_skeleton(canvas, kpt[None, :, :])
            if det_id in selected_det_ids:
                match_list_id = selected_det_ids.index(det_id)
                gt_score = match_list[match_list_id][1]
            else:
                gt_score = 0.0

            v = kpt[:, -1]
            mask = v > 0.2
            masked_v =  np.where(mask, v, np.nan)
            vis_score = np.nanmean(masked_v)
            #if np.isnan(vis_score): vis_score=0.0
            if gt_score == 0.0: continue
            save_path = os.path.join(outdir, 'gt_score={}_score={}_{}.jpg'.format(gt_score, score, int(file_name[:-4])))
            #print(save_path)
            cv2.imwrite(save_path, skeleton)

    def evaluateDT(self,):
        if not self.dt_path:
            raise('please intialize the class with the name of the detection json.')
        self.cocoeval.params.useSegm = None
        self.cocoeval.evaluate()
        self.cocoeval.accumulate()
        self.cocoeval.summarize()

    def evaluateDTScore(self,):
        '''
        return score_acc: the mean error of the score to the OKS.
        '''
        def score_acc_per_image(oks_matrix, conf_s):
            # oks_matrix DT x GT
            match_list = self._get_gt_score_list(oks_matrix, conf_s)
        
            det_candidates = np.array(match_list)
            # det_candidates = det_candidates[det_candidates[:,-1]>0.1]
      
            if det_candidates.shape[0] <= 1: return -1
            
            oks = det_candidates[:, 1]
            conf = det_candidates[:, -1]
            #print(oks, conf)
            smallest = np.finfo(float).tiny
            #print(((oks+1e-10) / (oks[:, None]+1e-10)), ((conf+1e-10) / (conf[:,None]+1e-10)))
            oks_acc_matrix = (((oks+smallest) / (oks[:, None]+smallest)) > 1)
            conf_acc_matrix = (((conf+smallest) / (conf[:,None]+smallest)) > 1)
            acc_matrix = (oks_acc_matrix == conf_acc_matrix).astype(np.float32)
            #print(oks_acc_matrix, conf_acc_matrix)
            # print(acc_matrix)
            #input()
                
            score_acc = (acc_matrix.sum() - acc_matrix.shape[0]) / (acc_matrix.shape[0]**2 - acc_matrix.shape[0])     
            return score_acc

        self.cocoeval.evaluate() # to get the

        score_acc = []

        image_ids = self.img_ids
        bar = tqdm(image_ids)

        for image_id in bar:
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(image_id))
            inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
            dts = [dts[i] for i in inds]
            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(image_id))
            if len(dts) == 0 or len(gts) == 0: continue

            conf_scores = np.array([dt['score'] for dt in dts])
            oks_scores = np.array(self.cocoeval.computeOks(image_id, 1))  # DT x GT
            acc_per_image = score_acc_per_image(oks_scores, conf_scores)

            if acc_per_image != -1:
               score_acc.append(acc_per_image)
        score_acc = np.array(score_acc)
   
        return score_acc.mean()
    
    def evaluateDTScoreV2(self,):
        self.cocoeval.evaluate() # to get the oks matrix

        image_ids = self.img_ids
        bar = tqdm(image_ids)
        
        all_dts_with_gtscore = []
        for image_id in bar:
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(image_id))
            inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
            dts = [dts[i] for i in inds]
            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(image_id))
            if len(dts) == 0 or len(gts) == 0: 
                #new_json.append(dts)
                continue

            conf_scores = np.array([dt['score'] for dt in dts])
            oks_scores = np.array(self.cocoeval.computeOks(image_id, 1))  # DT x GT
            match_list = self._get_gt_score_list(oks_scores, conf_scores)
            
            selected_det_ids = [x[0] for x in match_list]
           
            for det_id, dt in enumerate(dts):
                if det_id in selected_det_ids:
                    match_list_id = selected_det_ids.index(det_id)
                    gt_s = match_list[match_list_id][1]
                else:
                    gt_s = 0.0
                all_dts_with_gtscore.append([gt_s, dt])

        ####### sort all scores #######
        #all_dts_with_gtscore.sort(key=lambda pair:pair[0])
        score_gts = np.array([i[0] for i in all_dts_with_gtscore])
        score_dts = np.array([i[1]['score'] for i in all_dts_with_gtscore])
        
        
        # sample 1000 poses for calibration metric 
        draw = [[s_dt, s_gt] for (s_dt, s_gt) in zip(score_dts, score_gts) if (s_gt > 0.5 and s_dt > 0.65 and s_gt < 1.0 and s_dt < 1.0)]
        import random
        sampled = random.sample(draw, 1000)
        sampled_dt = np.array([s_dt for (s_dt, s_gt) in sampled])
        sampled_gt = np.array([s_gt for (s_dt, s_gt) in sampled])
        # draw image
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(6, 5))
        plt.title(f'{self.model_name} conf. vs. oks')
        ax.set_xlabel('Confidence Score', size=10)
        ax.set_ylabel('OKS Score', size=10)
        ax.scatter(sampled_dt, sampled_gt)
        fig.savefig(os.path.join('test', f'{self.model_name}_coco.png'), dpi=fig.dpi, bbox_inches='tight')
        
        #pearson 
        r, p = pearsonr(sampled_dt, sampled_gt)

        # AUSE
        errors = 1.0 - sampled_gt
        uncertainty = 1.0 - sampled_dt
        ause = ause_oks(errors, uncertainty)

        rank_gt = torch.tensor(rankdata(score_gts, method='average')).cuda()[:20000]
        rank_dt = torch.tensor(rankdata(score_dts, method='average')).cuda()[:20000]
        
        with torch.no_grad():
             pairwise_rank_gt = ((rank_gt + 1e-5) / (rank_gt[:, None] + 1e-5) > 1)
             pairwise_rank_dt = ((rank_dt + 1e-5) / (rank_dt[:, None] + 1e-5) > 1)
        acc_matrix = (pairwise_rank_dt == pairwise_rank_gt).float()
        acc = (acc_matrix.sum() - acc_matrix.shape[0]) / (acc_matrix.shape[0]**2 - acc_matrix.shape[0])
        return acc.cpu().numpy(), r, ause
       
    def replaceWithGTscore(self, outdir=None):

        if not outdir:
            outdir = './visualize/{}'.format(self.gt_name.split('.')[0])
        if not os.path.exists(outdir):
            os.mkdir(outdir)

        self.cocoeval.evaluate() # to get the oks matrix

        image_ids = self.img_ids
        bar = tqdm(image_ids)
        new_json = []

        for image_id in bar:
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(image_id))
            inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
            dts = [dts[i] for i in inds]
            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(image_id))
            if len(dts) == 0 or len(gts) == 0: continue

            conf_scores = np.array([dt['score'] for dt in dts])
            oks_scores = np.array(self.cocoeval.computeOks(image_id, 1))  # DT x GT
            match_list = self._get_gt_score_list(oks_scores, conf_scores)
            
            selected_det_ids = [x[0] for x in match_list]
            for det_id, dt in enumerate(dts):
                if det_id in selected_det_ids:
                    match_list_id = selected_det_ids.index(det_id)
                    dt['score'] = match_list[match_list_id][1]
                else:
                    dt['score'] = 0.0
                new_json.append(dt)
        
        save_path = os.path.join(outdir, f'new_{self.model_name}.json')
        with open(save_path, 'w') as f:
            json.dump(new_json, f, sort_keys=True, indent=4)

    def replaceWithRanks(self, outdir=None):

        if not outdir:
            outdir = './visualize/{}'.format(self.gt_name.split('.')[0])
        if not os.path.exists(outdir):
            os.mkdir(outdir)

        self.cocoeval.evaluate() # to get the oks matrix

        image_ids = self.img_ids
        bar = tqdm(image_ids)
        new_json = []
        
        all_dts_with_gtscore = []
        for image_id in bar:
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(image_id))
            inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
            dts = [dts[i] for i in inds]
            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(image_id))
            if len(dts) == 0 or len(gts) == 0: 
                #new_json.append(dts)
                continue

            conf_scores = np.array([dt['score'] for dt in dts])
            oks_scores = np.array(self.cocoeval.computeOks(image_id, 1))  # DT x GT
            if len(oks_scores) == 0: continue
            match_list = self._get_gt_score_list(oks_scores, conf_scores)
            
            selected_det_ids = [x[0] for x in match_list]
           
            for det_id, dt in enumerate(dts):
                if det_id in selected_det_ids:
                    match_list_id = selected_det_ids.index(det_id)
                    gt_s = match_list[match_list_id][1]
                else:
                    gt_s = 0.0
                all_dts_with_gtscore.append([gt_s, dt])

        ####### sort all scores #######
        all_dts_with_gtscore.sort(key=lambda pair:pair[0])
        ranked_gts = np.array([i[0] for i in all_dts_with_gtscore])
        rank_scores = rankdata(ranked_gts, method='average') / (len(ranked_gts) - 1)
        for (rs, (gt_s, dt)) in zip(rank_scores, all_dts_with_gtscore):
            dt['score'] = rs
            new_json.append(dt)
                  
        
        save_path = os.path.join(outdir, f'new_{self.model_name}.json')
        with open(save_path, 'w') as f:
            json.dump(new_json, f, sort_keys=True, indent=4)

    def visualizeMatchedDT(self, oks_thres=None, score_thres=None, num_sample=None, outdir=None, concat_preds=False, brighten_up=False):
        '''
        # find the best mathch of the gt
        '''
        def greedyMatching(oks_matrix):
            # oks_matrix DT x GT
            gt_match_ids = oks_matrix.argmax(1)
            gt_match_scores = oks_matrix.max(1)
            match_dict = defaultdict(list)

            for det_id, (gt_id, oks) in enumerate(zip(gt_match_ids, gt_match_scores)):
                match_dict[gt_id].append([det_id, oks])

            match_res = []
            for gt_id in match_dict.keys():
                det_candidates = np.array(match_dict[gt_id])
                det_idx = det_candidates[:, 1].argmax()
                match_res.append([det_candidates[det_idx, 0], gt_id, det_candidates[det_idx, 1]])

            all_gts = set(range(oks_matrix.shape[1]))
            ignored_gts = all_gts - set(match_dict.keys())

            return match_res, list(ignored_gts)

        self.cocoeval.evaluate()

        all_preds = []
        miss_detections = []

        image_ids = self.img_ids
        bar = tqdm(image_ids)

        for image_id in bar:
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(image_id))
            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(image_id))
            if len(dts) == 0 or len(gts) == 0: continue

            conf_scores = np.array([dt['score'] for dt in dts])
            #conf_scores = conf_scores.sort()[::-1]
            oks_scores = np.array(self.cocoeval.computeOks(image_id, 1))  # DT x GT

            match_res, ignored_gts = greedyMatching(oks_scores)

            for i, (dt_id, gt_id, oks_s) in enumerate(match_res):
                gt_kpts = np.array(gts[gt_id]['keypoints']).reshape(self.num_joints, 3)
                if gt_kpts[:, 2].sum() == 0: continue

                conf_s = conf_scores[int(dt_id)]
                all_preds.append([image_id, dt_id, gt_id, conf_s, oks_s])
            
            for ig_gt in ignored_gts:
                miss_detections.append([image_id, ig_gt])

        fail_preds = np.array(all_preds)  # n x 5 (image_id, det_id, gt_id, confidence, oks)
        #bar = trange(fail_preds.shape[0])

        dt_dict = defaultdict(list)
        gt_dict = defaultdict(list)
        for idx, [img_id, dt_id, gt_id, conf_s, oks_s] in enumerate(fail_preds):
            img_id, dt_id, gt_id = int(img_id), int(dt_id), int(gt_id)
            dt_dict[img_id].append(dt_id)
            gt_dict[img_id].append(gt_id)
            # if img_id not in det_dict:
            #     det_dict[img_id] = [det_id]
            # else:
            #     det_ids = det_dict[img_id]
            #     det_ids.append(det_id)
            #     det_dict[img_id] = det_ids
        
        if not os.path.exists(outdir):
                os.makedirs(outdir)
        
        num_vis_samples = 0
        for img_id in tqdm(dt_dict):
            if (num_sample) and (num_vis_samples == num_sample):
                break
            # load img
            imgInfo = self.cocoDt.loadImgs(img_id)[0]
            f_name = imgInfo['file_name']
            img_path = os.path.join(self.data_dir, f_name)
            img = cv2.imread(img_path)
            
            h, w = img.shape[:2]
            # if h >= 1024:
            #    img = cv2.resize(img, (int(w/2), int(h/2)))
            # img_path = os.path.join(outdir, f'{img_id}_ori.jpg')
            # cv2.imwrite(img_path, img)
            
            # Scale up low-light image
            if brighten_up:
                # Calculate the average pixel intensity
                img = img.astype(np.float32)
                average_intensity = np.mean(img)

                # Compute the scaling factor
                scale_factor = 128 / average_intensity

                # Scale up the image by multiplying each pixel with the scaling factor
                img = np.clip(img * scale_factor, 0, 255).astype(np.uint8)

            # load dt kpts
         
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(img_id))
            inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
            dts = [[dts[i] for i in inds][j] for j in dt_dict[img_id]]
            
            dt_kpts = np.array([dt['keypoints'] for dt in dts]).reshape((-1, self.num_joints, 3))
            canvas = img.copy()
            dt_img = self.draw_skeleton(canvas, dt_kpts)

            # load gt kpts
            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(img_id))
            gt_kpts = np.array([gt['keypoints'] for gt in gts]).reshape((-1, self.num_joints, 3))
            canvas = img.copy()
            gt_img = self.draw_skeleton(canvas, gt_kpts)
            # concat figure
            img = np.pad(img, ((2, 2), (2, 2), (0, 0)), 'constant')
            dt_img = np.pad(dt_img, ((2, 2), (2, 2), (0, 0)), 'constant')
            gt_img = np.pad(gt_img, ((2, 2), (2, 2), (0, 0)), 'constant')

            outpath = os.path.join(outdir, f'{img_id}')
            if not os.path.exists(outpath):
                os.mkdir(outpath)
           
            if img.shape[0] >= 1024:
               # rescale if the image is too large
               img = cv2.resize(img, (int(w/2), int(h/2)))
               gt_img = cv2.resize(gt_img, (int(w/2), int(h/2)))
               dt_img = cv2.resize(dt_img, (int(w/2), int(h/2)))
            
            if concat_preds:
               output_img = np.concatenate([img, dt_img, gt_img], axis=1)
               save_name = '{}_conf_{}_oks_{}.jpg'.format(f_name[:-4], conf_s, oks_s)
               save_path = os.path.join(outpath, save_name)
               cv2.imwrite(save_path, output_img)   
            else:
               scaled_path = os.path.join(outpath, f'{img_id}.jpg')
               cv2.imwrite(scaled_path, img)
               gt_path = os.path.join(outpath, 'gt.jpg')
               cv2.imwrite(gt_path, gt_img)
               dt_path = os.path.join(outpath, f'{self.method}.jpg')
               cv2.imwrite(dt_path, dt_img)
            num_vis_samples += 1

    def failure_cases(self, oks_thres=None, score_thres=None, num_sample=None, outdir=None):
        '''
        # find the best mathch of the gt
        '''
        # assert not (oks_thres != None and num_sample !=None)

        def greedyMatching(oks_matrix):
            # oks_matrix DT x GT
            gt_match_ids = oks_matrix.argmax(1)
            gt_match_scores = oks_matrix.max(1)
            match_dict = defaultdict(list)

            for det_id, (gt_id, oks) in enumerate(zip(gt_match_ids, gt_match_scores)):
                match_dict[gt_id].append([det_id, oks])

            match_res = []
            for gt_id in match_dict.keys():
                det_candidates = np.array(match_dict[gt_id])
                det_idx = det_candidates[:, 1].argmax()
                match_res.append([det_candidates[det_idx, 0], gt_id, det_candidates[det_idx, 1]])

            all_gts = set(range(oks_matrix.shape[1]))
            ignored_gts = all_gts - set(match_dict.keys())

            return match_res, list(ignored_gts)

        if not outdir:
            # outdir = './visualize/{}_{}_{}_failure_cases'.format(self.model_name, self.dataset_name, self.dataset_type)
            outdir = f'./visualize/{self.dt_name[:-5]}_failure_cases/'
        if not os.path.exists(outdir):
            os.mkdir(outdir)

        self.cocoeval.evaluate() # to get the

        all_preds = []
        miss_detections = []

        image_ids = self.img_ids
        bar = tqdm(image_ids)

        for image_id in bar:
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(image_id))
            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(image_id))
            if len(dts) == 0 or len(gts) == 0: continue

            conf_scores = np.array([dt['score'] for dt in dts])
            #conf_scores = conf_scores.sort()[::-1]
            oks_scores = np.array(self.cocoeval.computeOks(image_id, 1))  # DT x GT
            if oks_scores.shape[0] == 0: continue
            match_res, ignored_gts = greedyMatching(oks_scores)

            for i, (det_id, gt_id, oks_s) in enumerate(match_res):
                gt_kpts = np.array(gts[gt_id]['keypoints']).reshape(self.num_joints, 3)
                if gt_kpts[:, 2].sum() == 0: continue

                conf_s = conf_scores[int(det_id)]
                all_preds.append([image_id, det_id, gt_id, conf_s, oks_s])

            for ig_gt in ignored_gts:

                miss_detections.append([image_id, ig_gt])

        fail_preds = np.array(all_preds)  # n x 5 (image_id, det_id, gt_id, confidence, oks)
        miss_detections = np.array(miss_detections)  # m x 2 (image_id, ig_gt_id)

        # print distribution of oks
        oks = fail_preds[:, -1]
        import matplotlib.pyplot as plt
        plt.figure()
        plt.title('oks distribution in {}_{}_{}'.format(self.model_name, self.dataset_name, self.dataset_type))
        plt.hist(oks)
        plt.show()

        if score_thres:
            fail_preds = fail_preds[fail_preds[:, 3] <= score_thres]

        if oks_thres:
           fail_preds = fail_preds[fail_preds[:, 4] <= oks_thres]  # select the failures with oks <= oks_thres

        if num_sample:
           idx = fail_preds[:, 4].argsort()[:num_sample]
           fail_preds = fail_preds[idx]

        # idx = np.where(fail_preds[:, 0] == 7281)[0]
        # print(fail_preds[idx])

        fail_preds = fail_preds[fail_preds[:, 4] > 1e-10]
        print('{} failures are found w.r.t oks_thres={} and num_sample={} in {} '
              'matched predictions and {} missed detections'.format(
              fail_preds.shape[0], oks_thres, num_sample,
              len(all_preds), len(miss_detections)))

        # visualize
        t_bar = trange(fail_preds.shape[0])
        for id, [img_id, det_id, gt_id, conf_s, oks_s] in enumerate(fail_preds):
            img_id, det_id, gt_id = int(img_id), int(det_id), int(gt_id)

            # load img
            imgInfo = self.cocoDt.loadImgs(img_id)[0]
            f_name = imgInfo['file_name']
            img_path = os.path.join(self.data_dir, f_name)
            img = cv2.imread(img_path)

            # Scale up low-light image
            if 'LL' in self.gt_path:
                # Calculate the average pixel intensity
                average_intensity = np.mean(img)

                # Compute the scaling factor
                scale_factor = 128 / average_intensity

                # Scale up the image by multiplying each pixel with the scaling factor
                img = np.clip(img * scale_factor, 0, 255).astype(np.uint8)

                # Convert the image to RGB format
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # load dt kpts
            dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(img_id))
            inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
            # we should sort the score from high to low because the cocoeval.evaluate() will sort the score

            # load dt kpts
            dt = [dts[i] for i in inds][det_id]
            dt_kpts = np.array(dt['keypoints']).reshape((1, self.num_joints, 3))
            canvas = img.copy()
            dt_img = self.draw_skeleton(canvas, dt_kpts)

            # load gt kpts
            gt = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(img_id))[gt_id]
            gt_kpts = np.array(gt['keypoints']).reshape((1, self.num_joints, 3))
            canvas = img.copy()
            gt_img = self.draw_skeleton(canvas, gt_kpts)

            # concat figure
            img = np.pad(img, ((2, 2), (2, 2), (0, 0)), 'constant')
            dt_img = np.pad(dt_img, ((2, 2), (2, 2), (0, 0)), 'constant')
            gt_img = np.pad(gt_img, ((2, 2), (2, 2), (0, 0)), 'constant')
            output_img = np.concatenate([img, dt_img, gt_img], axis=1)
           
            save_name = '{}_conf_{}_oks_{}.jpg'.format(f_name[:-4], conf_s, oks_s)
            save_path = os.path.join(outdir, save_name)
            cv2.imwrite(save_path, output_img)

            t_bar.update(1)
        t_bar.close()

    def statistics(self, is_GT=True, return_dict=False):
        if is_GT:
           self._statisticsGT(return_dict)
        else:
           self._statisticsDT(return_dict)
        pass

    def _statisticsGT(self, return_dict=False):
        '''
            return:
             dict{
                  'area':[min, avg, max],
                  'ratio_scales':[small, medium, large],
                  'ratio_occluded_joints':[min, avg, max],
                  'num_poses': [min, avg, max, all],
                  'num_bboxes': [min, avg, max, all],
                  'num_interacting_joints': [min, avg, max, all],
                  }
        '''
        coco = self.cocoGt
        img_ids = list(coco.imgs.keys())

        areas = []
        num_persons = []

        bar = tqdm(img_ids)
        for img_id in bar:
            anns = coco.loadAnns(coco.getAnnIds(img_id))
            kpts = [a for a in anns if a['num_keypoints'] > 0]
            area = [a['area'] for a in kpts]

            num_persons.append(len(kpts))
            areas.extend(area)

        areas = np.array(areas).reshape(-1)
        num_persons = np.array(num_persons).reshape(-1)

        dict = {
            'metrics':['min', 'avg', 'max'],
            'num_persons': [num_persons.min(), num_persons.mean(), num_persons.max()],
            'area': [areas.min(), areas.mean(), areas.max()]
        }

        print(dict)

        if return_dict:
           return dict

    def _statisticsDT(self, return_dict=False):
        pass

    def interacting_statistics_bbox_view(self, iou_limits=0.0, scale_limits=1.5):
        # this is a statistics particular make for no the interacting joints

        print('Warning! To evaluate the interacting dataset in the view of bbox, '
              'you should load the annotations of instance. i.e instances_train2017.json')

        from crowdposetools.coco import COCO
        # crowdpose_ids = COCO('../data/crowdpose/annotations/crowdpose_train.json')
        # image_ids = list(crowdpose_ids.imgs.keys())
        image_ids = list(self.cocoGt.imgs.keys())

        counter_bbox = []
        counter_interacting_bbox_pairs = []
        counter_interacting_imgs = 0

        from interacting_joints_utils import get_interacting_number, pose2box

        for image_id in image_ids:

            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=image_id, catIds=1))
            #gts = crowdpose_ids.loadAnns(crowdpose_ids.getAnnIds(imgIds=image_id, catIds=1))
            bboxes = np.array([gt['bbox'] for gt in gts if (gt['bbox'][2] > 0. and gt['bbox'][3] > 0.)])

            if self.cocoDt:
               dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=image_id, catIds=1))
               poses = np.array([dt['keypoints'] for dt in dts if dt['score'] > 1e-3]).reshape((-1, self.num_joints, 3))
               if poses.shape[0] != 0:
                   dt_bboxes = pose2box(poses)
                   #print(dt_bboxes.shape)
                   bboxes = np.concatenate([bboxes, dt_bboxes], axis=0)
                   #print(bboxes.shape)

            if len(bboxes) == 0:
                continue

            counter_bbox.append(len(bboxes))
            num_interacting_pairs = get_interacting_number(bboxes, iou_limits, scale_limits)
            counter_interacting_bbox_pairs.append(num_interacting_pairs)

            if num_interacting_pairs != 0:
                counter_interacting_imgs += 1

        total_bbox = np.array(counter_bbox).sum()
        total_interacting_pairs = np.array(counter_interacting_bbox_pairs).sum()
        total_interacting_imgs = counter_interacting_imgs
        total_imgs = len(image_ids)

        print('total bbox: {}, total interacting bboxes pairs: {}, ratio of bboxes: {}'.format(total_bbox,
                                            total_interacting_pairs, total_interacting_pairs * 2 / total_bbox))
        print('total imgs: {}, total interacting imgs: {}, ratio of imgs: {}'.format(total_imgs,
                                            total_interacting_imgs, total_interacting_imgs / total_imgs))

    def interacting_statistics_pose_view(self, iou_limits=0.0, scale_limits=1.5):
        # this is a statistics particular make for no the interacting joints

        print('Warning! To evaluate the interacting dataset in the view of bbox, '
              'you should load the annotations of keypoints. i.e person_keypoints_train2017.json')

        image_ids = list(self.cocoGt.imgs.keys())
        image_ids = [ids for ids in image_ids if len(self.cocoGt.getAnnIds(ids, iscrowd=None)) > 0]

        counter_bbox = []
        counter_interacting_bbox_pairs = []
        counter_interacting_imgs = 0

        from interacting_joints_utils import get_interacting_number, pose2box

        for image_id in image_ids:

            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=image_id))
            bboxes = np.array([gt['bbox'] for gt in gts if (gt['bbox'][2] > 0. and gt['bbox'][3] > 0. and np.array(gt['keypoints']).sum() > 0.0)])

            if self.cocoDt:
                dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=image_id, catIds=1))
                poses = np.array([dt['keypoints'] for dt in dts if dt['score'] > 1e-3]).reshape((-1, self.num_joints, 3))
                if poses.shape[0] != 0:
                    dt_bboxes = pose2box(poses)
                    # print(dt_bboxes.shape)
                    bboxes = np.concatenate([bboxes, dt_bboxes], axis=0)
                    # print(bboxes.shape)

            if len(bboxes) == 0:
                continue

            counter_bbox.append(len(bboxes))
            num_interacting_pairs = get_interacting_number(bboxes, iou_limits, scale_limits)
            counter_interacting_bbox_pairs.append(num_interacting_pairs)

            if num_interacting_pairs != 0:
                counter_interacting_imgs += 1

        total_bbox = np.array(counter_bbox).sum()
        total_interacting_pairs = np.array(counter_interacting_bbox_pairs).sum()
        total_interacting_imgs = counter_interacting_imgs
        total_imgs = len(image_ids)

        print('total bbox: {}, total interacting bboxes pairs: {}, ratio of bboxes: {}'.format(total_bbox,
                                                                                               total_interacting_pairs,
                                                                                               total_interacting_pairs * 2 / total_bbox))
        print('total imgs: {}, total interacting imgs: {}, ratio of imgs: {}'.format(total_imgs,
                                                                                     total_interacting_imgs,
                                                                                     total_interacting_imgs / total_imgs))

    def _pts2area(self, pts):
        '''
        param pts: numpy array pts (n, 17, 3)
        return: numpy array area (n, 1)
        '''
        x_max, x_min = pts[:, :, 0].max(1), pts[:, :, 0].min(1)
        y_max, y_min = pts[:, :, 1].max(1), pts[:, :, 1].min(1)

        area = (x_max - x_min) * (y_max - y_min)
        return np.sqrt(area)

    def compare_res(self, comp_dt_path, comp_model_name='unknown', score_thre=0.0, outdir=None):

        if not outdir:
           outdir='./visualize/{}_{}_{}_vs_{}'.format(self.dataset_name, self.dataset_type, self.model_name, comp_model_name)

        if not os.path.exists(outdir):
            os.mkdir(outdir)

        cocoDt1 = self.cocoDt
        cocoDt2 = self.cocoGt.loadRes(comp_dt_path)

        all_matched_counter = 0
        large_difference_counter = 0
        fast_better_counter = 0
        slow_better_counter = 0
        score_correct_counter = 0
        unmatched_pose_pair = 0
        un_matched_oks_dt1 = []
        un_matched_oks_dt2 = []
        large_difference_oks = 0.5

        tbar = tqdm(self.img_ids)

        for id in tbar:
            imgInfo = self.cocoDt.loadImgs(id)[0]
            f_name = imgInfo['file_name']
            img_w, img_h = imgInfo['width'], imgInfo['height']
            img_path = os.path.join(self.data_dir, f_name)
            img = cv2.imread(img_path)

            gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id))
            dt1s = cocoDt1.loadAnns(cocoDt1.getAnnIds(id))
            dt2s = cocoDt2.loadAnns(cocoDt2.getAnnIds(id))

            dt1_scores = [dt1['score'] for dt1 in dt1s if dt1['score'] >= score_thre]
            dt2_scores = [dt2['score'] for dt2 in dt2s if dt2['score'] >= score_thre]

            gt_kpts = [g['keypoints'] for g in gts]
            gt_areas = [g['area'] for g in gts]
            dt1s = [dt1['keypoints'] for dt1 in dt1s if dt1['score'] >= score_thre]
            dt2s = [dt2['keypoints'] for dt2 in dt2s if dt2['score'] >= score_thre]

            gt_kpts = np.array(gt_kpts).astype(np.float32).reshape((-1, self.num_joints, 3))
            gt_areas = np.array(gt_areas).astype(np.float32)
            dt1s = np.array(dt1s).astype(np.float32).reshape((-1, self.num_joints, 3))
            dt1_areas = self._pts2area(dt1s)
            dt2s = np.array(dt2s).astype(np.float32).reshape((-1, self.num_joints, 3))

            gt_oks_dt1, dt1_matched_gt_id = oks_one(gt_kpts, dt1s, gt_areas, [img_w, img_h])
            gt_oks_dt2, dt2_matched_gt_id = oks_one(gt_kpts, dt2s, gt_areas, [img_w, img_h])
            gt_oks_dt1, gt_oks_dt2 = np.array(gt_oks_dt1).astype(np.float32), np.array(gt_oks_dt2).astype(np.float32)
            dt1_matched_gt_id, dt2_matched_gt_id = np.array(dt1_matched_gt_id).astype(np.int32), np.array(dt2_matched_gt_id).astype(np.int32)

            img_pad = np.pad(img, ((2, 2), (2, 2), (0, 0)), 'constant')
            # for i, gt_kpt in enumerate(gt_kpts):
            #     all_matched_counter += 1
            #     dt1_matched_idx = np.where(dt1_matched_gt_id == i)[0]
            #     dt2_matched_idx = np.where(dt2_matched_gt_id == i)[0]
            #     if len(dt1_matched_idx) == 0 or len(dt2_matched_idx) ==0: continue
            #     dt1_oks = gt_oks_dt1[dt1_matched_idx]
            #     dt2_oks = gt_oks_dt2[dt2_matched_idx]
            #     best_matched_dt1_id = dt1_matched_idx[np.argmax(dt1_oks)]
            #     best_matched_dt2_id = dt2_matched_idx[np.argmax(dt2_oks)]
            #
            #     best_matched_dt1_gt_oks = gt_oks_dt1[best_matched_dt1_id]
            #     best_matched_dt2_gt_oks = gt_oks_dt2[best_matched_dt2_id]
            #
            #     best_matched_dt1_kpt = dt1s[best_matched_dt1_id]
            #     best_matched_dt2_kpt = dt2s[best_matched_dt2_id]
            #
            #     oks_difference = OKS(best_matched_dt1_kpt, best_matched_dt2_kpt, dt1_areas[best_matched_dt1_id],
            #                          [img_w, img_h])
            #
            #     dt1_s = dt1_scores[best_matched_dt1_id]
            #     dt2_s = dt2_scores[best_matched_dt2_id]
            #
            #     # if oks_difference < large_difference_oks:
            #         #large_difference_counter += 1
            #     # if best_matched_dt2_gt_oks - best_matched_dt1_gt_oks > 0.1:
            #     #     fast_better_counter += 1
            #     #     canvas = img.copy()
            #     #     # print(gt_kpt.shape)
            #     #     gt_img = draw_skeleton(canvas, gt_kpt[None, :, :])
            #     #
            #     #     canvas = img.copy()
            #     #     dt1_img = draw_skeleton(canvas, best_matched_dt1_kpt[None, :, :])
            #     #
            #     #     canvas = img.copy()
            #     #     dt2_img = draw_skeleton(canvas, best_matched_dt2_kpt[None, :, :])
            #     #
            #     #     gt_img = np.pad(gt_img, ((2, 2), (2, 2), (0, 0)), 'constant')
            #     #     dt1_img = np.pad(dt1_img, ((2, 2), (2, 2), (0, 0)), 'constant')
            #     #     dt2_img = np.pad(dt2_img, ((2, 2), (2, 2), (0, 0)), 'constant')
            #     #     output_img = np.concatenate([gt_img, dt1_img, dt2_img], axis=1)
            #     #
            #     #     save_name = '{}_ori_{}_{}_better_than_{}_{}_{}.jpg'.format(f_name[:-4], self.model_name,
            #     #                                                                best_matched_dt1_gt_oks, comp_model_name,
            #     #                                                                best_matched_dt2_gt_oks, i)
            #     #     save_path = os.path.join(outdir, save_name)
            #     #     cv2.imwrite(save_path, output_img)
            #     #
            #     #     if dt1_s > dt2_s:
            #     #         score_correct_counter += 1
            #     #
            #     # elif best_matched_dt1_gt_oks > best_matched_dt2_gt_oks:
            #     #     slow_better_counter += 1
            #     #     if dt1_s < dt2_s:
            #     #         score_correct_counter += 1

            okss, dt1_matched_ids = oks_one(dt1s, dt2s, dt1_areas, [img_w, img_h])

            #print(dt1_matched_ids)
            #print(dt1_matched_gt_id)
            #img_pad = np.pad(img, ((2, 2), (2, 2), (0, 0)), 'constant')
            for i, (oks, dt1_id, dt2, dt2_s) in  enumerate(zip(okss, dt1_matched_ids, dt2s, dt2_scores)):
                if dt1_id == -1: continue

                canvas = img.copy()
                gt_img = draw_skeleton(canvas, gt_kpts[dt1_matched_gt_id[dt1_id]][None, :, :])

                dt1 = dt1s[dt1_id]
                dt1_s = dt1_scores[dt1_id]
                canvas = img.copy()
                #print(dt1.shape)
                dt1_img = draw_skeleton(canvas, dt1[None, :, :])

                canvas = img.copy()
                dt2_img = draw_skeleton(canvas, dt2[None, :, :])

                gt_img = np.pad(gt_img, ((2, 2), (2, 2), (0, 0)), 'constant')
                dt1_img = np.pad(dt1_img, ((2, 2), (2, 2), (0, 0)), 'constant')
                dt2_img = np.pad(dt2_img, ((2, 2), (2, 2), (0, 0)), 'constant')
                output_img = np.concatenate([gt_img, dt1_img, dt2_img], axis=1)

                save_name = '{}_ori_{}_{}_vs_{}_{}_{}.jpg'.format(f_name[:-4], self.model_name, dt1_s, comp_model_name, dt2_s, i)
                save_path = os.path.join(outdir, save_name)
                cv2.imwrite(save_path, output_img)
        # print('Among {} matched pairs, {}({}) large difference (oks > {}) are found '.format(all_matched_counter,
        #                                                                                      large_difference_counter,
        #                                                                                      large_difference_counter / all_matched_counter,
        #                                                                                      large_difference_oks))
        # print('Among {} large difference, {}({})fast poses are better and {}({}) slow poses are better.'.format(
        #     large_difference_counter, fast_better_counter, fast_better_counter / large_difference_counter,
        #     slow_better_counter,
        #                                                    slow_better_counter / large_difference_counter
        # ))
        #
        # print('Among {} large difference {}({}) scores are correct.'.format(large_difference_counter,
        #                                                                     score_correct_counter,
        #                                                                     score_correct_counter / large_difference_counter))

        # print('{} unmathced pose pair'.format(unmatched_pose_pair))
        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.hist(un_matched_oks_dt1)
        # plt.title('fast unmatched oks distribution')
        # plt.show()
        #
        # plt.figure()
        # plt.hist(un_matched_oks_dt2)
        # plt.title('slow unmatched oks distribution')
        # plt.show()

    @staticmethod
    def help():
        msg = "please initialize the visualkpts with four params:" \
              "1. dataset_name (str): the name of the dataset. One should be choosen from coco, ochuman and crowdpose" \

        print(msg)

