﻿import Config
import Database.Database
import Networks.NetworkManager
import Evaluation.Evaluation
import Evaluation.EvaluationPlotter
import Evaluation.SamplePlotter
import Evaluation.SavingEngine
import Evaluation.OutputDirCreator
import Evaluation.SVMPlotter
import SVM.SVM


def main():
    config = configurate()
    database = Database.Database.DataBase(config)
    networkManager = Networks.NetworkManager.NetworkManager(config)
    evaluation = Evaluation.Evaluation.Evaluation()
    outputDirCreator = Evaluation.OutputDirCreator.OutputDirCreator(config)
    samplePlotter = Evaluation.SamplePlotter.SamplePlotter(database, config, outputDirCreator.getMappingOutputDir())

    epoch = 0
    for epoch in range(config.attributes['test']['epochs']):
        
        predictionsAndLosses = networkManager.predict(database)
        database.submitPredictionsAndLosses(predictionsAndLosses, networkManager.getActiveNetIds())
    
        samplePlotter.plotMapping(database, networkManager.getActiveNetIds(), epoch)

        networkManager.train(database)
    
        dropStats = networkManager.dropNets(database, epoch)
        addStats  = networkManager.addNet(database, epoch)

        evaluation.submitActiveNetIds(networkManager.getActiveNetIds())
        evaluation.submitDroppingStats(dropStats)
        evaluation.submitAddingStats(addStats)

    predictionsAndLosses = networkManager.predict(database)
    database.submitPredictionsAndLosses(predictionsAndLosses, networkManager.getActiveNetIds())
    
    samplePlotter.plotMapping(database, networkManager.getActiveNetIds(), epoch+1)

    svm = SVM.SVM.SVM(database, outputDirCreator.getGeneralOutputDir(), config)
    svm.classify()
    svm.export()

    Evaluation.EvaluationPlotter.EvaluationPlotter(evaluation, outputDirCreator.getGeneralOutputDir(), config).plotEvaluation()
    Evaluation.SVMPlotter.SVMPlotter(config, database, outputDirCreator.getGeneralOutputDir()).plotSVM(svm, networkManager.getActiveNetIds())
    savingEngine = Evaluation.SavingEngine.SavingEngine(config, database, outputDirCreator.getGeneralOutputDir())
    savingEngine.saveTrainingSamples()
    savingEngine.copyOtherSamples()
    savingEngine.saveSVMParameters(svm.getParamters())
    savingEngine.saveConfig(
        netIds          = networkManager.getActiveNetIds(),
        learningRates   = [networkManager.getNetLayers(         netId) for netId in networkManager.getActiveNetIds()], 
        layers          = [networkManager.getNetNeuronsPerLayer(netId) for netId in networkManager.getActiveNetIds()], 
        neuronsPerLayer = [networkManager.getNetLearningRate(   netId) for netId in networkManager.getActiveNetIds()]
    )


def configurate():
    return Config.Config(inputDir                      = '../input/',
                         outputDir                     = '../output/',
                         inputSamplesFile              = 'training_data.csv',
                         samples                       = 5000,
                         layersBounds                  = [2, 5],
                         neuronsPerLayerBounds         = [3, 8],
                         learningRateBounds            = [0.0001, 0.005],
                         batchSize                     = 16,
                         epochs                        = 1000,
                         initialNumOfNets              = 10, 
                         featureRangeX                 = (-1,1),
                         featureRangeY                 = (-1,1),
                         startDroppingAdding           = 30,
                         droppingActive                = True,
                         droppingInterv                = 1,
                         droppingReplacability         = 1.8,
                         addingActive                  = True,
                         addingInterv                  = 1,
                         addingCatchUpTrainingPatience = 20,
                         addingImprovement             = 1,)


if __name__ == '__main__':
    main()