import logging
from PIL import Image
from timm.data import ImageDataset

_logger = logging.getLogger(__name__)

_ERROR_RETRY = 50


class PickleTaskDataset(ImageDataset):

    def __init__(self, reader, root=None, class_map=None, load_bytes=False, **kwargs):

        super(PickleTaskDataset, self).__init__(root, reader=reader, class_map=class_map, load_bytes=load_bytes, **kwargs)

        # Add number of classes
        self.num_classes = len(set([sample[1] for sample in self.reader.samples]))
        self._offset = 0

    def set_class_offset(self, offset):
        self._offset = offset

    def __getitem__(self, index):

        img, target = self.reader[index]
        try:
            img = Image.fromarray(img).convert('RGB')
        except Exception as e:
            _logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
            self._consecutive_errors += 1
            if self._consecutive_errors < _ERROR_RETRY:
                return self.__getitem__((index + 1) % len(self.reader))
            else:
                raise e
        self._consecutive_errors = 0
        if self.transform is not None:
            img = self.transform(img)
        if target is None:
            target = -1
        elif self.target_transform is not None:
            target = self.target_transform(target)

        return img, target+self._offset


class FolderTaskDataset(ImageDataset):

    def __init__(self, reader, root=None, class_map=None, load_bytes=False, **kwargs):

        super(FolderTaskDataset, self).__init__(root, reader=reader, class_map=class_map, load_bytes=load_bytes, **kwargs)

        # Add number of classes
        self.num_classes = len(set([sample[1] for sample in self.reader.samples]))
        self._offset = 0

    def set_class_offset(self, offset):
        self._offset = offset

    def __getitem__(self, index):

        img, target = super().__getitem__(index)
        return img, target+self._offset


class ClassIncrementalDataset:

    def __init__(self, task_datasets, offset_task_labels=False):

        self.offset_task_labels = offset_task_labels
        self.task_datasets = []
        self.task_order = []

        _offset = 0
        for task_name, task_ds in task_datasets.items():

            if self.offset_task_labels:
                task_ds.set_class_offset(_offset)
                _offset += task_ds.num_classes
            
            self.task_datasets.append(task_ds)
            self.task_order.append(task_name)

    def __getitem__(self, index):
        return self.task_datasets[index]

    def __len__(self):
        return len(self.task_datasets)
