#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Define a minimalist two-layer multilayer perceptron (MLP) architecture."""

import torch
from torch import nn


class MLP1(nn.Module):
    """Implement a linear two-layer MLP that omits intermediate activations.

    The architecture is composed of a linear input projection followed by a
    linear classifier head, producing raw logits suitable for calibration or
    testing scenarios.

    Attributes:
        layers: Sequential container holding the model layers.
    """

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        """Initializes the MLP1 model.

        Args:
            input_dim: The dimension of the input features.
            hidden_dim: The number of neurons in the hidden layer.
            output_dim: The dimension of the output (e.g., number of classes).
        """
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.Linear(hidden_dim, output_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Execute the forward pass after flattening the input tensor."""
        # Flatten the input tensor to a 2D shape [batch_size, num_features]
        x = x.view(x.size(0), -1)
        return self.layers(x)
