import random

import DataSet
import Network
import NetworkStrategy.AbstractNetworkStrategy as AbstractNetworkStrategy


class OnlyOneNetworkStrategy(AbstractNetworkStrategy.AbstractNetworkStrategy):

    def __init__(self, dataLoader, plotter, config, svmParameters, parallelParameters, numPartitions):
        self._config            = config.attributes
        if parallelParameters == 0:
            self._allowedParameters = float('inf')
        else:
            self._allowedParameters = svmParameters + parallelParameters
        self._numPartitions     = numPartitions
        self._plotter           = plotter
        self._dataSet           = DataSet.DataSet(config, dataLoader.getTrainingData(), dataLoader.getValidationData(), dataLoader.getTestData())
        self._network           = None
        

    def train(self):
        bestNetwork = None
        bestValidationLoss    = float('inf')
        
        sweep  = 0
        sweeps = self._numPartitions*self._config['test']['sweeps']
            
        while sweep < sweeps:
            layers       = random.randint(self._config['test']['layersBounds']         [0], self._numPartitions*self._config['test']['layersBounds']         [1])
            neurons      = random.randint(self._config['test']['neuronsPerLayerBounds'][0], self._numPartitions*self._config['test']['neuronsPerLayerBounds'][1])
            learningRate = random.random() * (self._config['test']['learningRateBounds'][1] - self._config['test']['learningRateBounds'][0]) + self._config['test']['learningRateBounds'][0]
        
            network = Network.Network(self._config, layers, neurons, learningRate)
            
            if network.countParameters() > self._allowedParameters:
                continue

            network.train(self._dataSet)
            samplesX, samplesY = self._dataSet.getAllValidationItems()
            validationLoss = network.determineLoss(samplesX, samplesY, lambda scaledY: self._dataSet.unscaleY(scaledY))
        
            if validationLoss < bestValidationLoss:
                bestNetwork = network
                
            sweep += 1
                    
        self._network = bestNetwork
        

    def test(self):
        samplesX, samplesY = self._dataSet.getAllTestItems()
        return self._network.determineLoss(samplesX, samplesY, lambda scaledY: self._dataSet.unscaleY(scaledY))


    def plotSamples(self):
        torchSamplesX, torchSamplesY = self._dataSet.getAllTestItems()

        numpyUnscaledSamplesX    = self._dataSet.unscaleX(torchSamplesX).numpy()
        numpyUnscaledSamplesY    = self._dataSet.unscaleY(torchSamplesY).numpy()
        
        self._plotter.plotSamples(numpyUnscaledSamplesX, numpyUnscaledSamplesY, 'only one network', 'OnlyOneNetwork')


    def plotPrediction(self):
        torchSamplesX, torchSamplesY = self._dataSet.getAllTestItems()
        torchPredictionY = self._network.predict(torchSamplesX)

        numpyUnscaledSamplesX    = self._dataSet.unscaleX(torchSamplesX).numpy()
        numpyUnscaledSamplesY    = self._dataSet.unscaleY(torchSamplesY).numpy()
        numpyUnscaledPredictionY = self._dataSet.unscaleY(torchPredictionY).detach().numpy()

        self._plotter.plotPrediction(numpyUnscaledSamplesX, numpyUnscaledSamplesY, numpyUnscaledPredictionY, 'only one network', 'OnlyOneNetwork')
        

    def getNumParameters(self):
        return self._network.countParameters()
    
    def getNumAllowedParameters(self):
        return self._allowedParameters

    def getLearningRate(self):
        return self._network.getLearningRate()
    
    def getLayers(self):
        return self._network.getLayers()
    
    def getNeurons(self):
        return self._network.getNeurons()
    
    def getFirstWeights(self):
        return self._network.getFirstWeights()
