#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Exercise the streamlined MLP1 model."""

import pytest
import torch
from torch import nn

from byzantine_robust_fl.fed_nets import MLP1


@pytest.fixture
def model_dims():
    """Define standard dimensions for the MLP1 model."""
    return {
        "input_dim": 784,
        "hidden_dim": 128,
        "output_dim": 10,
    }


@pytest.fixture
def model(model_dims):
    """Instantiate the MLP1 model for tests."""
    return MLP1(
        input_dim=model_dims["input_dim"],
        hidden_dim=model_dims["hidden_dim"],
        output_dim=model_dims["output_dim"],
    )


def test_initialization(model: MLP1, model_dims: dict):
    """Verify that the model layers initialise with the expected dimensions."""
    assert isinstance(model.layers, nn.Sequential), "Layers should be a Sequential container."

    assert len(model.layers) == 2, "Model should have two linear layers."

    first_layer = model.layers[0]
    second_layer = model.layers[1]

    assert isinstance(first_layer, nn.Linear), "First layer should be a Linear layer."
    assert first_layer.in_features == model_dims["input_dim"], "First layer input dimension is incorrect."
    assert first_layer.out_features == model_dims["hidden_dim"], "First layer output dimension is incorrect."

    assert isinstance(second_layer, nn.Linear), "Second layer should be a Linear layer."
    assert second_layer.in_features == model_dims["hidden_dim"], "Second layer input dimension is incorrect."
    assert second_layer.out_features == model_dims["output_dim"], "Second layer output dimension is incorrect."


def test_forward_pass(model: MLP1, model_dims: dict):
    """Validate that a 2D input produces outputs with the expected shape and dtype."""
    batch_size = 64
    input_tensor = torch.randn(batch_size, model_dims["input_dim"])

    output = model(input_tensor)

    expected_shape = (batch_size, model_dims["output_dim"])
    assert output.shape == expected_shape, f"Output shape should be {expected_shape}, but got {output.shape}."

    assert output.dtype == torch.float32, "Output tensor should have dtype float32."


def test_forward_pass_with_flattening():
    """Ensure the forward pass correctly flattens multi-dimensional inputs."""
    batch_size = 32
    channels = 1
    height = 28
    width = 28

    input_dim_flat = channels * height * width
    hidden_dim = 64
    output_dim = 10

    image_model = MLP1(input_dim=input_dim_flat, hidden_dim=hidden_dim, output_dim=output_dim)

    input_tensor = torch.randn(batch_size, channels, height, width)

    output = image_model(input_tensor)

    expected_shape = (batch_size, output_dim)
    assert output.shape == expected_shape, "Model failed to correctly flatten the input tensor."
