"""Definitions of neural network layers."""

#from this import d
import torch
from torch import nn


class DropoutActivation(nn.Module):
    """Combines dropout and activation into a single module.

    This is useful for adding dropout to a Stable Baselines3 policy, which takes an
    activation function as input.
    """

    activation_fn = nn.ReLU

    def __init__(self, p: float = 0.1):
        """Instantiate the dropout and activation layers."""
        super(DropoutActivation, self).__init__()
        self.activation = DropoutActivation.activation_fn()
        self.dropout = nn.Dropout(p=p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute a forward pass: activation function first, then dropout."""
        return self.dropout(self.activation(x))


class ResiduleMLPBlock(nn.Module):
    """Residual MLP block.

    This is a simple MLP block with residual connections. It is used in the
    ResidualMLP class.
    """

    activation_fn = nn.ReLU
    p = 0.1

    def __init__(self, input_size, output_size, hidden_size, p=0.1):
        """Instantiate the layers of the MLP block."""
        super(ResiduleMLPBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.activation = DropoutActivation(p=p)
        self.activation.activation_fn = ResiduleMLPBlock.activation_fn
        self.activation.p = ResiduleMLPBlock.p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute a forward pass."""
        y = self.activation(self.fc1(x))
        y = self.fc2(y)
        return x + y


class ResiduleMLP(nn.Module):
    """Residual MLP.

    This is a simple MLP with residual connections.
    """

    def __init__(self, input_size, output_size, hidden_size, num_blocks, dropout_p=0.1):
        """Instantiate the layers of the MLP."""
        super(ResiduleMLP, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_blocks = num_blocks
        self.blocks = nn.ModuleList()
        # Linear layer to map from input to hidden size
        self.blocks.append(nn.Linear(input_size, hidden_size))
        self.blocks.append(DropoutActivation(p=dropout_p))
        for i in range(num_blocks):
            self.blocks.append(ResiduleMLPBlock(
                hidden_size, hidden_size, hidden_size, dropout_p))
        # Add the final linear layer.
        self.blocks.append(nn.Linear(hidden_size, output_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute a forward pass."""
        for block in self.blocks:
            x = block(x)
        return x
