import logging
import random

import torch
import numpy as np

from models import relu_net
from models import conv_net
from models import resnet
from models.base_model import SequentialModel


def get_model(model_name: str, dims: list, seed=42, **kwargs) -> SequentialModel:
    """Get model object from its name.

    Parameters
    ----------
    model_name
        Name of the model with the width specification.
    dims
        Dimensions of the input and the output in the form: [input_shape, output_shape]
    seed, optional
        Random seed used to initialise the model, by default 42

    Returns
    -------
        Model that is derived from SequentialModel class

    Raises
    ------
    NameError
        This exception is raised when the model name was not found.
    """
    # fix the seed for model initialisation
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if model_name.startswith("FF_ReLU"):
        widths = list(map(int, model_name.split("_")[2:]))
        return relu_net.ReLU_Net(widths, dims, bias=False)

    if model_name.startswith("Bias_FF_ReLU"):
        widths = list(map(int, model_name.split("_")[3:]))
        return relu_net.ReLU_Net(widths, dims, bias=True)

    if model_name.startswith("LeakyOutput_FF_ReLU"):
        widths = list(map(int, model_name.split("_")[3:]))
        return relu_net.ReLU_Net(widths, dims, leaky_output=True, bias=False)

    if model_name.startswith("Bias_LeakyOutput_FF_ReLU"):
        widths = list(map(int, model_name.split("_")[4:]))
        return relu_net.ReLU_Net(widths, dims, leaky_output=True, bias=True)

    if model_name.startswith("CNN_"):
        width = int(model_name.split("_")[1])
        return conv_net.Conv_Net(width, dims, bias=False)

    if model_name.startswith("ResNet_"):
        width = int(model_name.split("_")[1])
        return resnet.ResNet(width, dims, **kwargs)

    logging.error("Model does not exist")
    raise NameError(name="model_name")
