import csv
import numpy
import sklearn.model_selection


class DataLoader():
    def __init__(self, config, inputDir):
        self._config = config.attributes

        self.readDimensions(inputDir)
        
        self.loadData(inputDir)
        self.extractData()


    def readDimensions(self, inputDir):
        with open(inputDir + self._config['input']['trainingDataFile']) as inputFile:
            firstLine = next(csv.reader(inputFile))
            self._config['input']['xDimensions'] = int(firstLine[0].split('=')[1])
            self._config['input']['yDimensions'] = int(firstLine[1].split('=')[1])
            

    def loadData(self, inputDir):
        self._trainingData   = numpy.loadtxt(fname = inputDir + self._config['input'][  'trainingDataFile'], delimiter=',', skiprows=1)
        self._validationData = numpy.loadtxt(fname = inputDir + self._config['input']['validationDataFile'], delimiter=',', skiprows=1)
        self._testData       = numpy.loadtxt(fname = inputDir + self._config['input'][      'testDataFile'], delimiter=',', skiprows=1)


    def extractData(self):
        self._trainingData   = self._trainingData[  numpy.random.choice(a=self._trainingData.shape[0],   size=round(self._config['test']['totalSamples']*self._config['test']['sampleSplit'][0]), replace=False)]
        self._validationData = self._validationData[numpy.random.choice(a=self._validationData.shape[0], size=round(self._config['test']['totalSamples']*self._config['test']['sampleSplit'][1]), replace=False)]
        self._testData       = self._testData[      numpy.random.choice(a=self._testData.shape[0],       size=round(self._config['test']['totalSamples']*self._config['test']['sampleSplit'][2]), replace=False)]


    def getTrainingData(self):
        return self._trainingData
    

    def getValidationData(self):
        return self._validationData


    def getTestData(self):
        return self._testData
        