"""Implementations multi-layer perceptrons."""

import numpy as np
import torch

from torch.nn import functional as F
from torch import nn


class MLP(nn.Module):
    """A standard multi-layer perceptron."""

    def __init__(self,
                 in_shape,
                 out_shape,
                 hidden_sizes,
                 activation=F.relu,
                 activate_output=False):
        """
        Args:
            in_shape: tuple, list or torch.Size, the shape of the input.
            out_shape: tuple, list or torch.Size, the shape of the output.
            hidden_sizes: iterable of ints, the hidden-layer sizes.
            activation: callable, the activation function.
            activate_output: bool, whether to apply the activation to the output.
        """
        super().__init__()
        self._in_shape = torch.Size(in_shape)
        self._out_shape = torch.Size(out_shape)
        self._hidden_sizes = hidden_sizes
        self._activation = activation
        self._activate_output = activate_output

        if len(hidden_sizes) == 0:
            raise ValueError('List of hidden sizes can\'t be empty.')

        self._input_layer = nn.Linear(np.prod(in_shape), hidden_sizes[0])
        self._hidden_layers = nn.ModuleList([
            nn.Linear(in_size, out_size)
            for in_size, out_size in zip(hidden_sizes[:-1], hidden_sizes[1:])
        ])
        self._output_layer = nn.Linear(hidden_sizes[-1], np.prod(out_shape))

    def forward(self, inputs):
        if inputs.shape[1:] != self._in_shape:
            raise ValueError('Expected inputs of shape {}, got {}.'.format(
                self._in_shape, inputs.shape[1:]))

        inputs = inputs.reshape(-1, np.prod(self._in_shape))
        outputs = self._input_layer(inputs)
        outputs = self._activation(outputs)

        for hidden_layer in self._hidden_layers:
            outputs = hidden_layer(outputs)
            outputs = self._activation(outputs)

        outputs = self._output_layer(outputs)
        if self._activate_output:
            outputs = self._activation(outputs)
        outputs = outputs.reshape(-1, *self._out_shape)

        return outputs
