from collections import OrderedDict
from tkinter.messagebox import NO
import os.path as osp
from pycocotools import mask
import pdb
from mmcv.utils import print_log

from mmdet.core import eval_map, eval_recalls
from mmdet.core import eval_map_segm
from mmdet.datasets import pipelines
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.xml_style import XMLDataset
from chainercv.datasets.voc import voc_utils
from chainercv.utils import read_image
from chainercv.utils import read_label
import numpy as np

@DATASETS.register_module()
class VOCDatasetInstance(XMLDataset):

    CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
               'tvmonitor')

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if 'VOC2007' in self.img_prefix:
            self.year = 2007
        elif 'VOC2012' in self.img_prefix:
            self.year = 2012
        else:
            raise ValueError('Cannot infer dataset year from img_prefix')
            
#     def get_ann_info(self, idx):
#         """Get annotation from XML file by index.

#         Args:
#             idx (int): Index of data.

#         Returns:
#             dict: Annotation info of specified index.
#         """

#         img_id = self.data_infos[idx]['id']
#         xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml')
#         tree = ET.parse(xml_path)
#         root = tree.getroot()
#         points = []
#         bboxes = []
#         labels = []
#         points_ignore = []
#         bboxes_ignore = []
#         labels_ignore = []
#         for obj in root.findall('object'):
#             name = obj.find('name').text
#             if name not in self.CLASSES:
#                 continue
#             label = self.cat2label[name]
#             difficult = obj.find('difficult')
#             difficult = 0 if difficult is None else int(difficult.text)
#             bnd_box = obj.find('bndbox')
#             # TODO: check whether it is necessary to use int
#             # Coordinates may be float type
#             bbox = [
#                 int(float(bnd_box.find('xmin').text)),
#                 int(float(bnd_box.find('ymin').text)),
#                 int(float(bnd_box.find('xmax').text)),
#                 int(float(bnd_box.find('ymax').text))
#             ]
#             ignore = False
#             if self.min_size:
#                 assert not self.test_mode
#                 w = bbox[2] - bbox[0]
#                 h = bbox[3] - bbox[1]
#                 if w < self.min_size or h < self.min_size:
#                     ignore = True
#             # find point annotations
#             point_ann = obj.find('point')
#             point = [
#                 int(float(point_ann.find('x').text)),
#                 int(float(point_ann.find('y').text)),
#             ]
#             if difficult or ignore:
#                 points_ignore.append(point)
#                 bboxes_ignore.append(bbox)
#                 labels_ignore.append(label)
#             else:
#                 points.append(point)
#                 bboxes.append(bbox)
#                 labels.append(label)
#         if not bboxes:
#             points = np.zeros((0, 2))
#             bboxes = np.zeros((0, 4))
#             labels = np.zeros((0, ))
#         else:
#             points = np.array(points, ndmin=2) - 1
#             bboxes = np.array(bboxes, ndmin=2) - 1
#             labels = np.array(labels)
#         if not bboxes_ignore:
#             points_ignore = np.zeros((0, 2))
#             bboxes_ignore = np.zeros((0, 4))
#             labels_ignore = np.zeros((0, ))
#         else:
#             points_ignore = np.array(points_ignore, ndmin=2) - 1
#             bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
#             labels_ignore = np.array(labels_ignore)
#         ann = dict(
#             points=points.astype(np.float32),
#             bboxes=bboxes.astype(np.float32),
#             labels=labels.astype(np.int64),
#             points_ignore=points_ignore.astype(np.float32),
#             bboxes_ignore=bboxes_ignore.astype(np.float32),
#             labels_ignore=labels_ignore.astype(np.int64))
#         return ann 

    def pre_pipeline(self, results):
        """Prepare results dict for pipeline."""
        results['img_prefix'] = self.img_prefix
        results['seg_prefix'] = self.seg_prefix
        results['proposal_file'] = self.proposal_file
        results['bbox_fields'] = []
        results['point_fields'] = []
        results['mask_fields'] = []
        results['seg_fields'] = []
        
    def evaluate(self,
                 results,
                 metric='mAP',
                 logger=None,
                 proposal_nums=(100, 300, 1000),
                 iou_thr=[0.25, 0.5, 0.75],
                 scale_ranges=None):
        """Evaluate in VOC protocol.

        Args:
            results (list[list | tuple]): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated. Options are
                'mAP', 'recall'.
            logger (logging.Logger | str, optional): Logger used for printing
                related information during evaluation. Default: None.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thr (float | list[float]): IoU threshold. Default: 0.5.
            scale_ranges (list[tuple], optional): Scale ranges for evaluating
                mAP. If not specified, all bounding boxes would be included in
                evaluation. Default: None.

        Returns:
            dict[str, float]: AP/recall metrics.
        """

        if not isinstance(metric, str):
            assert len(metric) == 1
            metric = metric[0]
        allowed_metrics = ['mAP', 'recall', 'mAP_Segm']
        if metric not in allowed_metrics:
            raise KeyError(f'metric {metric} is not supported')
        annotations = [self.get_ann_info_test(i) for i in range(len(self))]
        eval_results = OrderedDict()
        iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
        if metric == 'mAP':
            assert isinstance(iou_thrs, list)
            if self.year == 2007:
                ds_name = 'voc07'
            else:
                ds_name = self.CLASSES
            mean_aps = []
            for iou_thr in iou_thrs:
                print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
                mean_ap, _ = eval_map(
                    results,
                    annotations,
                    scale_ranges=None,
                    iou_thr=iou_thr,
                    dataset=ds_name,
                    logger=logger)
                mean_aps.append(mean_ap)
                eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
            eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
        elif metric == 'mAP_Segm':
            det_results = [res[0] for res in results]
            segm_results= [res[1] for res in results]
            for iou_thr in iou_thrs:
                det_eval_result = self.evaluate(det_results, metric='mAP', logger=logger, 
                            proposal_nums=proposal_nums, iou_thr=iou_thr,
                            scale_ranges=scale_ranges)
                segm_map, _ = eval_map_segm(det_results=det_results, 
                        segm_results=segm_results, 
                        annotations=annotations,
                        scale_ranges=None,
                        iou_thr=iou_thr,
                        dataset='voc',
                        logger=logger)
            eval_results['mAP_Segm'] = segm_map
            eval_results.update(det_eval_result)
        elif metric == 'recall':
            gt_bboxes = [ann['bboxes'] for ann in annotations]
            recalls = eval_recalls(
                gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
            for i, num in enumerate(proposal_nums):
                for j, iou in enumerate(iou_thr):
                    eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
            if recalls.shape[1] > 1:
                ar = recalls.mean(axis=1)
                for i, num in enumerate(proposal_nums):
                    eval_results[f'AR@{num}'] = ar[i]
        
        return eval_results
    
    def get_ann_info(self, idx):
        # ann0 = super().get_ann_info(idx)
        img_id = self.data_infos[idx]['id']
        label_path = osp.join(self.img_prefix, 'SegmentationClass', f'{img_id}.png')
        inst_path= osp.join(self.img_prefix, 'SegmentationObject', f'{img_id}.png')
        label_img = read_label(label_path, dtype=np.int32)
        label_img[label_img == 255] = -1
        inst_img = read_label(inst_path, dtype=np.int32)
        inst_img[inst_img == 0] = -1
        inst_img[inst_img == 255] = -1
        mask_bin, labels = voc_utils.image_wise_to_instance_wise(
            label_img, inst_img)
        mask_bin = np.asfortranarray(mask_bin.transpose(1,2,0).astype(np.uint8))
        mask_encode = mask.encode(mask_bin)
        bbox_mask   = mask.toBbox(mask_encode)
        bbox_mask[:, 2:] += bbox_mask[:, :2]
        ann = dict(
            bboxes=bbox_mask.astype(np.float32),
            labels=labels.astype(np.int64),
            bboxes_ignore=np.zeros((0, 4)).astype(np.float32),
            labels_ignore=np.zeros((0, )).astype(np.int64),
            masks=mask_encode)
        return ann

    def get_ann_info_test(self, idx):
        # ann0 = super().get_ann_info(idx)
        img_id = self.data_infos[idx]['id']
        label_path = osp.join(self.img_prefix, 'SegmentationClass', f'{img_id}.png')
        inst_path= osp.join(self.img_prefix, 'SegmentationObject', f'{img_id}.png')
        label_img = read_label(label_path, dtype=np.int32)
        label_img[label_img == 255] = -1
        inst_img = read_label(inst_path, dtype=np.int32)
        inst_img[inst_img == 0] = -1
        inst_img[inst_img == 255] = -1
        mask_bin, labels = voc_utils.image_wise_to_instance_wise(
            label_img, inst_img)
        mask_bin_f = np.asfortranarray(mask_bin.transpose(1,2,0).astype(np.uint8))
        mask_encode = mask.encode(mask_bin_f)
        bbox_mask   = mask.toBbox(mask_encode)
        bbox_mask[:, 2:] += bbox_mask[:, :2]
        ann = dict(
            bboxes=bbox_mask.astype(np.float32),
            labels=labels.astype(np.int64),
            bboxes_ignore=np.zeros((0, 4)).astype(np.float32),
            labels_ignore=np.zeros((0, )).astype(np.int64),
            masks=mask_bin)
        return ann
    
