#! -*- coding: utf-8
import typing
from collections import OrderedDict

import torch


class MLPModel(torch.nn.Module):
    def __init__(self, features: typing.List[int], dropout: float = None,
                 activation: typing.Union[typing.Callable | str] = None,
                 activation_args: typing.List = [], activation_kwargs: typing.Dict = {}):
        assert len(features) > 1
        super().__init__()

        self.in_features = features.pop(0)
        self.out_features = features.pop(-1)

        activation = getattr(torch.nn, activation) if isinstance(activation, str) \
            else activation

        layers = OrderedDict()
        in_features = self.in_features
        for i, out_features in enumerate(features):
            layers[f"linear{i}"] = torch.nn.Linear(in_features, out_features)
            if activation is not None:
                layers[f"activation{i}"] = activation(*activation_args,
                                                      **activation_kwargs)
            layers[f"dropout{i}"] = torch.nn.Dropout(p=dropout)
            in_features = out_features

        self.layers = torch.nn.Sequential(layers)
        self.output = torch.nn.Linear(in_features, self.out_features)

    def forward(self, x: torch.Tensor):
        if x.ndim > 2:
            x = x.flatten(start_dim=1)
        o = self.layers(x)
        return self.output(o)
