# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List, Optional, Sequence, Union

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

try:
    from dsdl.dataset import DSDLDataset
except ImportError:
    DSDLDataset = None


@DATASETS.register_module()
class DSDLSegDataset(BaseSegDataset):
    """Dataset for dsdl segmentation.

    Args:
        specific_key_path(dict): Path of specific key which can not
            be loaded by it's field name.
        pre_transform(dict): pre-transform functions before loading.
        used_labels(sequence): list of actual used classes in train steps,
            this must be subset of class domain.
    """

    METAINFO = {}

    def __init__(self,
                 specific_key_path: Dict = {},
                 pre_transform: Dict = {},
                 used_labels: Optional[Sequence] = None,
                 **kwargs) -> None:

        if DSDLDataset is None:
            raise RuntimeError(
                'Package dsdl is not installed. Please run "pip install dsdl".'
            )
        self.used_labels = used_labels

        loc_config = dict(type='LocalFileReader', working_dir='')
        if kwargs.get('data_root'):
            kwargs['ann_file'] = os.path.join(kwargs['data_root'],
                                              kwargs['ann_file'])
        required_fields = ['Image', 'LabelMap']

        self.dsdldataset = DSDLDataset(
            dsdl_yaml=kwargs['ann_file'],
            location_config=loc_config,
            required_fields=required_fields,
            specific_key_path=specific_key_path,
            transform=pre_transform,
        )
        BaseSegDataset.__init__(self, **kwargs)

    def load_data_list(self) -> List[Dict]:
        """Load data info from a dsdl yaml file named as ``self.ann_file``

        Returns:
            List[dict]: A list of data list.
        """

        if self.used_labels:
            self._metainfo['classes'] = tuple(self.used_labels)
            self.label_map = self.get_label_map(self.used_labels)
        else:
            self._metainfo['classes'] = tuple(['background'] +
                                              self.dsdldataset.class_names)
        data_list = []

        for i, data in enumerate(self.dsdldataset):
            datainfo = dict(
                img_path=os.path.join(self.data_prefix['img_path'],
                                      data['Image'][0].location),
                seg_map_path=os.path.join(self.data_prefix['seg_map_path'],
                                          data['LabelMap'][0].location),
                label_map=self.label_map,
                reduce_zero_label=self.reduce_zero_label,
                seg_fields=[],
            )
            data_list.append(datainfo)

        return data_list

    def get_label_map(self,
                      new_classes: Optional[Sequence] = None
                      ) -> Union[Dict, None]:
        """Require label mapping.

        The ``label_map`` is a dictionary, its keys are the old label ids and
        its values are the new label ids, and is used for changing pixel
        labels in load_annotations. If and only if old classes in class_dom
        is not equal to new classes in args and nether of them is not
        None, `label_map` is not None.
        Args:
            new_classes (list, tuple, optional): The new classes name from
                metainfo. Default to None.
        Returns:
            dict, optional: The mapping from old classes to new classes.
        """
        old_classes = ['background'] + self.dsdldataset.class_names
        if (new_classes is not None and old_classes is not None
                and list(new_classes) != list(old_classes)):

            label_map = {}
            if not set(new_classes).issubset(old_classes):
                raise ValueError(
                    f'new classes {new_classes} is not a '
                    f'subset of classes {old_classes} in class_dom.')
            for i, c in enumerate(old_classes):
                if c not in new_classes:
                    label_map[i] = 255
                else:
                    label_map[i] = new_classes.index(c)
            return label_map
        else:
            return None
