import torch


class Network(torch.nn.Module):
    def __init__(self, config, layers, neurons, learningRate):
        super().__init__()

        self._config       = config
        self._layers       = layers
        self._neurons      = neurons
        self._learningRate = learningRate

        self._model = torch.nn.Sequential()

        if self._layers == 1:
            self._model.append(torch.nn.Linear(self._config['input']['xDimensions'], self._config['input']['yDimensions']))

        else: # self._layers > 1
            # First layer
            self._model.append(torch.nn.Linear(self._config['input']['xDimensions'], self._neurons))
            self._model.append(torch.nn.Tanh())

            # Layers in between
            for _ in range(self._layers-2):
                self._model.append(torch.nn.Linear(self._neurons, self._neurons))
                self._model.append(torch.nn.Tanh())

            # Last layer
            self._model.append(torch.nn.Linear(self._neurons, self._config['input']['yDimensions']))
        
        self._criterion = torch.nn.MSELoss()
        self._optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learningRate)
        

    def forward(self, x):
        y = self._model(x)
        return y


    def train(self, dataSet):
        # Wrap dataSet with dataLoader -> Creates batches of samples
        dataLoader = torch.utils.data.DataLoader(dataSet, batch_size=self._config['test']['batchSize'], shuffle=True)

        # Put model in training mode
        self._model.train()
        
        for _ in range(self._config['test']['epochs']):
            for dataBatch in dataLoader:
                dataX, dataY = dataBatch

                # Reset all gradients
                self._optimizer.zero_grad()

                # Forward-, backward-propagation and optimization
                loss = self._criterion(self._model(dataX), dataY)
                loss.backward()
                self._optimizer.step()


    def predict(self, x):
        # Put model in test mode
        self._model.eval()

        # Determine prediction
        with torch.no_grad():
            return self._model(x)


    def determineLoss(self, x, y, lambdaUnscaleY):
        # Get prediction
        prediction = self.predict(x)

        # Determine unscaled loss -> so that loss is independent of scaling
        return self._criterion(lambdaUnscaleY(prediction), lambdaUnscaleY(y)).numpy()
    

    def countParameters(self):
        return sum(p.numel() for p in self._model.parameters() if p.requires_grad)

    def getLayers(self):
        return self._layers
    
    def getNeurons(self):
        return self._neurons

    def getLearningRate(self):
        return self._learningRate
    
    def getFirstWeights(self):
        return self._model[0].weight.detach().numpy()
    