import numpy as np
import math


class FlattenDataset:
    def __init__(self, dataset, num_epochs=1):
        self.dataset = dataset
        self.num_epochs = num_epochs
        self.batch_size = dataset.batch_size
        self.size = len(dataset.data)
        self._labels = None
        self.reset()

    @property
    def labels(self):
        if self._labels is None:
            self._labels = np.zeros((self.size,), dtype=np.int_)
            for i, label in enumerate(self.dataset.labels):
                self._labels[self.dataset.labels2indices[label]] = i
        return self._labels

    def reset(self):
        self.dataset.reset()
        self.rng = self.dataset.rng

    def __len__(self):
        return self.num_epochs * math.ceil(self.size / self.batch_size)

    def __iter__(self):
        for _ in range(self.num_epochs):
            indices = self.rng.permutation(self.size)

            for i in range(0, self.size, self.batch_size):
                slice_ = indices[i:i + self.batch_size]
                inputs = self.dataset.transform(self.dataset.data[slice_])
                targets = self.labels[slice_]
                yield inputs, targets
