# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing import get_context

import numpy as np
import torch
from mmcv.ops import box_iou_rotated
from mmcv.utils import print_log
from mmdet.core import average_precision
from terminaltables import AsciiTable
import os.path as osp
import math
import pickle

def tpfp_default(det_bboxes,
                 gt_bboxes,
                 gt_bboxes_ignore=None,
                 iou_thr=0.5,
                 area_ranges=None):
    """Check if detected bboxes are true positive or false positive.

    Args:
        det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 6).
        gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 5).
        gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
            of shape (k, 5). Default: None
        iou_thr (float): IoU threshold to be considered as matched.
            Default: 0.5.
        area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
            in the format [(min1, max1), (min2, max2), ...]. Default: None.

    Returns:
        tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
            each array is (num_scales, m).
    """
    # an indicator of ignored gts
    det_bboxes = np.array(det_bboxes)
    gt_ignore_inds = np.concatenate(
        (np.zeros(gt_bboxes.shape[0],
                  dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
    # stack gt_bboxes and gt_bboxes_ignore for convenience
    gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))

    num_dets = det_bboxes.shape[0]
    num_gts = gt_bboxes.shape[0]
    if area_ranges is None:
        area_ranges = [(None, None)]
    num_scales = len(area_ranges)
    # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
    # a certain scale
    tp = np.zeros((num_scales, num_dets), dtype=np.float32)
    fp = np.zeros((num_scales, num_dets), dtype=np.float32)

    # if there is no gt bboxes in this image, then all det bboxes
    # within area range are false positives
    if gt_bboxes.shape[0] == 0:
        if area_ranges == [(None, None)]:
            fp[...] = 1
        else:
            raise NotImplementedError
        return tp, fp

    ious = box_iou_rotated(
        torch.from_numpy(det_bboxes).float(),
        torch.from_numpy(gt_bboxes).float()).numpy()
    # for each det, the max iou with all gts
    ious_max = ious.max(axis=1)
    # for each det, which gt overlaps most with it
    ious_argmax = ious.argmax(axis=1)
    # sort all dets in descending order by scores
    sort_inds = np.argsort(-det_bboxes[:, -1])
    for k, (min_area, max_area) in enumerate(area_ranges):
        gt_covered = np.zeros(num_gts, dtype=bool)
        # if no area range is specified, gt_area_ignore is all False
        if min_area is None:
            gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
        else:
            raise NotImplementedError
        for i in sort_inds:
            if ious_max[i] >= iou_thr:
                matched_gt = ious_argmax[i]
                if not (gt_ignore_inds[matched_gt]
                        or gt_area_ignore[matched_gt]):
                    if not gt_covered[matched_gt]:
                        gt_covered[matched_gt] = True
                        tp[k, i] = 1
                    else:
                        fp[k, i] = 1
                # otherwise ignore this detected bbox, tp = 0, fp = 0
            elif min_area is None:
                fp[k, i] = 1
            else:
                bbox = det_bboxes[i, :5]
                area = bbox[2] * bbox[3]
                if area >= min_area and area < max_area:
                    fp[k, i] = 1
    return tp, fp


def get_cls_results(det_results, annotations, class_id):
    """Get det results and gt information of a certain class.

    Args:
        det_results (list[list]): Same as `eval_map()`.
        annotations (list[dict]): Same as `eval_map()`.
        class_id (int): ID of a specific class.

    Returns:
        tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
    """
    cls_dets = [img_res[class_id] for img_res in det_results]

    cls_gts = []
    cls_gts_ignore = []
    for ann in annotations:
        gt_inds = ann['labels'] == class_id
        cls_gts.append(ann['bboxes'][gt_inds, :])

        if ann.get('labels_ignore', None) is not None:
            ignore_inds = ann['labels_ignore'] == class_id
            cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])

        else:
            cls_gts_ignore.append(torch.zeros((0, 5), dtype=torch.float64))

    return cls_dets, cls_gts, cls_gts_ignore


def eval_rbbox_map(det_results,
                   annotations,
                   scale_ranges=None,
                   iou_thr=[0.05*i for i in range(10,100//5)],
                   use_07_metric=False,
                   dataset=None,
                   logger=None,
                   nproc=2):
    """Evaluate mAP of a rotated dataset.

    Args:
        det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
            The outer list indicates images, and the inner list indicates
            per-class detected bboxes.
        annotations (list[dict]): Ground truth annotations where each item of
            the list indicates an image. Keys of annotations are:

            - `bboxes`: numpy array of shape (n, 5)
            - `labels`: numpy array of shape (n, )
            - `bboxes_ignore` (optional): numpy array of shape (k, 5)
            - `labels_ignore` (optional): numpy array of shape (k, )
        scale_ranges (list[tuple] | None): Range of scales to be evaluated,
            in the format [(min1, max1), (min2, max2), ...]. A range of
            (32, 64) means the area range between (32**2, 64**2).
            Default: None.
        iou_thr (float): IoU threshold to be considered as matched.
            Default: 0.5.
        use_07_metric (bool): Whether to use the voc07 metric.
        dataset (list[str] | str | None): Dataset name or dataset classes,
            there are minor differences in metrics for different datasets, e.g.
            "voc07", "imagenet_det", etc. Default: None.
        logger (logging.Logger | str | None): The way to print the mAP
            summary. See `mmcv.utils.print_log()` for details. Default: None.
        nproc (int): Processes used for computing TP and FP.
            Default: 4.

    Returns:
        tuple: (mAP, [dict, dict, ...])
    """
    assert len(det_results) == len(annotations)

    num_imgs = len(det_results)
    num_scales = len(scale_ranges) if scale_ranges is not None else 1
    num_classes = len(det_results[0])  # positive class num
    area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
                   if scale_ranges is not None else None)

    pool = get_context('spawn').Pool(nproc)
    mAP50eval_results=None
    mAP75eval_results=None
    mAP_eval_results=[]
  
    for thr in iou_thr:
        eval_results = []
        for i in range(num_classes):
            # get gt and det bboxes of this class
            cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
                det_results, annotations, i)

            # compute tp and fp for each image with multiple processes
            tpfp = pool.starmap(
                tpfp_default,
                zip(cls_dets, cls_gts, cls_gts_ignore,
                    [thr for _ in range(num_imgs)],
                    [area_ranges for _ in range(num_imgs)]))
            tp, fp = tuple(zip(*tpfp))
            # calculate gt number of each scale
            # ignored gts or gts beyond the specific scale are not counted
            num_gts = np.zeros(num_scales, dtype=int)
            for _, bbox in enumerate(cls_gts):
                if area_ranges is None:
                    num_gts[0] += bbox.shape[0]
                else:
                    gt_areas = bbox[:, 2] * bbox[:, 3]
                    for k, (min_area, max_area) in enumerate(area_ranges):
                        num_gts[k] += np.sum((gt_areas >= min_area)
                                            & (gt_areas < max_area))
            # sort all det bboxes by score, also sort tp and fp
            cls_dets = np.vstack(cls_dets)
            num_dets = cls_dets.shape[0]
            sort_inds = np.argsort(-cls_dets[:, -1])
            tp = np.hstack(tp)[:, sort_inds]
            fp = np.hstack(fp)[:, sort_inds]
            # calculate recall and precision with tp and fp
            tp = np.cumsum(tp, axis=1)
            fp = np.cumsum(fp, axis=1)
            eps = np.finfo(np.float32).eps
            recalls = tp / np.maximum(num_gts[:, np.newaxis], eps)
            precisions = tp / np.maximum((tp + fp), eps)
            # calculate AP
            if scale_ranges is None:
                recalls = recalls[0, :]
                precisions = precisions[0, :]
                num_gts = num_gts.item()
            mode = 'area' if not use_07_metric else '11points'
            ap = average_precision(recalls, precisions, mode)
            eval_results.append({
                'num_gts': num_gts,
                'num_dets': num_dets,
                'recall': recalls,
                'precision': precisions,
                'ap': ap
            })
            mAP_eval_results.append(eval_results)
            if thr==0.5:
                mAP50eval_results=eval_results
            elif thr==0.75:
                mAP75eval_results=eval_results 
    pool.close() 
    
    mean_ap_50 = 0
    mean_ap_75 = 0
    mean_aps = []
    for i, eval_results in enumerate(mAP_eval_results):
        if scale_ranges is not None:
            # shape (num_classes, num_scales)
            all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
            all_num_gts = np.vstack(
                [cls_result['num_gts'] for cls_result in eval_results])
            mean_ap = []
            for i in range(num_scales):
                if np.any(all_num_gts[:, i] > 0):
                    mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean())
                else:
                    mean_ap.append(0.0)
        else:
            aps = []
            for cls_result in eval_results:
                if cls_result['num_gts'] > 0:
                    aps.append(cls_result['ap'])
            mean_ap = np.array(aps).mean().item() if aps else 0.0
        mean_aps.append(mean_ap)
        if i == 0:
            mean_ap_50 = mean_ap
        elif i == 5:
            mean_ap_75 = mean_ap

   
   
    if scale_ranges is not None:
        # shape (num_classes, num_scales)
        all_ap = np.vstack([cls_result['ap'] for cls_result in mAP75eval_results])
        all_num_gts = np.vstack(
            [cls_result['num_gts'] for cls_result in mAP75eval_results])
        mean_ap_75 = []
        for i in range(num_scales):
            if np.any(all_num_gts[:, i] > 0):
                mean_ap_75.append(all_ap[all_num_gts[:, i] > 0, i].mean())
            else:
                mean_ap_75.append(0.0)
    else:
        aps = []
        for cls_result in mAP75eval_results:
            if cls_result['num_gts'] > 0:
                aps.append(cls_result['ap'])
        mean_ap_75 = np.array(aps).mean().item() if aps else 0.0

    # with open('eval_map_info.pkl', 'wb') as f:
    #     pickle.dump((mean_ap_50, mean_ap_75, mAP50eval_results, mAP75eval_results, dataset, area_ranges, logger, sum(mean_aps)/len(mean_aps)), f)
    print_map_summary(
        mean_ap_50, mean_ap_75, mAP50eval_results, mAP75eval_results, dataset, area_ranges, logger, sum(mean_aps)/len(mean_aps))    
        
    return mean_ap_50, eval_results


def print_map_summary(mean_ap_50,
                      mean_ap_75,
                      mAP50eval_results,
                      mAP75eval_results,
                      dataset=None,
                      scale_ranges=None,
                      logger=None,
                      mAP=0):
    """Print mAP and results of each class.

    A table will be printed to show the gts/dets/recall/AP of each class and
    the mAP.

    Args:
        mean_ap (float): Calculated from `eval_map()`.
        results (list[dict]): Calculated from `eval_map()`.
        dataset (list[str] | str | None): Dataset name or dataset classes.
        scale_ranges (list[tuple] | None): Range of scales to be evaluated.
        logger (logging.Logger | str | None): The way to print the mAP
            summary. See `mmcv.utils.print_log()` for details. Default: None.
    """

    if logger == 'silent':
        return

    if isinstance(mAP50eval_results[0]['ap'], np.ndarray):
        num_scales = len(mAP50eval_results[0]['ap'])
    else:
        num_scales = 1

    if scale_ranges is not None:
        assert len(scale_ranges) == num_scales

    num_classes = len(mAP50eval_results)

    recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
    precisions = np.zeros((num_scales, num_classes), dtype=np.float32)
    aps50 = np.zeros((num_scales, num_classes), dtype=np.float32)
    aps75 = np.zeros((num_scales, num_classes), dtype=np.float32)
    num_gts = np.zeros((num_scales, num_classes), dtype=int)
    for i, cls_result in enumerate(mAP50eval_results):
        if cls_result['recall'].size > 0:
            recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
        if cls_result['precision'].size > 0:
            precisions[:, i] = np.array(cls_result['precision'], ndmin=2)[:, -1]
        aps50[:, i] = cls_result['ap']
        num_gts[:, i] = cls_result['num_gts']

    for i, cls_result in enumerate(mAP75eval_results):
        if cls_result['recall'].size > 0:
            recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
        if cls_result['precision'].size > 0:
            precisions[:, i] = np.array(cls_result['precision'], ndmin=2)[:, -1]
        aps75[:, i] = cls_result['ap']

    if dataset is None:
        label_names = [str(i) for i in range(num_classes)]
    else:
        label_names = dataset

    if not isinstance(mean_ap_50, list):
        mean_ap_50 = [mean_ap_50]

    if not isinstance(mean_ap_75, list):
        mean_ap_75 = [mean_ap_75]

    header = ['class', 'gts', 'dets', 'F1','mAP', 'ap50', 'ap75']
    for i in range(num_scales):
        if scale_ranges is not None:
            print_log(f'Scale range {scale_ranges[i]}', logger=logger)
        table_data = [header]
        f1s = []
        for j in range(num_classes):
            f1 = (2*precisions[i, j]*recalls[i, j])/(precisions[i, j]+recalls[i, j])
            f1s.append(f1)
            row_data = [
                label_names[j], num_gts[i, j], mAP50eval_results[j]['num_dets'],
                f'{f1:.4f}','', f'{aps50[i, j]:.4f}', f'{aps75[i, j]:.4f}'
            ]
            table_data.append(row_data)
        table_data.append(['mean', '', '', f'{np.nanmean(f1s).item():.4f}',f'{mAP:.4f}', f'{mean_ap_50[i]:.4f}', f'{mean_ap_75[i]:.4f}'])
        table = AsciiTable(table_data)
        table.inner_footing_row_border = True
        print_log('\n' + table.table, logger=logger)
