import json
import copy

import numpy as np
import torch

from opencood.utils import eval_utils

class MetricUtil:
    is_load_from_json: bool = False
    all_in_box: list
    in_all_agent_bbox_number: int = 0
    all_agent_bbox_number: int = 0
    extra_width: int = None
    thr: int = 20
    
    is_dair: bool = True
    is_two_stage: bool = False
    is_use_mask: bool = False
    all_in_box_list: list = []
    inbox_inds_count_list: list = []
    result_stat: list = [{
        0.3: {'tp': [], 'fp': [], 'gt': 0, 'score': []},
        0.5: {'tp': [], 'fp': [], 'gt': 0, 'score': []},
        0.7: {'tp': [], 'fp': [], 'gt': 0, 'score': []}
        } for _ in range(40)]
    
    @staticmethod
    def get_points_in_corrd(points_list, bboxes_list, bbox_labels_list, img_metas):
        if not MetricUtil.is_load_from_json:
            record_len = img_metas[0]['record_len']
            
            for i, agent_num in enumerate(record_len):
                bboxes = bboxes_list[i]
                bboxes_num = len(bboxes)
                base_idx = record_len[0:i].sum()
                inbox_inds_list = [[False for i in range (bboxes_num+1)] for j in range(agent_num)]
                inbox_inds_count_list = [[0 for i in range (bboxes_num+1)] for j in range(agent_num)]
                all_in_box = [False for i in range (bboxes_num+1)]
                
                for j in range(agent_num):
                    points = points_list[base_idx+j][:, :3]
                    
                    extra_width = MetricUtil.extra_width
                    if extra_width is not None:
                        bboxes = bboxes.enlarged_box_hw(extra_width)
                    inbox_inds = bboxes.points_in_boxes(points).long()
                    unique_inds, unique_inds_count = torch.unique(inbox_inds, return_counts=True)
                    for b in range(-1, bboxes_num):
                        if b in unique_inds:
                            inbox_inds_list[j][b] = True
                            inbox_inds_count_list[j][b] = unique_inds_count[unique_inds == b].tolist()[0]
                
                for b in range(0, bboxes_num):
                    flag = True
                    for j in range(agent_num):
                        if inbox_inds_count_list[j][b] < MetricUtil.thr:
                            flag = False
                            break
                    all_in_box[b] = flag
                MetricUtil.inbox_inds_count_list.append(inbox_inds_count_list)
                MetricUtil.all_in_box = all_in_box
                MetricUtil.all_in_box_list.append(all_in_box)
                MetricUtil.all_agent_bbox_number += bboxes_num
                MetricUtil.in_all_agent_bbox_number += sum(all_in_box[:-1])    
        else:
            pass
    
    @staticmethod
    def load_from_json():
        try:
            MetricUtil.load_json()
            MetricUtil.is_load_from_json = True
        except:
            MetricUtil.is_load_from_json = False
    
    @staticmethod
    def set_threshold(threshold):
        MetricUtil.thr = threshold
    
    @staticmethod
    def save_json():
        with open(f'all_in_box_{MetricUtil.thr}.json', 'w') as f:
            json.dump(MetricUtil.all_in_box_list, f, indent=4)
        prefix = '_dair' if MetricUtil.is_dair else ''
        name = f'agents_points_in_bbox{prefix}.json'
        with open(name, 'w') as f:
            json.dump(MetricUtil.inbox_inds_count_list, f, indent=4)

    @staticmethod
    def load_json():
        prefix = '_dair' if MetricUtil.is_dair else ''
        name = f'agents_points_in_bbox{prefix}.json'
        with open(name, 'r') as f:
            MetricUtil.all_in_box_list = json.load(f)
    
    @staticmethod
    def get_gt_mask(idx):
        return MetricUtil.all_in_box_list[idx]
    
    @staticmethod
    def set_args(**kargs):
        for k, v in kargs.items():
            setattr(MetricUtil, k, v)
    
    @staticmethod
    def calucalate_tp_fp(idx, pred_dicts, result_stat, use_mask=False, mask=0, down_thr=-1, up_thr=-1):
        pred_dicts = copy.deepcopy(pred_dicts)
        pred_box_tensor = pred_dicts['pred_box']
        pred_score = pred_dicts['pred_score']
        gt_box_tensor = pred_dicts['gt_box']
        if not use_mask:
            eval_utils.caluclate_tp_fp(pred_box_tensor, pred_score, gt_box_tensor, result_stat, 0.3)
            eval_utils.caluclate_tp_fp(pred_box_tensor, pred_score, gt_box_tensor, result_stat, 0.5)
            eval_utils.caluclate_tp_fp(pred_box_tensor, pred_score, gt_box_tensor, result_stat, 0.7)
        else:
            if MetricUtil.is_load_from_json:
                used_mask = MetricUtil.get_gt_mask(idx)
                # breakpoint()
                used_mask = used_mask[mask] 
                if down_thr == -1: 
                    down_thr = -1
                if up_thr == -1:
                    up_thr = 1e10
                used_mask = [down_thr < j <= up_thr for j in used_mask]
            else:
                used_mask = MetricUtil.all_in_box 
            # used_mask = [j == mask for j in used_mask]
            eval_utils.caluclate_tp_fp_with_mask(pred_box_tensor, pred_score, gt_box_tensor, result_stat, 0.3, used_mask)
            eval_utils.caluclate_tp_fp_with_mask(pred_box_tensor, pred_score, gt_box_tensor, result_stat, 0.5, used_mask)
            eval_utils.caluclate_tp_fp_with_mask(pred_box_tensor, pred_score, gt_box_tensor, result_stat, 0.7, used_mask)
    
    @staticmethod
    def eval_iter_with_stage(idx, pred_dicts, use_mask=False, mask=-1, down_thr=-1, up_thr=-1, filter=0):
        if not use_mask:
            assert mask == -1

        MetricUtil.is_two_stage = True if isinstance(pred_dicts['pred_box'], dict) else False
        if MetricUtil.is_two_stage:
            for key in pred_dicts['pred_box'].keys():
                pred_dicts_t = dict(
                    pred_box=pred_dicts['pred_box'][key],
                    pred_score=pred_dicts['pred_score'][key],
                    gt_box=pred_dicts['gt_box'],
                )
                if key == 'one_stage':
                    MetricUtil.calucalate_tp_fp(idx, pred_dicts_t, MetricUtil.result_stat[get_result_idx(mask, 0, filter)], use_mask=use_mask, mask=mask, down_thr=down_thr, up_thr=up_thr)
                elif key == 'two_stage':
                    MetricUtil.calucalate_tp_fp(idx, pred_dicts_t, MetricUtil.result_stat[get_result_idx(mask, 1, filter)], use_mask=use_mask, mask=mask,  down_thr=down_thr, up_thr=up_thr)
        else:
            MetricUtil.calucalate_tp_fp(idx, pred_dicts, MetricUtil.result_stat[get_result_idx(mask, 0, filter)], use_mask=use_mask, mask=mask,  down_thr=down_thr, up_thr=up_thr)
    
    @staticmethod
    def eval_iter(idx, pred_dicts):
        """
            take apart use_mask and not use_mask
        """
        use_mask = MetricUtil.is_use_mask
        if use_mask:
            # breakpoint()
            MetricUtil.eval_iter_with_stage(idx, pred_dicts, use_mask=False)
            MetricUtil.eval_iter_with_stage(idx, pred_dicts, use_mask=use_mask, mask=0, down_thr=-1, up_thr=30, filter=1)
            MetricUtil.eval_iter_with_stage(idx, pred_dicts, use_mask=use_mask, mask=0, down_thr=30, up_thr=150, filter=2)
            MetricUtil.eval_iter_with_stage(idx, pred_dicts, use_mask=use_mask, mask=0, down_thr=150, up_thr=-1, filter=3)

            # MetricUtil.eval_iter_with_stage(idx, pred_dicts, use_mask=use_mask, mask=1)
        else:
            MetricUtil.eval_iter_with_stage(idx, pred_dicts, use_mask=use_mask)
            
    @staticmethod
    def eval_mAp_with_stage(output_dir, use_mask=False, mask=-1, filter=0):
        if not use_mask:
            assert mask == -1

        if MetricUtil.is_two_stage:
            print("\tone stage:", end='\t')
            eval_utils.eval_final_results(MetricUtil.result_stat[get_result_idx(mask, 0, filter)], output_dir)
            print("\ttwo stage:", end='\t')
            eval_utils.eval_final_results(MetricUtil.result_stat[get_result_idx(mask, 1, filter)], output_dir)
        else:
            eval_utils.eval_final_results(MetricUtil.result_stat[get_result_idx(mask, 0, filter)], output_dir)
    
    @staticmethod
    def eval_mAp(output_dir):
        """
            take apart use_mask and not use_mask
        """
        use_mask = MetricUtil.is_use_mask
        if use_mask:
            print("norm label:")
            MetricUtil.eval_mAp_with_stage(output_dir, use_mask=False)
            # print("none-coop label:")
            # MetricUtil.eval_mAp_with_stage(output_dir, use_mask=use_mask, mask=0)
            # print("coop label:")
            # MetricUtil.eval_mAp_with_stage(output_dir, use_mask=use_mask, mask=1)
            print("hard label:")
            MetricUtil.eval_mAp_with_stage(output_dir, use_mask=use_mask, mask=0, filter=1)
            print("mid label:")
            MetricUtil.eval_mAp_with_stage(output_dir, use_mask=use_mask, mask=0, filter=2)
            print("easy label:")
            MetricUtil.eval_mAp_with_stage(output_dir, use_mask=use_mask, mask=0, filter=3)
        else:
            print("norm label:")
            MetricUtil.eval_mAp_with_stage(output_dir, use_mask=use_mask)
    

    # communication volume
    comunication_volume: list = []
    point_num: list = []
    cluster_num: list = []
    
    iter_num: int = 0
    point_feature_dim: int = 3
    cluster_feature_dim: int = 768
    @staticmethod
    def get_volume(point_num: int, cluster_num: int, 
                   point_feature_dim: int = 3, cluster_feature_dim: int = 768):
        return 4 * (point_num*point_feature_dim + cluster_num*cluster_feature_dim)
    
    @staticmethod
    def record_volume(point_num: int, cluster_num: int, 
                      point_feature_dim: int = 3, cluster_feature_dim: int = 768):
        MetricUtil.iter_num += 1
        MetricUtil.point_num.append(point_num)
        MetricUtil.cluster_num.append(cluster_num)
        MetricUtil.comunication_volume.append(MetricUtil.get_volume(point_num, cluster_num,
                                                                    point_feature_dim, cluster_feature_dim))
    
    @staticmethod
    def print_volume():
        # iter_num = MetricUtil.iter_num
        # point_num = (np.array(MetricUtil.point_num)).mean()
        # cluster_num = (np.array(MetricUtil.cluster_num)).mean()
        # comunication_volume = np.log2((np.array(MetricUtil.comunication_volume)).mean())
        # print(f'iter num: {iter_num}')
        # print(f'point num / iter: {point_num:.4f}')
        # print(f'cluster num / iter: {cluster_num:.4f}')
        # print(f'comunication volume (log2) / iter: {comunication_volume:.4f}')
        pass


def get_result_idx(mask, stage, filter=0):
    return (mask + 1) * 2 + stage + filter * 10