# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .multi_label import AveragePrecision, MultiLabelMetric


class VOCMetricMixin:
    """A mixin class for VOC dataset metrics, VOC annotations have extra
    `difficult` attribute for each object, therefore, extra option is needed
    for calculating VOC metrics.

    Args:
        difficult_as_postive (Optional[bool]): Whether to map the difficult
            labels as positive in one-hot ground truth for evaluation. If it
            set to True, map difficult gt labels to positive ones(1), If it
            set to False, map difficult gt labels to negative ones(0).
            Defaults to None, the difficult labels will be set to '-1'.
    """

    def __init__(self,
                 *arg,
                 difficult_as_positive: Optional[bool] = None,
                 **kwarg):
        self.difficult_as_positive = difficult_as_positive
        super().__init__(*arg, **kwarg)

    def process(self, data_batch, data_samples: Sequence[dict]):
        """Process one batch of data samples.

        The processed results should be stored in ``self.results``, which will
        be used to computed the metrics when all batches have been processed.

        Args:
            data_batch: A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """
        for data_sample in data_samples:
            result = dict()
            gt_label = data_sample['gt_label']
            gt_label_difficult = data_sample['gt_label_difficult']

            result['pred_score'] = data_sample['pred_score'].clone()
            num_classes = result['pred_score'].size()[-1]

            if 'gt_score' in data_sample:
                result['gt_score'] = data_sample['gt_score'].clone()
            else:
                result['gt_score'] = label_to_onehot(gt_label, num_classes)

            # VOC annotation labels all the objects in a single image
            # therefore, some categories are appeared both in
            # difficult objects and non-difficult objects.
            # Here we reckon those labels which are only exists in difficult
            # objects as difficult labels.
            difficult_label = set(gt_label_difficult) - (
                set(gt_label_difficult) & set(gt_label.tolist()))

            # set difficult label for better eval
            if self.difficult_as_positive is None:
                result['gt_score'][[*difficult_label]] = -1
            elif self.difficult_as_positive:
                result['gt_score'][[*difficult_label]] = 1

            # Save the result to `self.results`.
            self.results.append(result)


@METRICS.register_module()
class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric):
    """A collection of metrics for multi-label multi-class classification task
    based on confusion matrix for VOC dataset.

    It includes precision, recall, f1-score and support.

    Args:
        difficult_as_postive (Optional[bool]): Whether to map the difficult
            labels as positive in one-hot ground truth for evaluation. If it
            set to True, map difficult gt labels to positive ones(1), If it
            set to False, map difficult gt labels to negative ones(0).
            Defaults to None, the difficult labels will be set to '-1'.
        **kwarg: Refers to `MultiLabelMetric` for detailed docstrings.
    """


@METRICS.register_module()
class VOCAveragePrecision(VOCMetricMixin, AveragePrecision):
    """Calculate the average precision with respect of classes for VOC dataset.

    Args:
        difficult_as_postive (Optional[bool]): Whether to map the difficult
            labels as positive in one-hot ground truth for evaluation. If it
            set to True, map difficult gt labels to positive ones(1), If it
            set to False, map difficult gt labels to negative ones(0).
            Defaults to None, the difficult labels will be set to '-1'.
        **kwarg: Refers to `AveragePrecision` for detailed docstrings.
    """
