﻿import torch


class DataSet(torch.utils.data.Dataset):

    def __init__(self, config, sampleSet):
        super().__init__()
        self._config = config
        self._sampleSet = torch.from_numpy(sampleSet.astype('float32', casting='same_kind'))


    def __len__(self):
        return self._sampleSet.shape[0]


    def __getitem__(self, idx):
        return self._sampleSet[idx, :self._config['input']['xDimensions']], \
               self._sampleSet[idx, self._config['input']['xDimensions']:]


    def getAllItems(self):
        return self._sampleSet[:, :self._config['input']['xDimensions']], \
               self._sampleSet[:, self._config['input']['xDimensions']:]
