﻿import numpy
import sklearn

import Networks.Network
import Database.DataSet


class NetworkManager():

    def __init__(self, config):
        self._config = config.attributes

        self._nets = dict.fromkeys(range(self._config['test']['initialNumOfNets']))
        for netId in self._nets:
            self._nets[netId] = Networks.Network.Network(self._config)

        self._highestAssignedNetId = netId

    # --- Basic functionality ---

    def getActiveNetIds(self):
        return list(self._nets.keys())


    def getNetLayers(self, netId):
        return self._nets[netId].getLayers()
    

    def getNetNeuronsPerLayer(self, netId):
        return self._nets[netId].getNeuronsPerLayer()


    def getNetLearningRate(self, netId):
        return self._nets[netId].getLearningRate()


    def train(self, database):
        for netId in self._nets:
            self._nets[netId].trainForOneEpoch(database.getDataSetMappedToNet(netId))
        
        
    def predict(self, database):
        predsAndLosses = dict.fromkeys(self.getActiveNetIds())
        for netId in self._nets:
            predsAndLosses[netId] = self._nets[netId].predict(database.getAllTrainingSamples())
        return predsAndLosses

    # --- Basic functionality ---

    # --- Dropping and Adding ---

    def dropNets(self, database, zeroBasedEpoch):
        stats = {}
        stats['active'] = False
        stats['replacabilities'] = []
        stats['droppedNets'] = 0

        if self._config['test']['droppingActive'] != True:
            return stats

        if zeroBasedEpoch < self._config['test']['startDroppingAdding']:
            return stats

        if zeroBasedEpoch % self._config['test']['droppingInterv'] != 0:
            return stats


        for netId in self.getActiveNetIds():
            replacability = database.getReplacabilityOfNet(netId)
            if replacability < self._config['test']['droppingReplacability']:
                self.removeOneNet(database, netId)
                stats['droppedNets'] += 1
            stats['replacabilities'].append(replacability)

        stats['active'] = True
        return stats


    def addNet(self, database, zeroBasedEpoch):
        stats = {}
        stats['active'] = False
        stats['trainedEpochs'] = 0
        stats['improvement'] = 0
        stats['addedNet'] = False

        if self._config['test']['addingActive'] != True:
            return stats

        if zeroBasedEpoch < self._config['test']['startDroppingAdding']:
            return stats

        if zeroBasedEpoch % self._config['test']['addingInterv'] != 0:
            return stats

        # Define bad samples as those whose loss > mean + stddev
        losses = database.getLoss()
        lossBound = numpy.mean(losses) + numpy.std(losses)
        samplesX, samplesY = database.getTrainingSamplesAboveLossBound(lossBound)

        # Quit adding if too few bad samples have been found
        minNumBadSamples = 5
        if samplesX.shape[0] < minNumBadSamples:
            stats['active'] = True
            return stats
        
        # Split samples into training and validation set
        trainSamplesX, valSamplesX, trainSamplesY, valSamplesY= sklearn.model_selection.train_test_split(samplesX, samplesY, test_size=0.2)

        trainDataSet = Database.DataSet.DataSet(self._config, numpy.concatenate((trainSamplesX, trainSamplesY), axis=1))
        valDataSet   = Database.DataSet.DataSet(self._config, numpy.concatenate((  valSamplesX,   valSamplesY), axis=1))

        # Catch-up training with one new net with those bad samples
        newNet = Networks.Network.Network(self._config)
        stats['trainedEpochs'], newLoss = newNet.trainUntilOverfitting(trainDataSet, valDataSet, zeroBasedEpoch)

        # Decide whether to keep the new network
        oldLoss = numpy.mean(database.getLossAboveLossBound(lossBound))
        improvement = oldLoss/newLoss
        if improvement > self._config['test']['addingImprovement']:
            self.addOneNet(newNet, database)
            stats['addedNet'] = True

        stats['active'] = True
        stats['improvement'] = improvement
        return stats
     

    def removeOneNet(self, database, netId):
        # We want to keep at least one net
        if len(self._nets) > 1:
            database.removeNet(netId)
            del self._nets[netId]
            
    
    def addOneNet(self, net, database):
        self._highestAssignedNetId += 1
        self._nets[self._highestAssignedNetId] = net

    # --- Dropping and Adding ---
