#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Define a multi-layer perceptron (MLP) architecture using PyTorch."""

import torch
from torch import nn


class MLP(nn.Module):
    """Implement a single-hidden-layer MLP geared towards classification tasks.

    The network comprises an input projection, a ReLU activation with dropout,
    and a linear output head that produces raw logits.

    Attributes:
        layers: Sequential container encapsulating the network layers.
    """

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        """Initializes the MLP model.

        Args:
            input_dim: The dimension of the input features.
            hidden_dim: The number of neurons in the hidden layer.
            output_dim: The dimension of the output (number of classes).
        """
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Execute the forward pass after flattening the input tensor."""
        # Flatten the input tensor to a 2D shape [batch_size, num_features]
        x = x.view(x.size(0), -1)
        return self.layers(x)
