﻿import torch
import numpy
import random

import Networks.EarlyStopper


class Network(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        
        self._config = config

        # Choose random number of layers from given range
        lowerLayersBound = self._config['network']['layersBounds'][0]
        upperLayersBound = self._config['network']['layersBounds'][1]
        self._layers = random.randint(lowerLayersBound, upperLayersBound)

        # Choose random number of neurons per layer from given range
        lowerNeuronsPerLayerBound = self._config['network']['neuronsPerLayerBounds'][0]
        upperNeuronsPerLayerBound = self._config['network']['neuronsPerLayerBounds'][1]
        self._neuronsPerLayer = random.randint(lowerNeuronsPerLayerBound, upperNeuronsPerLayerBound)

        # Choose random learning rate from given range
        lowerLearningRateBound = self._config['network']['learningRateBounds'][0]
        upperLearningRateBound = self._config['network']['learningRateBounds'][1]
        self._learningRate = random.random() * (upperLearningRateBound - lowerLearningRateBound) + lowerLearningRateBound

        self._model = torch.nn.Sequential()

        if self._layers == 1:
            self._model.append(torch.nn.Linear(self._config['input']['xDimensions'], self._config['input']['yDimensions']))

        else: # self._layers > 1
            # First layer
            self._model.append(torch.nn.Linear(self._config['input']['xDimensions'], self._neuronsPerLayer))
            self._model.append(torch.nn.Tanh())

            # Layers in between
            for _ in range(self._layers-2):
                self._model.append(torch.nn.Linear(self._neuronsPerLayer, self._neuronsPerLayer))
                self._model.append(torch.nn.Tanh())

            # Last layer
            self._model.append(torch.nn.Linear(self._neuronsPerLayer, self._config['input']['yDimensions']))

        self._optimizationLoss   = torch.nn.MSELoss(reduction='mean') # Mean of the squared error of all predictions of one batch
        self._predictionRankLoss = torch.nn.MSELoss(reduction='none') # Mean is not formed but all the squared errors are output directly
        self._optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learningRate)


    def getLearningRate(self):
        return self._learningRate
    

    def getLayers(self):
        return self._layers
    

    def getNeuronsPerLayer(self):
        return self._neuronsPerLayer


    def forward(self, x):
        return self._model(x)


    def trainForOneEpoch(self, dataSet):
        # Check for empty dataSet
        if dataSet.__len__() == 0:
            return
        
        # Wrap dataSet with dataLoader -> Creates batches of samples
        dataLoader = torch.utils.data.DataLoader(dataSet, batch_size=self._config['test']['batchSize'], shuffle=True)
        
        # Put model in training mode
        self._model.train()
        
        for dataBatch in dataLoader:
            dataX, dataY = dataBatch

            # Reset all gradients
            self._optimizer.zero_grad()

            # Forward-, backward-propagation and optimization
            loss = self._optimizationLoss(self._model(dataX), dataY)
            loss.backward()
            self._optimizer.step()


    def trainUntilOverfitting(self, trainDataSet, valDataSet, zeroBasedMaxEpochs):
        earlyStopper = Networks.EarlyStopper.EarlyStopper(self._config['test']['addingCatchUpTrainingPatience'])

        zeroBasedEpoch = 0
        for zeroBasedEpoch in range(zeroBasedMaxEpochs+1):
            self.trainForOneEpoch(trainDataSet)

            # Put model in evaluation mode
            self._model.eval()
            [valDataX, valDataY] = valDataSet.getAllItems()

            with torch.no_grad():
                valLoss = self._optimizationLoss(self._model(valDataX), valDataY)
            
            if earlyStopper.overfittingReached(valLoss):
                break

        # Return training loss
        self._model.eval()
        [trainDataX, trainDataY] = trainDataSet.getAllItems()
        with torch.no_grad():
            trainLoss = self._optimizationLoss(self._model(trainDataX), trainDataY)
            return zeroBasedEpoch+1, trainLoss

    
    def predict(self, data):
        dataTensorX = torch.from_numpy(data[0].astype('float32', casting='same_kind'))
        dataTensorY = torch.from_numpy(data[1].astype('float32', casting='same_kind'))
        
        # Put model in test mode
        self._model.eval()

        with torch.no_grad():
            prediction = self._model(dataTensorX)
            loss       = numpy.mean(self._predictionRankLoss(prediction, dataTensorY).numpy(), axis=1)
            # In case of multiple output dimensions, the loss of one prediction is the mean of the 
            # losses in all output dimensions, not its euclidean norm. Found in literature, applied here.
            return [prediction.detach().numpy(), loss]
