import pytest
import torch
from torch import nn

from byzantine_robust_fl.fed_nets.cnn_mnist import CNNMnist


def test_cnnmnist_initialization():
    """Confirm that ``CNNMnist`` initialises with the requested output classes."""
    # 1. Test with default number of classes (10)
    model_default = CNNMnist()
    assert isinstance(model_default, nn.Module), "Model should be a PyTorch Module."
    assert model_default.classifier[-1].out_features == 10, "Default model should have 10 output classes."

    # 2. Test with a custom number of classes
    custom_classes = 5
    model_custom = CNNMnist(num_classes=custom_classes)
    assert isinstance(model_custom, nn.Module)
    assert model_custom.classifier[-1].out_features == custom_classes, (
        "Model should initialize with the specified number of output classes."
    )


def test_cnnmnist_forward_pass_shape():
    """Validate that the forward pass yields outputs with the expected shape."""
    batch_size = 4
    num_classes = 10
    model = CNNMnist(num_classes=num_classes)

    # Create a dummy input tensor mimicking a batch of MNIST images
    # MNIST images are 1x28x28 (channels, height, width)
    dummy_input = torch.randn(batch_size, 1, 28, 28)

    # Perform the forward pass
    output = model(dummy_input)

    # The output shape should be (batch_size, num_classes)
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, f"Output shape should be {expected_shape}, but got {output.shape}"


@pytest.mark.parametrize(
    "batch_size, num_classes",
    [
        (1, 10),  # Test with a single item batch
        (32, 5),  # A common batch size with different classes
        (64, 100),  # A larger batch size and more classes
    ],
)
def test_cnnmnist_forward_pass_parametrized(batch_size, num_classes):
    """Check forward-pass shapes across varying batch sizes and class counts."""
    model = CNNMnist(num_classes=num_classes)
    input_tensor = torch.randn(batch_size, 1, 28, 28)

    # Perform a forward pass
    output = model(input_tensor)

    # Check if the output shape is correct
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, (
        f"For batch={batch_size} and classes={num_classes}, expected shape {expected_shape}, but got {output.shape}"
    )
