from torch import nn
import torch.nn.functional as F
from data.utils import set_seed

activation_functions = {'relu': nn.ReLU(), 'leaky_relu': nn.LeakyReLU(),
                        'sigmoid': nn.Sigmoid(), 'tanh': nn.Tanh()}

def ann_l():
    return ArtificialNeuralNetwork(784, [200, 200, 200])

model_dict = {
    'ann_l': ann_l
}

class ArtificialNeuralNetwork(nn.Module):
    def __init__(self, input_dim, hidden_layers, n_class = 10,
                 activation = 'relu', init_seed=None):
        """
        Initializes the artificial neural network model
        :param input_dim: int, number of features
        :param hidden_layers: list of int, number of neurons in each hidden layer
        :param n_class: int, number of classes
        """
        super().__init__()
        set_seed(init_seed)
        
        # Construct layers
        model_layers = []
        previous_layer = input_dim
        for layer in hidden_layers:
            model_layers.append(nn.Linear(previous_layer, layer))
            model_layers.append(activation_functions[activation])
            previous_layer = layer
        model_layers.append(nn.Linear(previous_layer, n_class))
        self.network = nn.Sequential(*model_layers)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return F.softmax(self.network(x), dim=-1)
    
    def get_parameters(self):
        """
        Returns the parameters of the model.
        :returns: model parameters.
        """
        return self.network.parameters()

    def get_data_representation(self, x):
        x = x.view(x.size(0), -1)
        return self.network[:-1](x)