import Config
import Plotter
import DirFinder
import DataLoader
import ConfigReader
import SavingEngine
import SVMParameterReader
import NetworkStrategy.OnlyOneNetwork.OnlyOneNetworkStrategy as OnlyOneNetworkStrategy
import NetworkStrategy.ParallelNetworks.AbstractDataSeparator as AbstractDataSeparator
import NetworkStrategy.ParallelNetworks.ParallelNetworksStrategy as ParallelNetworksStrategy


def main():
    config         = configurate()
    dirFinder      = DirFinder.DirFinder(config)
    ConfigReader.ConfigReader(dirFinder.getDir()).readConfig(config)

    dataLoader         = DataLoader.DataLoader(config, dirFinder.getDir())
    svmParameterReader = SVMParameterReader.SVMParameterReader(dirFinder.getDir())
    plotter            = Plotter.Plotter(config, dirFinder.getDir())

    svmParameters  = svmParameterReader.readSVMParameters()

    parallelParams        = []
    numPartitions         = 1
    parallelLearningRates = []
    parallelLayers        = []
    parallelNeurons       = []
    parallelWeights       = []

    partitioningAvailable = dirFinder.checkForSVM(config)
    if partitioningAvailable:
        parallelNetworkStrategy = ParallelNetworksStrategy.ParallelNetworksStrategy(dataLoader, plotter, config, dirFinder.getDir())
        parallelNetworkStrategy.train()
        parallelNetworksLoss = parallelNetworkStrategy.test()
        parallelNetworkStrategy.plotSamples()
        parallelNetworkStrategy.plotPrediction()
        parallelParams        = parallelNetworkStrategy.getNumParameters()
        numPartitions         = parallelNetworkStrategy.getNumPartitions()
        parallelLearningRates = parallelNetworkStrategy.getLearningRates()
        parallelLayers        = parallelNetworkStrategy.getLayers()
        parallelNeurons       = parallelNetworkStrategy.getNeurons()
        parallelWeights       = parallelNetworkStrategy.getFirstWeights()
        
    onlyOneNetworkStrategy = OnlyOneNetworkStrategy.OnlyOneNetworkStrategy(dataLoader, plotter, config, svmParameters, sum(parallelParams), numPartitions)
    onlyOneNetworkStrategy.train()
    onlyOneNetworkLoss  = onlyOneNetworkStrategy.test()
    onlyOneNetworkStrategy.plotSamples()
    onlyOneNetworkStrategy.plotPrediction()
    oneParams        = onlyOneNetworkStrategy.getNumParameters()
    oneAllowedParams = onlyOneNetworkStrategy.getNumAllowedParameters()
    oneLearningRates = onlyOneNetworkStrategy.getLearningRate()
    oneLayers        = onlyOneNetworkStrategy.getLayers()
    oneNeurons       = onlyOneNetworkStrategy.getNeurons()
    oneWeights       = onlyOneNetworkStrategy.getFirstWeights()

    savingEngine = SavingEngine.SavingEngine(dirFinder.getDir())
    savingEngine.saveConfig(config)
    savingEngine.saveHyperparameters([oneLearningRates] + parallelLearningRates, [oneLayers] + parallelLayers, [oneNeurons] + parallelNeurons, [oneParams] + parallelParams, oneAllowedParams)
    savingEngine.saveWeights([oneWeights] + parallelWeights)
    
    if partitioningAvailable:
        savingEngine.saveLosses([onlyOneNetworkLoss, parallelNetworksLoss], ['Only one network', 'Parallel networks'])
    else:
        savingEngine.saveLosses([onlyOneNetworkLoss], ['Only one network'])


def configurate():
    return Config.Config(pipelineDir               = '../pipeline/',
                         trainingDataFile          = 'training_data.csv',
                         validationDataFile        = 'validation_data.csv',
                         testDataFile              = 'test_data.csv',
                         partitioningFile          = 'svm.pkl',
                         partitioningSettingFile   = 'svmSettings.pkl',
                         dataSeparationKnownBounds = None,
                         dataSeparationType        = AbstractDataSeparator.DataSeparatorType.PARTITIONING,
                         totalSamples              = 5000,
                         sampleSplit               = [0.7, 0.15, 0.15],
                         epochs                    = 500,
                         sweeps                    = 100,
    )


if __name__ == '__main__':
    main()
