"""Implementations multi-layer perceptrons."""

import numpy as np
import torch
from torch.nn import functional as F
from torch import nn


class MLP(nn.Module):
    """A standard multi-layer perceptron."""

    def __init__(self, in_shape, out_shape, hidden_sizes, activation=F.relu, activate_output=False, context_features=None):
        """
        Args:
            in_shape: tuple, list or torch.Size, the shape of the input.
            out_shape: tuple, list or torch.Size, the shape of the output.
            hidden_sizes: iterable of ints, the hidden-layer sizes.
            activation: callable, the activation function.
            activate_output: bool, whether to apply the activation to the output.
        """
        super().__init__()
        self._in_shape = torch.Size(in_shape)
        self._context_features = 0 if context_features is None else context_features
        self._out_shape = torch.Size(out_shape)
        self._hidden_sizes = hidden_sizes
        self._activation = activation
        self._activate_output = activate_output

        if len(hidden_sizes) == 0:
            raise ValueError("List of hidden sizes can't be empty.")

        self._input_layer = nn.Linear(np.prod(in_shape) + self._context_features, hidden_sizes[0])
        self._hidden_layers = nn.ModuleList([nn.Linear(in_size, out_size) for in_size, out_size in zip(hidden_sizes[:-1], hidden_sizes[1:])])
        self._output_layer = nn.Linear(hidden_sizes[-1], np.prod(out_shape))

    def forward(self, inputs, context=None):
        if inputs.shape[1:] != self._in_shape:
            raise ValueError("Expected inputs of shape {}, got {}.".format(self._in_shape, inputs.shape[1:]))

        if context is not None:
            inputs = torch.cat((inputs, context), dim=1)
        inputs = inputs.reshape(-1, np.prod(self._in_shape) + self._context_features)

        outputs = self._input_layer(inputs)
        outputs = self._activation(outputs)

        for hidden_layer in self._hidden_layers:
            outputs = hidden_layer(outputs)
            outputs = self._activation(outputs)

        outputs = self._output_layer(outputs)
        if self._activate_output:
            outputs = self._activation(outputs)
        outputs = outputs.reshape(-1, *self._out_shape)

        return outputs
