import torch
import torch.nn as nn


class MLPLayer(nn.Module):
    """A one hidden layer MLP module.

    Args:
        idim: (int) Input dimension.
        hdim: (int) Hidden dimension.
        odim: (int) Output dimension.
    """
    def __init__(self, idim, hdim, odim, bn=True):
        super(MLPLayer, self).__init__()
        self.fci = nn.Linear(idim, hdim)
        self.fco = nn.Linear(hdim, odim)
        self.bn = nn.BatchNorm1d(hdim)

    def forward(self, x):
        """Forward function.

        Args:
            x: (torch.Tensor) Input feature.
        """
        x = torch.relu(self.bn(self.fci(x)))
        return self.fco(x)


class LabelerMLP(nn.Module):
    """A one hidden layer MLP module.

    Args:
        idim: (int) Input dimension.
        hdim: (int) Hidden dimension.
        odim: (int) Output dimension.
    """
    def __init__(self, idim, hdim, odim):
        super(LabelerMLP, self).__init__()
        self.fci = nn.Linear(idim, hdim)
        self.fco = nn.Linear(hdim, odim)

    def forward(self, x):
        """Forward function.

        Args:
            x: (torch.Tensor) Input feature.
        """
        x = torch.relu(self.fci(x))
        return self.fco(x)
