##############################################################################################################################################################
##############################################################################################################################################################
"""
Generic model definition. Takes as input the config file and create the MLP according to the specified hyperparameters.
"""
##############################################################################################################################################################
##############################################################################################################################################################

from torch import nn

##############################################################################################################################################################
##############################################################################################################################################################

class MLP(nn.Module):

    def __init__(self, model_config, data_config):
        super().__init__()

        # dimension of the input data
        self.input_size = data_config["dimension"]

        # sizes of the different fully connected layers
        # for ease of use, we add the input dimension and output dimension as well
        self.layers_sizes = [self.input_size] + model_config["hidden_layers"] + [1]

        # from the torch.nn module, get the activation function specified 
        # in the config. make sure the naming is correctly spelled 
        # according to the torch name
        self.activation = getattr(nn, model_config["activation"])

        # placeholder for all layers
        # the first layer will flatten the input  to a vector
        self.layers = [nn.Flatten()]

        # for each fully connected layer
        # we iterate twice over the layers sizes with an offset
        # this way we simultaneously can use the values as input and output dimension
        for idx, (in_features, out_features) in enumerate(zip(self.layers_sizes, self.layers_sizes[1:])):

            # add the fully connected layer
            self.layers.append(nn.Linear(in_features=in_features, out_features=out_features, bias=True))

            # add the activation function, but only if not the last pair
            # -2 because the first element is input dimension
            if idx != len(self.layers_sizes)-2:
                self.layers.append(self.activation())

        # make a sequential model using all layers defined above  
        self.regression = nn.Sequential(*self.layers)

        # print the model
        self._print_model()

        # define the loss function to use
        self.criterion = nn.MSELoss(reduction="mean")


    def _print_model(self):

        print("=" * 57)
        print("The model is defined as: ")
        print("=" * 57)
        print(self.regression)
        print("=" * 57)
        print("Parameters of the model to learn:")
        print("=" * 57)
        for name, param in self.named_parameters():
            if param.requires_grad == True:
                print(name)
        print('=' * 57)

    def loss(self, prediction, target):

        # use the defined criterion
        return self.criterion(prediction.squeeze(), target)

    def forward(self, x):

        # simple regression using the sequential model
        x = self.regression(x)

        return x

##############################################################################################################################################################