from typing import Dict
import torch.nn as nn
from pydantic import BaseModel


def init_model(config: Dict) -> nn.Module:
    """
    Initializes the specified model.

    Args:
        config (dict): A dictionary containing the model configuration.

    Returns:
        The initialized model object (an instance of `torch.nn.Module`).

    Raises:
        NotImplementedError: If the specified model name in the config is not supported.
    """

    if config['model_name'] == 'mlp':
        return MLP(config)
    else:
        raise NotImplementedError(f"model_name ({config['name']}) not implemented.")


class MLP(nn.Module):
    class Config(BaseModel):
        model_name: str = 'mlp'
        input_size: int = 73
        output_size: int = 72
        hidden_sizes: tuple = (512, 512, 512, 512, 512)
    default_config = Config().dict()

    def __init__(self, config):
        super(MLP, self).__init__()
        self.config = self.Config(**config).dict()

        # parameters
        self.model_name = config['model_name']
        self.input_size = config['input_size']
        self.output_size = config['output_size']
        self.hidden_sizes = config['hidden_sizes']

        layers = []
        prev_size = self.input_size
        for size in self.hidden_sizes:
            layers.append(nn.Linear(prev_size, size))
            layers.append(nn.ReLU())
            prev_size = size

        self.hidden_layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(prev_size, self.output_size)

    def forward(self, x):
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        return x
