"""
Standard MLP model.

Author:
Date: October 28, 2023
"""
from typing import Optional

import torch
from torch import Tensor
import torch.nn as nn

from krt.utils import get_activation


class MLP(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_layer_width: int,
        hidden_layer_depth: int,
        hidden_activation: str = 'relu',
        bias: bool = True,
    ):
        """Constructor.

        Args:
            input_dim: Dimension of the input data.
            output_dim: Dimension of the output data.
            hidden_layer_width: Number of hidden units in a hidden layer.
            hidden_layer_depth: Number of hidden layers.
            hidden_activation: The name of the hidden activation to use.
        """
        assert hidden_layer_depth >= 0
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.depth = hidden_layer_depth
        self.bias = bias
        if self.depth == 0:
            self._add_linear_layer(input_dim, output_dim, 0)
        else:
            self._add_linear_layer(input_dim, hidden_layer_width, 0)
            for n in range(self.depth):
                self._add_linear_layer(
                    hidden_layer_width,
                    hidden_layer_width if n < self.depth - 1 else output_dim,
                    n + 1,
                )
        self.activation = get_activation(hidden_activation)

    def forward(
            self,
            net_in: Tensor,
    ) -> Tensor:
        """Forward pass through network.

        Args:
            net_in: The input to the network.

        Returns:
            The output of the network w shape (out_dim,) if one head or
            (n_heads, out_dim) if there are multiple heads.
        """
        curr = net_in
        for layer_num in range(self.depth):
            curr = getattr(self, f'linear_{layer_num}')(curr)
            curr = self.activation(curr)
        return getattr(self, f'linear_{self.depth}')(curr)

    def _add_linear_layer(
            self,
            lin_in: int,
            lin_out: int,
            layer_num: Optional[int] = None,
            layer_name: Optional[str] = None,
    ) -> None:
        """Add a linear layer to the network.

        Args:
            lin_in: Input dimension to the layer.
            lin_out: Output dimension of the layer.
            layer_num: The number of the layer being added.
            layer_name: The name of the layer.
        """
        layer = torch.nn.Linear(lin_in, lin_out, bias=self.bias)
        if layer_name is None:
            if layer_num is None:
                raise ValueError('Either layer_num or layer_name must be provided')
            layer_name = f'linear_{layer_num}'
        self.add_module(layer_name, layer)
