import numpy as np
import torch
from torch.utils.data import Dataset


class EnvDataset(Dataset):
    
    def __init__(self, input_dim, output_dim, device, inputs=None, targets=None, transform=None):
        if inputs is None and targets is None:
            self.inputs = torch.empty((0, input_dim)).to(device)
            self.targets = torch.empty((0, output_dim)).to(device)
        else:
            self.inputs = inputs
            self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        x = self.inputs[index]
        if self.transform:
            x = self.transform(x)
        y = self.targets[index]
        return x, y

    def __len__(self):
        return self.inputs.shape[0]