import copy
import torch
import numpy as np
import random
from torch.utils.data import Dataset, ConcatDataset

from .pin_memory import pin_memory


class SetIter(object):
    def __init__(self, dataset):
        self.dataset = dataset
        self.length = len(dataset)
        self._i = 0

    def __next__(self):
        if self._i >= self.length:
            raise StopIteration
        item = self.dataset[self._i]
        self._i += 1
        return item


class BaseShell(Dataset):

    def __init__(self, length: int):
        r""" A shell for Packaging Dataset.

        Args:
            length: dataset length

        Extra features:
            support cut dataset, can be used to divide the training set and validation set
            support for loop
        """
        self._data_ids = np.arange(length)

    def copy(self, data_ids=None):
        """
        Share the data for save memory. only copy ids.

        Returns:
            a new Shell.
        """
        new = copy.copy(self)
        if data_ids is None:
            data_ids = self._data_ids
        new._data_ids = data_ids.copy()
        return new

    def shuffle(self):
        np.random.shuffle(self._data_ids)
        return self

    def select(self, start_percent: float, end_percent: float):
        """ Select a continuous segment of the dataset. The args is a percentage of the data length

        Args:
            start_percent: the start percentage of the data length
            end_percent: the end percentage of the data length

        Returns:
            a new Shell with selected data.

        """
        # assert 0 <= start_percent < end_percent <= 1

        out = self.copy()
        out_start_id = int(len(self) * start_percent)
        out_end_id = int(len(self) * end_percent)
        out._data_ids = out._data_ids[out_start_id:out_end_id]
        return out

    def random_select(self, percent):
        nums = int(len(self) * percent)
        out = self.copy()
        out._data_ids = random.sample(out._data_ids, nums)
        out._data_ids = sorted(out._data_ids)
        return out

    def cut(self, *cut_percents: float):
        r"""cut the dataset. for dividing the training set or validation set

        Args:
            *cut_percents: cutting point.

        Returns:
            cut Shells.

        Examples:
            ```python
            >>> dataset = MNIST()
            >>> my_dataset = DatasetShell(dataset)
            >>> train_set, valid_set = my_dataset.cut(0.8)
            >>> train_set, eval_set, valid_set = my_dataset.cut(0.8, 0.9)
            ```
        """
        assert all([0 < p < 1 for p in cut_percents])
        cut_percents = sorted(cut_percents)

        start_percent = 0
        result = []
        for cut_percent in cut_percents:
            result.append(self.select(start_percent, cut_percent))
            start_percent = cut_percent
        if start_percent < 1:
            result.append(self.select(start_percent, 1))
        return tuple(result)

    def random_cut(self, *percents):
        assert np.isclose(np.sum(percents), 1)
        data_ids = list(self._data_ids.copy())
        select_nums = [int(percent * len(self)) for percent in percents]

        selects = []
        for num in select_nums[:-1]:
            select = random.sample(data_ids, num)
            data_ids = list(set(data_ids) - set(select))
            selects.append(select)
        selects.append(data_ids)

        return [self.copy(sorted(select)) for select in selects]

    def __iter__(self):
        return SetIter(self)

    def __getitem__(self, i: int):
        raise NotImplementedError

    def __len__(self):
        return len(self._data_ids)


class DatasetShell(BaseShell):
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
        super().__init__(len(self.dataset))

    def __getitem__(self, i: int):
        if i >= len(self):
            raise IndexError('index {} is bigger than the max length {}.'.format(i, len(self)))
        return self.dataset[self._data_ids[i]]

    def __repr__(self):
        return "Dataset({}): {}".format(len(self), self.dataset.__class__.__name__)


class SubShell(BaseShell):
    def __init__(self, *datasets: Dataset):
        self.datasets = ConcatDataset(datasets)
        super().__init__(len(self.datasets))

    def __getitem__(self, i: int):
        return self.datasets[i]

    def __repr__(self):
        return "Dataset({}): [{}]".format(len(self), ', '.join(
            [dataset.__class__.__name__ + "(" + str(len(dataset)) + ")" for dataset in self.datasets.datasets]))


class CatShell(BaseShell):
    def __init__(self, *datasets: Dataset):
        self.datasets = datasets

        for dataset in self.datasets:
            assert len(dataset) == len(self.datasets[0])
        super().__init__(len(self.datasets[0]))

    def __getitem__(self, i: int):
        return tuple([dataset[i] for dataset in self.datasets])

    def __repr__(self):
        return "CatDataset({}): [{}]".format(len(self), ', '.join(
            [dataset.__class__.__name__ + "(" + str(len(dataset)) + ")" for dataset in self.datasets]))


class SequenceShell(BaseShell):
    def __init__(self, dataset: Dataset, seq_len: int):
        self.dataset = dataset
        self.seq_len = seq_len
        super().__init__(len(dataset) - seq_len + 1)

    def __getitem__(self, item: int):
        data = [self.dataset[i] for i in np.arange(item, item + self.seq_len)]
        return pin_memory(data)
