import pytest
import torch
from torch import nn

from byzantine_robust_fl.fed_nets.alexnet_cifar import AlexNetCifar


def test_alexnetcifar_initialization():
    """Confirm that ``AlexNetCifar`` initialises with the requested output classes."""
    # 1. Test with default number of classes (10)
    model_default = AlexNetCifar()
    assert isinstance(model_default, nn.Module)
    assert model_default.classifier[-1].out_features == 10, "Default model should have 10 output classes."

    # 2. Test with a custom number of classes
    model_custom = AlexNetCifar(num_classes=100)
    assert isinstance(model_custom, nn.Module)
    assert model_custom.classifier[-1].out_features == 100, (
        "Custom model should have the specified number of output classes."
    )


def test_alexnetcifar_forward_pass_shape():
    """Validate that the forward pass yields outputs with the expected shape."""
    batch_size = 4
    num_classes = 10
    model = AlexNetCifar(num_classes=num_classes)

    # Dummy input for a batch of CIFAR-10 images (3x32x32)
    dummy_input = torch.randn(batch_size, 3, 32, 32)

    # Perform the forward pass
    output = model(dummy_input)

    # Expected shape: (batch_size, num_classes)
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, f"Output shape should be {expected_shape}, but got {output.shape}"


@pytest.mark.parametrize(
    "batch_size, num_classes",
    [
        (1, 10),  # Single item batch
        (16, 20),  # A standard batch size
        (32, 1000),  # Larger batch and ImageNet-like classes
    ],
)
def test_alexnetcifar_forward_pass_parametrized(batch_size, num_classes):
    """Check forward-pass shapes across varying batch sizes and class counts."""
    model = AlexNetCifar(num_classes=num_classes)
    input_tensor = torch.randn(batch_size, 3, 32, 32)

    # Perform a forward pass
    output = model(input_tensor)

    # Check if the output shape is correct
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, (
        f"For batch={batch_size} and classes={num_classes}, expected shape {expected_shape}, but got {output.shape}"
    )
