
import numpy as np
from torch.utils.data import Dataset


class FromFileWrapper(Dataset):
    """
    Wraps dataset but loads data from files instead of memory
    """

    def __init__(self, dataset, files_list, load_data):
        self.dataset = dataset
        self.files_list = files_list
        self.load_data = load_data

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        data = self.dataset[item]
        data = self.load_data(*data)

        data = super().__getitem__(item)

        return data


class WrapTorchVisionTransform(object):
    """
    Wraps a torch vision transform to be usable with the pytorch_wrapper.
    """

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, sample):
        image, target = sample['image'], sample['target']
        return {'image': self.transform(image), 'target': target}


class TorchVisionWrapper(Dataset):
    """
    Wraps a torch vision dataset to be usable with the pytorch_wrapper.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        X, y = self.dataset[item]
        return {'inputs': {'X': X},
                'targets': {'y': y}}


class TorchVisionAutoEncoderWrapper(Dataset):
    """
    Wraps a torch vision dataset to be able to train an autoencoder.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        X, y = self.dataset[item]
        return {'inputs': {'X': X},
                'targets': {'X': X, 'y': y}}


class TorchVisionGanWrapper(Dataset):
    """
    Wraps a torch vision dataset to be able to train a ganerative model such as a gan.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        X, _ = self.dataset[item]

        y_hat = np.zeros(X.shape[0], np.float32)
        y_val = np.ones(X.shape[0], np.float32)

        return {'inputs': {'X_val': X},
                'targets': {'y_hat': y_hat, 'y_val': y_val}}
