import math
import torch
import numpy as np
from typing import Tuple
from torch.utils.data import Dataset, DataLoader


class TensorDataset(Dataset):
    def __init__(self, inputs: torch.Tensor, targets: torch.Tensor) -> None:
        super().__init__()
        self.inputs = inputs
        self.targets = targets

    def __len__(self) -> int:
        return len(self.inputs)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.inputs[index], self.targets[index]


class UpsampleDataset(Dataset):

    def __init__(self, dataset: Dataset, upsample: int, seed: int = 42) -> None:
        super().__init__()

        self.dataset: Dataset = dataset
        self.upsample = upsample

        rng = np.random.default_rng(seed=seed)
        self.indices = rng.integers(low=0, high=len(self.dataset), size=(max(upsample, len(dataset)), ))
        self.indices[:len(self.dataset)] = np.asarray(range(len(self.dataset)))

    def __len__(self) -> int:
        return self.upsample
    
    def __getitem__(self, index) -> Tuple:
        return self.dataset[self.indices[index]]


class InMemoryDataloader(DataLoader):
    def __init__(self, dataset, batch_size: int = 0, shuffle: bool = False):
        self.dataset = dataset
        self.batch_size = batch_size if batch_size > 0 else len(dataset)
        self.shuffle = shuffle

        self.rng = torch.Generator()
        self.rng.manual_seed(2357)

        # prepare for first iteration
        self.__iter__()

    def __iter__(self) -> None:
        self._iterated = False
        self._last_index = 0
        if self.shuffle:
            self._indices = torch.randperm(len(self.dataset), generator=self.rng)
        else:
            self._indices = torch.arange(len(self.dataset))
        return self

    def __len__(self):
        return math.ceil(len(self.dataset) / self.batch_size)

    def __next__(self):
        if self._iterated:
            raise StopIteration

        from_index = self._last_index
        if self._last_index + self.batch_size >= len(self.dataset):
            to_index = len(self.dataset)
            self._iterated = True
        else:
            to_index = self._last_index + self.batch_size
            self._last_index += self.batch_size

        return self.dataset[self._indices[from_index:to_index]]
