#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Exercise the MLP regression model."""

import pytest
import torch
from torch import nn

from byzantine_robust_fl.fed_nets import MLPRegression

# --- Pytest Fixtures ---
# Fixtures are a powerful pytest feature for providing data and objects
# to test functions in a reusable way.


@pytest.fixture
def model_dims():
    """Define standard dimensions for the regression MLP model."""
    return {
        "input_dim": 50,  # e.g., 50 input features
        "hidden_dim": 100,
        "output_dim": 1,  # For standard regression, the output dimension is typically 1
    }


@pytest.fixture
def model(model_dims):
    """Instantiate the MLPRegression model for tests."""
    return MLPRegression(
        input_dim=model_dims["input_dim"],
        hidden_dim=model_dims["hidden_dim"],
        output_dim=model_dims["output_dim"],
    )


# --- Test Cases ---


def test_initialization(model: MLPRegression, model_dims: dict):
    """Verify that the model initialises with the expected layer configuration."""
    # Verify that self.layers is an nn.Sequential container
    assert isinstance(model.layers, nn.Sequential), "Model layers should be wrapped in nn.Sequential."

    # Verify that there are exactly two linear layers
    assert len(model.layers) == 2, "Model should have exactly two linear layers."

    # Get each layer for detailed inspection
    input_to_hidden_layer = model.layers[0]
    hidden_to_output_layer = model.layers[1]

    # Check the first layer (input -> hidden)
    assert isinstance(input_to_hidden_layer, nn.Linear), "The first layer should be of type nn.Linear."
    assert input_to_hidden_layer.in_features == model_dims["input_dim"], (
        "Input dimension of the first layer is incorrect."
    )
    assert input_to_hidden_layer.out_features == model_dims["hidden_dim"], (
        "Output dimension of the first layer is incorrect."
    )

    # Check the second layer (hidden -> output)
    assert isinstance(hidden_to_output_layer, nn.Linear), "The second layer should be of type nn.Linear."
    assert hidden_to_output_layer.in_features == model_dims["hidden_dim"], (
        "Input dimension of the second layer is incorrect."
    )
    assert hidden_to_output_layer.out_features == model_dims["output_dim"], (
        "Output dimension of the second layer is incorrect."
    )


def test_forward_pass(model: MLPRegression, model_dims: dict):
    """Validate that a 2D input produces outputs with the expected shape and dtype."""
    batch_size = 32
    # Create a random input tensor with shape [batch_size, input_dim]
    input_tensor = torch.randn(batch_size, model_dims["input_dim"])

    # Perform the forward pass
    output = model(input_tensor)

    # Verify the output tensor's shape
    expected_shape = (batch_size, model_dims["output_dim"])
    assert output.shape == expected_shape, f"Output shape should be {expected_shape}, but got {output.shape}."

    # Verify the output tensor's data type
    assert output.dtype == torch.float32, "Output tensor's dtype should be float32."


def test_forward_pass_with_input_flattening(model: MLPRegression, model_dims: dict):
    """Ensure the forward pass correctly flattens multi-dimensional inputs."""
    batch_size = 16
    # Create a multi-dimensional input that needs to be flattened
    # (e.g., 16 samples, where each sample is a 5x10 feature map)
    # 5 * 10 = 50, which matches model_dims["input_dim"]
    multi_dim_input = torch.randn(batch_size, 5, 10)

    # Perform the forward pass
    output = model(multi_dim_input)

    # Verify the output shape is still correct, which implicitly proves
    # that the input was flattened successfully.
    expected_shape = (batch_size, model_dims["output_dim"])
    assert output.shape == expected_shape, "Model failed to correctly handle and flatten the multi-dimensional input."
