import pytest
import torch
from torch import nn

from byzantine_robust_fl.fed_nets.vgg import VGG, _make_layers


def test_make_layers_structure():
    """Validate that ``_make_layers`` constructs the expected module sequence."""
    config = [64, "M", 128]
    layers = _make_layers(config)

    # Expected structure: [Conv2d, BatchNorm2d, ReLU], [MaxPool2d], [Conv2d, BatchNorm2d, ReLU]
    assert len(layers) == 7, "Should create 7 modules for the given config."

    # Check first convolutional block
    assert isinstance(layers[0], nn.Conv2d)
    assert layers[0].in_channels == 3, "Initial input channels should be 3."
    assert layers[0].out_channels == 64

    # Check max pooling layer
    assert isinstance(layers[3], nn.MaxPool2d)

    # Check second convolutional block
    assert isinstance(layers[4], nn.Conv2d)
    assert layers[4].in_channels == 64, "Input channels for the second conv layer should be 64."
    assert layers[4].out_channels == 128


def test_vgg_initialization_and_errors():
    """Confirm that VGG initialises for valid names and rejects invalid ones."""
    # Test successful initialization with a custom class count
    model = VGG(vgg_name="VGG11", num_classes=100)
    assert isinstance(model, nn.Module)
    assert model.classifier[-1].out_features == 100, "Classifier output features should match num_classes."

    # Test that an invalid VGG name raises a ValueError
    with pytest.raises(ValueError) as excinfo:
        VGG(vgg_name="VGG_INVALID")
    assert "not recognized" in str(excinfo.value), "Should raise a ValueError for an unrecognized VGG name."


@pytest.mark.parametrize(
    "vgg_name, batch_size, num_classes",
    [
        ("VGG11", 2, 10),
        ("VGG13", 4, 20),
        ("VGG16", 1, 100),
    ],
)
def test_vgg_forward_pass_parametrized(vgg_name, batch_size, num_classes):
    """Check forward-pass shapes across supported VGG variants."""
    #
    model = VGG(vgg_name=vgg_name, num_classes=num_classes)
    model.eval()  # Set model to evaluation mode for consistent behavior (e.g., for dropout)

    # Input tensor for CIFAR-10 (3x32x32)
    input_tensor = torch.randn(batch_size, 3, 32, 32)

    # Perform the forward pass
    output = model(input_tensor)

    # Check that the output shape is correct
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, f"For {vgg_name}, expected shape {expected_shape}, but got {output.shape}"
