# Inheritaed from open-mmlab/mmclassification
# Modified by Hangyu LIN

import copy
import os.path as osp
from abc import ABCMeta, abstractmethod
from os import PathLike
from typing import List

import numpy as np
from torch.utils.data import Dataset

from .transforms import Compose
from .transition_function import ClassMatrixTransistion

def expanduser(path):
    if isinstance(path, (str, PathLike)):
        return osp.expanduser(path)
    else:
        return path

class BaseDataset(Dataset, metaclass=ABCMeta):
    """Base dataset.

    Args:
        root (str): the prefix of data path
        transforms (list): a list of dict, where each element represents
            a operation defined in `transforms`
        ann_file (str | None): the annotation file. When ann_file is str,
            the subclass is expected to read from the ann_file. When ann_file
            is None, the subclass is expected to read according to data_prefix
        test_mode (bool): in train mode or test mode
    """

    CLASSES = None
    DEFAULT_TRNASFORMS = Compose([])
    def __init__(self,
                 root,
                 transforms,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if transforms is None:
            transforms = self.DEFAULT_TRNASFORMS
        self.transforms = transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()

    @abstractmethod
    def load_annotations(self):
        pass

    @property
    def class_to_idx(self):
        """Map mapping class name to class index.

        Returns:
            dict: mapping from class name to class index.
        """

        return {_class: i for i, _class in enumerate(self.CLASSES)}

    def get_gt_labels(self):
        """Get all ground-truth labels (categories).

        Returns:
            np.ndarray: categories for all images.
        """

        gt_labels = np.array([data['gt_label'] for data in self.data_infos])
        return gt_labels

    def get_cat_ids(self, idx: int) -> List[int]:
        """Get category id by index.

        Args:
            idx (int): Index of data.

        Returns:
            cat_ids (List[int]): Image category of specified index.
        """

        return [int(self.data_infos[idx]['gt_label'])]

    def __len__(self):
        return len(self.data_infos)

    def __getitem__(self, idx):
        img, label = self.data_infos[idx]['img'], self.data_infos[idx]['gt_label']
        if self.transforms is not None:
            img = self.transforms(img)
        if self.target_transforms is not None:
            label = self.target_transforms(label)
        return img, label

    @classmethod
    def get_classes(cls, classes=None):
        """Get class names of current dataset.

        Args:
            classes (Sequence[str] | str | None): If classes is None, use
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.

        Returns:
            tuple[str] or list[str]: Names of categories of the dataset.
        """
        if classes is None:
            return cls.CLASSES

        # if isinstance(classes, str):
        #     # take it as a file path
        #     class_names = mmcv.list_from_file(expanduser(classes))
        # el
        if isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names


class NoisyDataset(BaseDataset):
    """
    Datasets that generated with noisy data from the original datasets, mainly for classsification.
    """
    def __init__(self,
                 root,
                 transforms,
                 trans_rate,
                 trans_type,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if transforms is None:
            transforms = self.DEFAULT_TRNASFORMS
        self.trans_rate = trans_rate
        self.trans_type = trans_type
        self.transforms = Compose(transforms)
        if target_transforms is not None:
            target_transforms = Compose(target_transforms)
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.num_classes = len(self.CLASSES)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
        self.post_process()
        self._use_corrected = True


    def post_process(self):
        self.transition_func = ClassMatrixTransistion(len(self.CLASSES), self.trans_rate, self.trans_type)
        self.data_infos = self._transit_labels(self.data_infos, self.transition_func)

    def _transit_labels(self, data_infos, transition_func):
        """
        Transit the labels by the noise transition matrix T
        Only support for classification now
        """
        transited_data_infos = []
        for idx in range(len(data_infos)):
            term = copy.copy(data_infos[idx])
            x, label = term['img'], term['gt_label']
            transit_prob = transition_func(x, label)
            transited_label = np.random.choice(self.num_classes, p=transit_prob.detach().numpy())
            term['trans_label'] = transited_label
            transited_data_infos.append(term)
        return transited_data_infos
    
    def update_corrected_label(self, idx, label):
        self.data_infos[idx]['corrected_trans_label'] = label

    def retransit_label(self, idx, x, label, transition_func):
        transit_prob = transition_func(x, label)
        transited_label = np.random.choice(self.num_classes, p=transit_prob.detach().cpu().numpy())
        self.data_infos[idx]['corrected_trans_label'] = label
