import torch

from byzantine_robust_fl.fed_nets.mlp import MLP
from byzantine_robust_fl.utils.reproducable import set_seed

# --- Test Suite for MLP Model ---


class TestMLP:
    """Exercise the behaviour of the MLP model."""

    def test_model_initialization(self):
        """Verify that the MLP initialises with the expected layer dimensions."""
        input_dim = 784  # e.g., 28x28 MNIST image flattened
        hidden_dim = 128
        output_dim = 10  # e.g., 10 classes for MNIST

        model = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)

        # Check that the layers are created correctly
        assert isinstance(model.layers[0], torch.nn.Linear)
        assert model.layers[0].in_features == input_dim
        assert model.layers[0].out_features == hidden_dim

        assert isinstance(model.layers[3], torch.nn.Linear)
        assert model.layers[3].in_features == hidden_dim
        assert model.layers[3].out_features == output_dim

    def test_forward_pass_shape(self):
        """Validate that the forward pass returns outputs with the expected shape."""
        input_dim = 784
        hidden_dim = 128
        output_dim = 10
        batch_size = 64

        model = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
        # Create a dummy input tensor
        # The input can be multi-dimensional, as the model flattens it.
        dummy_input = torch.randn(batch_size, 1, 28, 28)

        # Perform a forward pass
        output = model(dummy_input)

        # Check the output shape
        assert output.shape == (batch_size, output_dim)

    def test_reproducibility_with_set_seed(self):
        """Confirm that seed control yields deterministic outputs."""
        input_dim = 50
        hidden_dim = 20
        output_dim = 5
        batch_size = 4

        # Set the seed before creating models and data
        set_seed(42)
        model1 = MLP(input_dim, hidden_dim, output_dim)
        dummy_input1 = torch.randn(batch_size, input_dim)
        output1 = model1(dummy_input1)

        # Reset the seed and create another identical setup
        set_seed(42)
        model2 = MLP(input_dim, hidden_dim, output_dim)
        dummy_input2 = torch.randn(batch_size, input_dim)
        output2 = model2(dummy_input2)

        # The outputs should be identical
        assert torch.allclose(output1, output2, atol=1e-6)

        # Now, test that a different seed produces different results
        set_seed(99)
        model3 = MLP(input_dim, hidden_dim, output_dim)
        dummy_input3 = torch.randn(batch_size, input_dim)
        output3 = model3(dummy_input3)

        assert not torch.allclose(output1, output3)
