import copy
import torch

import ScalingEngine


class DataSet(torch.utils.data.Dataset):

    def __init__(self, config, trainingData, validationData, testData):
        super().__init__()
        self._config = config.attributes
        self._scalingEngineX = None
        self._scalingEngineY = None
        self._trainingData   = None
        self._testData       = None
        
        self.initDataSet(copy.deepcopy(trainingData), copy.deepcopy(validationData), copy.deepcopy(testData))


    def initDataSet(self, trainingData, validationData, testData):
        # Scale data
        self._scalingEngineX = ScalingEngine.ScalingEngine(self._config['test']['featureRangeX'])
        self._scalingEngineY = ScalingEngine.ScalingEngine(self._config['test']['featureRangeY'])
        trainingData  [:, :self._config['input']['xDimensions']] = self._scalingEngineX.fitAndScale(trainingData[:, :self._config['input']['xDimensions']])
        trainingData  [:, self._config['input']['xDimensions']:] = self._scalingEngineY.fitAndScale(trainingData[:, self._config['input']['xDimensions']:])
        validationData[:, :self._config['input']['xDimensions']] = self._scalingEngineX.scale(    validationData[:, :self._config['input']['xDimensions']])
        validationData[:, self._config['input']['xDimensions']:] = self._scalingEngineY.scale(    validationData[:, self._config['input']['xDimensions']:])
        testData      [:, :self._config['input']['xDimensions']] = self._scalingEngineX.scale(          testData[:, :self._config['input']['xDimensions']])
        testData      [:, self._config['input']['xDimensions']:] = self._scalingEngineY.scale(          testData[:, self._config['input']['xDimensions']:])

        # Convert to torch tensor
        self._trainingData   = torch.from_numpy(trainingData.astype('float32', casting='same_kind'))
        self._validationData = torch.from_numpy(validationData.astype('float32', casting='same_kind'))
        self._testData       = torch.from_numpy(testData.astype(    'float32', casting='same_kind'))


    def __len__(self): 
        return self._trainingData.shape[0]


    def __getitem__(self, idx):
        return self._trainingData[idx, :self._config['input']['xDimensions']], \
               self._trainingData[idx, self._config['input']['xDimensions']:]


    def getAllTrainingItems(self):
        return self._trainingData[:, :self._config['input']['xDimensions']], \
               self._trainingData[:, self._config['input']['xDimensions']:]
    

    def getAllValidationItems(self):
        return self._validationData[:, :self._config['input']['xDimensions']], \
               self._validationData[:, self._config['input']['xDimensions']:]


    def getAllTestItems(self):
        return self._testData[:, :self._config['input']['xDimensions']], \
               self._testData[:, self._config['input']['xDimensions']:]


    def unscaleX(self, scaledX):
        numpyScaledX = scaledX.numpy()
        copiedNumpyScaledX = copy.deepcopy(numpyScaledX)
        copiedNumpyUnscaledX = self._scalingEngineX.unscale(copiedNumpyScaledX)
        copiedTorchUnscaledX = torch.from_numpy(copiedNumpyUnscaledX.astype('float32', casting='same_kind'))
        return copiedTorchUnscaledX
        

    def unscaleY(self, scaledY):
        numpyScaledY = scaledY.numpy()
        copiedNumpyScaledY = copy.deepcopy(numpyScaledY)
        copiedNumpyUnscaledY = self._scalingEngineY.unscale(copiedNumpyScaledY)
        copiedTorchUnscaledY = torch.from_numpy(copiedNumpyUnscaledY.astype('float32', casting='same_kind'))
        return copiedTorchUnscaledY

