import json
import contextlib
import io
import logging
import warnings
from collections import OrderedDict

import mmcv
import numpy as np
from mmcv.utils import print_log
from mmdet.datasets.api_wrappers import COCO, COCOeval
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset


@DATASETS.register_module()
class CocoDatasetOV(CocoDataset):
    def __init__(self,
                 ann_file,
                 pipeline,
                 seen_classes='datasets/mscoco_seen_classes.json',
                 unseen_classes='datasets/mscoco_unseen_classes.json',
                 all_classes='datasets/mscoco_65_classes.json',
                 **kwargs):
        
        super().__init__(ann_file, pipeline, **kwargs)
        self.seen_classes = json.load(open(seen_classes))
        self.unseen_classes = json.load(open(unseen_classes))
        self.all_classes = json.load(open(all_classes))
        self.PALETTE = [(255, 0, 0) if cls in self.unseen_classes else (0, 0, 255) for cls in self.all_classes]
        self.CLASSES = self.all_classes
        # self.CLASSES = seen_classes

    def get_ann_info(self, idx):
        """Get COCO annotation by index.

        Args:
            idx (int): Index of data.

        Returns:
            dict: Annotation info of specified index.
        """
        data_info = self.data_infos[idx].copy()
        img_id = self.data_infos[idx]['id']
        ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
        ann_info = self.coco.load_anns(ann_ids)

        data_info = self._parse_ann_info(self.data_infos[idx], ann_info)

        return data_info
    
    def _parse_ann_info(self, img_info, ann_info):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,\
                labels, masks, seg_map. "masks" are raw annotations and not \
                decoded into binary masks.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        gt_masks_ann = []
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
            inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
            if inter_w * inter_h == 0:
                continue
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            if ann['category_id'] not in self.cat_ids:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]
            if ann.get('iscrowd', False):
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])
                gt_masks_ann.append(ann.get('segmentation', None))

        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        seg_map = img_info['filename'].rsplit('.', 1)[0] + self.seg_suffix

        ann = dict(
            bboxes=gt_bboxes,
            labels=gt_labels,
            bboxes_ignore=gt_bboxes_ignore,
            masks=gt_masks_ann,
            seg_map=seg_map)

        return ann

    def evaluate(self,
                 results,
                 metric='bbox',
                 logger=None,
                 jsonfile_prefix=None,
                 classwise=True,
                 proposal_nums=(100, 300, 1000),
                 iou_thrs=None,
                 metric_items=None):
        """Evaluation in COCO protocol.

        Args:
            results (list[list | tuple]): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated. Options are
                'bbox', 'segm', 'proposal', 'proposal_fast'.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str | None): The prefix of json files. It includes
                the file path and the prefix of filename, e.g., "a/b/prefix".
                If not specified, a temp file will be created. Default: None.
            classwise (bool): Whether to evaluating the AP for each class.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thrs (Sequence[float], optional): IoU threshold used for
                evaluating recalls/mAPs. If set to a list, the average of all
                IoUs will also be computed. If not specified, [0.50, 0.55,
                0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
                Default: None.
            metric_items (list[str] | str, optional): Metric items that will
                be returned. If not specified, ``['AR@100', 'AR@300',
                'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
                used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
                'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
                ``metric=='bbox' or metric=='segm'``.

        Returns:
            dict[str, float]: COCO style evaluation metric.
        """

        metrics = metric if isinstance(metric, list) else [metric]
        allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError(f'metric {metric} is not supported')

        coco_gt = self.coco
        # self.cat_ids = coco_gt.get_cat_ids(cat_names=self.CLASSES)
        # self.cat_ids = self.known_cat_ids
        # self.cat_ids = self.unknown_cat_ids
        result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
        eval_results = self.evaluate_det_segm(results, result_files, coco_gt,
                                              metrics, logger, classwise,
                                              proposal_nums, iou_thrs,
                                              metric_items)

        if tmp_dir is not None:
            tmp_dir.cleanup()
        return eval_results

    def evaluate_det_segm(self,
                          results,
                          result_files,
                          coco_gt,
                          metrics,
                          logger=None,
                          classwise=False,
                          proposal_nums=(100, 300, 1000),
                          iou_thrs=None,
                          metric_items=None):
        """Instance segmentation and object detection evaluation in COCO
        protocol.

        Args:
            results (list[list | tuple | dict]): Testing results of the
                dataset.
            result_files (dict[str, str]): a dict contains json file path.
            coco_gt (COCO): COCO API object with ground truth annotation.
            metric (str | list[str]): Metrics to be evaluated. Options are
                'bbox', 'segm', 'proposal', 'proposal_fast'.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            classwise (bool): Whether to evaluating the AP for each class.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thrs (Sequence[float], optional): IoU threshold used for
                evaluating recalls/mAPs. If set to a list, the average of all
                IoUs will also be computed. If not specified, [0.50, 0.55,
                0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
                Default: None.
            metric_items (list[str] | str, optional): Metric items that will
                be returned. If not specified, ``['AR@100', 'AR@300',
                'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
                used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
                'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
                ``metric=='bbox' or metric=='segm'``.

        Returns:
            dict[str, float]: COCO style evaluation metric.
        """
        if iou_thrs is None:
            iou_thrs = np.linspace(
                .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
        if metric_items is not None:
            if not isinstance(metric_items, list):
                metric_items = [metric_items]

        eval_results = OrderedDict()
        for metric in metrics:
            msg = f'Evaluating {metric}...'
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            # if metric == 'proposal_fast':
            #     if isinstance(results[0], tuple):
            #         raise KeyError('proposal_fast is not supported for '
            #                        'instance segmentation result.')
            #     ar = self.fast_eval_recall(
            #         results, proposal_nums, iou_thrs, logger='silent')
            #     log_msg = []
            #     for i, num in enumerate(proposal_nums):
            #         eval_results[f'AR@{num}'] = ar[i]
            #         log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
            #     log_msg = ''.join(log_msg)
            #     print_log(log_msg, logger=logger)
            #     continue
            if metric == 'proposal_fast':
                proposal_list = []
                for box_res, mask_res in results:
                    proposal_list.append(np.concatenate(box_res, axis=0))
                ar = self.fast_eval_recall(
                        proposal_list,
                        proposal_nums,
                        iou_thrs, logger='silent'
                    )
                log_msg = []
                for i, num in enumerate(proposal_nums):
                    eval_results[f'AR@{num}'] = float(f'{ar[i]:.4f}')
                    log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            iou_type = 'bbox' if metric == 'proposal' else metric
            if metric not in result_files:
                raise KeyError(f'{metric} is not in results')
            try:
                predictions = mmcv.load(result_files[metric])
                if iou_type == 'segm':
                    # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331  # noqa
                    # When evaluating mask AP, if the results contain bbox,
                    # cocoapi will use the box area instead of the mask area
                    # for calculating the instance area. Though the overall AP
                    # is not affected, this leads to different
                    # small/medium/large mask AP results.
                    for x in predictions:
                        x.pop('bbox')
                    warnings.simplefilter('once')
                    warnings.warn(
                        'The key "bbox" is deleted for more accurate mask AP '
                        'of small/medium/large instances since v2.12.0. This '
                        'does not change the overall mAP calculation.',
                        UserWarning)
                coco_det = coco_gt.loadRes(predictions)
            except IndexError:
                print_log(
                    'The testing results of the whole dataset is empty.',
                    logger=logger,
                    level=logging.ERROR)
                break

            cocoEval = COCOeval(coco_gt, coco_det, iou_type)
            cocoEval.params.catIds = self.cat_ids
            cocoEval.params.imgIds = self.img_ids
            cocoEval.params.maxDets = list(proposal_nums)
            cocoEval.params.iouThrs = iou_thrs
            # mapping of cocoEval.stats
            coco_metric_names = {
                'mAP': 0,
                'mAP_50': 1,
                'mAP_75': 2,
                'mAP_s': 3,
                'mAP_m': 4,
                'mAP_l': 5,
                'AR@100': 6,
                'AR@300': 7,
                'AR@1000': 8,
                'AR_s@1000': 9,
                'AR_m@1000': 10,
                'AR_l@1000': 11
            }
            if metric_items is not None:
                for metric_item in metric_items:
                    if metric_item not in coco_metric_names:
                        raise KeyError(
                            f'metric item {metric_item} is not supported')

            if metric == 'proposal':
                cocoEval.params.useCats = 0
                cocoEval.evaluate()
                cocoEval.accumulate()

                # Save coco summarize print information to logger
                redirect_string = io.StringIO()
                with contextlib.redirect_stdout(redirect_string):
                    cocoEval.summarize()
                print_log('\n' + redirect_string.getvalue(), logger=logger)

                if metric_items is None:
                    metric_items = [
                        'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
                        'AR_m@1000', 'AR_l@1000'
                    ]

                for item in metric_items:
                    val = float(
                        f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
                    eval_results[item] = val
            else:
                cocoEval.evaluate()
                cocoEval.accumulate()

                # Save coco summarize print information to logger
                redirect_string = io.StringIO()
                with contextlib.redirect_stdout(redirect_string):
                    cocoEval.summarize()
                print_log('\n' + redirect_string.getvalue(), logger=logger)

                if classwise and metric == 'bbox':  # Compute per-category AP
                    # Compute per-category AP
                    # from https://github.com/facebookresearch/detectron2/
                    precisions = cocoEval.eval['precision']
                    # precision: (iou, recall, cls, area range, max dets)
                    assert len(self.cat_ids) == precisions.shape[2]

                    results_per_category = []
                    results_per_category50 = []
                    results_per_category50_seen = []
                    results_per_category50_unseen = []
                    for idx, catId in enumerate(self.cat_ids):

                        # area range index 0: all area ranges
                        # max dets index -1: typically 100 per image
                        name = self.coco.loadCats(catId)[0]['name']
                        
                        precision = precisions[:, :, idx, 0, -1]
                        precision = precision[precision > -1]
                        ap = np.mean(precision) if precision.size else float("nan")
                        results_per_category.append(("{}".format(name), float(ap * 100)))
                        
                        precision50 = precisions[0, :, idx, 0, -1]
                        precision50 = precision50[precision50 > -1]
                        ap50 = np.mean(precision50) if precision50.size else float("nan")
                        # results_per_category50.append(("{}".format(name), float(ap50 * 100)))
                        results_per_category50.append(float(ap50 * 100))
                        if name in self.seen_classes:
                            results_per_category50_seen.append(float(ap50 * 100))
                        if name in self.unseen_classes:
                            results_per_category50_unseen.append(float(ap50 * 100))
                    
                    base_ap50 = np.nanmean(results_per_category50_seen)
                    novel_ap50 = np.nanmean(results_per_category50_unseen)
                    all_ap50 = np.nanmean(results_per_category50)
                    eval_results['novel_ap50'] = float(f'{novel_ap50:.3f}')
                    eval_results['base_ap50'] = float(f'{base_ap50:.3f}')
                    eval_results['all_ap50'] = float(f'{all_ap50:.3f}')

                    recalls = cocoEval.eval['recall']
                    recalls_all = []
                    recalls_base = []
                    recalls_novel = []
                    for idx, catId in enumerate(self.cat_ids):
                        # area range index 0: all area ranges
                        # max dets index -1: typically 100 per image
                        name = self.coco.loadCats(catId)[0]['name']
                        
                        recall = recalls[0, idx, 0, -1]
                        recall = recall[recall > -1]
                        re = np.mean(recall) if recall.size else float("nan")
                        recalls_all.append(float(re * 100))
                        if name in self.seen_classes:
                            recalls_base.append(float(re * 100))
                        if name in self.unseen_classes:
                            recalls_novel.append(float(re * 100))
                    base_re = np.nanmean(recalls_base)
                    novel_re = np.nanmean(recalls_novel)
                    all_re = np.nanmean(recalls_all)
                    eval_results['novel_AR50'] = float(f'{novel_re:.3f}')
                    eval_results['base_AR50'] = float(f'{base_re:.3f}')
                    eval_results['all_AR50'] = float(f'{all_re:.3f}')

                print_log(f"{iou_type} AP50 novel: {float(f'{novel_ap50:.1f}')}", logger=logger)
                print_log(f"{iou_type} AP50 base : {float(f'{base_ap50:.1f}')}", logger=logger)
                print_log(f"{iou_type} AP50 all  : {float(f'{all_ap50:.1f}')}", logger=logger)
                print_log(f"{iou_type} AR50 novel: {float(f'{novel_re:.1f}')}", logger=logger)
                print_log(f"{iou_type} AR50 base : {float(f'{base_re:.1f}')}", logger=logger)
                print_log(f"{iou_type} AR50 all  : {float(f'{all_re:.1f}')}\n", logger=logger)

                if metric_items is None:
                    metric_items = [
                        'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
                    ]

                for metric_item in metric_items:
                    key = f'{metric}_{metric_item}'
                    val = float(
                        f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
                    )
                    eval_results[key] = val
                ap = cocoEval.stats[:6]
                eval_results[f'{metric}_mAP_copypaste'] = (
                    f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
                    f'{ap[4]:.3f} {ap[5]:.3f}')

        return eval_results

    def __repr__(self):
        pass
