import numpy as np
import torch
from torch.utils.data import TensorDataset


class LatentDataset(TensorDataset):
    def __init__(self, path):
        dataset = np.load(path)
        x = torch.stack([torch.as_tensor(tensor, dtype=torch.float).view(512, 8, 8) for tensor in dataset['x']])
        y = torch.stack([torch.as_tensor([labels], dtype=torch.long) for labels in dataset['y']])
        super().__init__(x, y)
