import torch
from torch import nn
import torch.nn.functional as F
from data.utils import set_seed

activation_functions = {'relu': nn.ReLU(), 'leaky_relu': nn.LeakyReLU(),
                        'sigmoid': nn.Sigmoid(), 'tanh': nn.Tanh()}

def ann_s():
    return ArtificialNeuralNetwork(23, [50, 50])

def ann_m():
    return ArtificialNeuralNetwork(23, [200, 50])

def ann_l():
    return ArtificialNeuralNetwork(23, [150, 150, 150])

def lr():
    return LogisticRegression(23)

model_dict = {
    'ann_s': ann_s,
    'ann_m': ann_m,
    'ann_l': ann_l,
    'lr': lr,
}

class ArtificialNeuralNetwork(nn.Module):
    def __init__(self, input_dim, hidden_layers, n_class = 2,
                 activation = 'relu', init_seed=None):
        """
        Initializes the artificial neural network model
        :param input_dim: int, number of features
        :param hidden_layers: list of int, number of neurons in each hidden layer
        :param n_class: int, number of classes
        """
        super().__init__()
        set_seed(init_seed)
        
        # Construct layers
        model_layers = []
        previous_layer = input_dim
        for layer in hidden_layers:
            model_layers.append(nn.Linear(previous_layer, layer))
            model_layers.append(activation_functions[activation])
            previous_layer = layer
        model_layers.append(nn.Linear(previous_layer, n_class))
        self.network = nn.Sequential(*model_layers)
    
    def forward(self, x):
        return F.softmax(self.network(x), dim=-1)
    
    def predict(self, test_dataset):
        """
        Predict outcomes using the model on the provided test dataset.
        :param test_dataset: the dataset to make predictions on.
        :returns: predictions made by the model.
        """
        return self(test_dataset.data)
    
    def get_parameters(self):
        """
        Returns the parameters of the model.
        :returns: model parameters.
        """
        return self.network.parameters()

    def get_data_representation(self, x):
        x = x.view(x.size(0), -1)
        return self.network[:-1](x)
    
class LogisticRegression(nn.Module):
    def __init__(self, input_dim):
        '''
        Initializes the logistic regression model for binary classification (2 outputs)
        :param input_dim: int, number of features
        '''
        super().__init__()

        # Construct layers
        self.input_dim = input_dim
        self.linear = nn.Linear(self.input_dim, 1)
        
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return torch.cat((1-y_pred, y_pred), dim=-1)

    def get_data_representation(self, x):
        return self.linear(x)