import numpy
import random

import DataSet
import Network
import NetworkStrategy.AbstractNetworkStrategy as AbstractNetworkStrategy
import NetworkStrategy.ParallelNetworks.AbstractDataSeparator as AbstractDataSeparator
import NetworkStrategy.ParallelNetworks.KnownBoundsDataSeparator as KnownBoundsDataSeparator
import NetworkStrategy.ParallelNetworks.PartitioningDataSeparator as PartitioningDataSeparator


class ParallelNetworksStrategy(AbstractNetworkStrategy.AbstractNetworkStrategy):

    def __init__(self, dataLoader, plotter, config, inputDir):
        self._config  = config.attributes
        self._plotter = plotter
        
        dataSeparator = None
        if   self._config['test']['dataSeparationType'] == AbstractDataSeparator.DataSeparatorType.KNOWN_BOUNDS:
            dataSeparator = KnownBoundsDataSeparator.KnownBoundsDataSeparator(config)
        elif self._config['test']['dataSeparationType'] == AbstractDataSeparator.DataSeparatorType.PARTITIONING:
            dataSeparator = PartitioningDataSeparator.PartitioningDataSeparator(config, inputDir)
        else:
            pass

        separatedTrainingData   = dataSeparator.separateData(dataLoader.getTrainingData())
        separatedValidationData = dataSeparator.separateData(dataLoader.getValidationData())
        separatedTestData       = dataSeparator.separateData(dataLoader.getTestData())

        self._dataSets = []
        for trainingData, validationData, testData in zip(separatedTrainingData, separatedValidationData, separatedTestData):
            self._dataSets.append(DataSet.DataSet(config, trainingData, validationData, testData))
            
        self._networks = []


    def train(self):
        for dataSet in self._dataSets:

            bestNetwork = None
            bestValidationLoss    = float('inf')
        
            for _ in range(self._config['test']['sweeps']):

                layers       = random.randint(self._config['test']['layersBounds']         [0], self._config['test']['layersBounds']         [1])
                neurons      = random.randint(self._config['test']['neuronsPerLayerBounds'][0], self._config['test']['neuronsPerLayerBounds'][1])
                learningRate = random.random() * (self._config['test']['learningRateBounds'][1] - self._config['test']['learningRateBounds'][0]) + self._config['test']['learningRateBounds'][0]
        
                network = Network.Network(self._config, layers, neurons, learningRate)
                network.train(dataSet)
                samplesX, samplesY = dataSet.getAllValidationItems()
                
                if samplesX.numel() == 0:
                    bestNetwork = network
                    break

                validationLoss = network.determineLoss(samplesX, samplesY, lambda scaledY: dataSet.unscaleY(scaledY))
        
                if validationLoss < bestValidationLoss:
                    bestNetwork = network
                    
            self._networks.append(bestNetwork)


    def test(self):
        summedLen  = 0
        summedLoss = 0
        
        for network, dataSet in zip(self._networks, self._dataSets):
            samplesX, samplesY = dataSet.getAllTestItems()
            if samplesX.numel() == 0:
                continue
            loss = network.determineLoss(samplesX, samplesY, lambda scaledY: dataSet.unscaleY(scaledY))
            summedLen  = summedLen  + samplesX.shape[0]
            summedLoss = summedLoss + samplesX.shape[0] * loss
            
        return summedLoss / summedLen


    def plotSamples(self):
        for index, dataSet in enumerate(self._dataSets):
            torchSamplesX, torchSamplesY = dataSet.getAllTestItems()

            numpyUnscaledSamplesX    = dataSet.unscaleX(torchSamplesX).numpy()
            numpyUnscaledSamplesY    = dataSet.unscaleY(torchSamplesY).numpy()
            
            self._plotter.plotSamples(numpyUnscaledSamplesX, numpyUnscaledSamplesY, 'parallel network no ' + str(index+1), 'ParallelNetworkNo' + str(index+1))


    def plotPrediction(self):
        numpyUnscaledSamplesX    = numpy.zeros((0, self._config['input']['xDimensions']))
        numpyUnscaledSamplesY    = numpy.zeros((0, self._config['input']['yDimensions']))
        numpyUnscaledPredictionY = numpy.zeros((0, self._config['input']['yDimensions']))

        for network, dataSet in zip(self._networks, self._dataSets):
            torchSamplesX, torchSamplesY = dataSet.getAllTestItems()
            torchPredictionY = network.predict(torchSamplesX)

            numpyUnscaledSamplesX    = numpy.append(arr=numpyUnscaledSamplesX,    values=dataSet.unscaleX(torchSamplesX).numpy(),             axis=0)
            numpyUnscaledSamplesY    = numpy.append(arr=numpyUnscaledSamplesY,    values=dataSet.unscaleY(torchSamplesY).numpy(),             axis=0)
            numpyUnscaledPredictionY = numpy.append(arr=numpyUnscaledPredictionY, values=dataSet.unscaleY(torchPredictionY).detach().numpy(), axis=0)

        self._plotter.plotPrediction(numpyUnscaledSamplesX, numpyUnscaledSamplesY, numpyUnscaledPredictionY, 'parallel networks', 'ParallelNetworks')
        

    def getNumParameters(self):
        return [network.countParameters() for network in self._networks]
    

    def getNumPartitions(self):
        return len(self._dataSets)
    

    def getLearningRates(self):
        return [network.getLearningRate() for network in self._networks]
    

    def getLayers(self):
        return [network.getLayers() for network in self._networks]
    

    def getNeurons(self):
        return [network.getNeurons() for network in self._networks]
    

    def getFirstWeights(self):
        return [network.getFirstWeights() for network in self._networks]
    