import numpy as np
import os

from numpy.random import default_rng
from collections import namedtuple

from comln.datasets.flatten import FlattenDataset


Dataset = namedtuple('Dataset', ['inputs', 'targets', 'infos'])


class MetaDataset:
    def __init__(
            self,
            root,
            batch_size,
            shots=5,
            ways=5,
            test_shots=None,
            size=None,
            split='train',
            seed=0,
            download=False
        ):
        self.root = os.path.join(os.path.expanduser(root), self.folder)
        self.batch_size = batch_size
        self.shots = shots
        self.ways = ways
        self.test_shots = shots if (test_shots is None) else test_shots
        self.size = size

        self.splits = split.split('+')
        assert all(split in ['train', 'val', 'test'] for split in self.splits)

        if download:
            self.download()

        self.seed = seed
        self.reset()

        self._data = None
        self._labels2indices = None
        self._labels = None
        self._num_classes = None

    def reset(self):
        self.rng = default_rng(self.seed)
        self.num_samples = 0

    @property
    def labels(self):
        if self._labels is None:
            self._labels = sorted(self.labels2indices.keys())
        return self._labels

    @property
    def num_classes(self):
        if self._num_classes is None:
            self._num_classes = len(self.labels)
        return self._num_classes

    def get_indices(self):
        total_shots = self.shots + self.test_shots

        class_indices = np.zeros((self.batch_size, self.ways), dtype=np.int_)
        indices = np.zeros((self.batch_size, self.ways, total_shots), dtype=np.int_)
        targets = np.zeros((self.batch_size, self.ways), dtype=np.int_)

        for idx in range(self.batch_size):
            class_indices[idx] = self.rng.choice(self.num_classes, size=(self.ways,), replace=False)
            targets[idx] = self.rng.permutation(self.ways)

            for way in range(self.ways):
                label = self.labels[class_indices[idx, way]]
                indices[idx, way] = self.rng.choice(self.labels2indices[label], size=(total_shots,), replace=False)

        return class_indices, indices, targets

    def transform(self, data):
        return data

    def __len__(self):
        if self.size is None:
            raise RuntimeError('The dataset has no length because it is infinite.')
        return self.size

    def __iter__(self):
        while (self.size is None) or (self.num_samples < self.size):
            class_indices, indices, targets = self.get_indices()
            data = self.data[indices]

            train = Dataset(
                inputs=self.transform(data[:, :, :self.shots].reshape((self.batch_size, -1) + data.shape[3:])),
                targets=targets.repeat(self.shots, axis=1),
                infos={'labels': class_indices, 'indices': indices[..., :self.shots]}
            )
            test = Dataset(
                inputs=self.transform(data[:, :, self.shots:].reshape((self.batch_size, -1) + data.shape[3:])),
                targets=targets.repeat(self.test_shots, axis=1),
                infos={'labels': class_indices, 'indices': indices[..., self.shots:]}
            )

            self.num_samples += 1
            yield {'train': train, 'test': test}

    def download(self):
        pass

    def flatten(self, num_epochs=1):
        return FlattenDataset(self, num_epochs=num_epochs)
