import pytest
import torch
from torch import nn

from byzantine_robust_fl.fed_nets.cnn_cifar import CNNCifar


def test_cnncifar_initialization():
    """Confirm that ``CNNCifar`` initialises with the requested output classes."""
    # 1. Test with default number of classes (10)
    model_default = CNNCifar()
    assert isinstance(model_default, nn.Module)
    assert model_default.classifier[-1].out_features == 10, "Default model should have 10 output classes."

    # 2. Test with a custom number of classes
    model_custom = CNNCifar(num_classes=20)
    assert isinstance(model_custom, nn.Module)
    assert model_custom.classifier[-1].out_features == 20, (
        "Custom model should have the specified number of output classes."
    )


def test_cnncifar_forward_pass_shape():
    """Validate that the forward pass yields outputs with the expected shape."""
    batch_size = 8
    num_classes = 10
    model = CNNCifar(num_classes=num_classes)

    # Create a dummy input tensor for a batch of CIFAR-10 images
    # CIFAR-10 images are 3x32x32 (channels, height, width)
    dummy_input = torch.randn(batch_size, 3, 32, 32)

    # Perform the forward pass
    output = model(dummy_input)

    # Expected shape: (batch_size, num_classes)
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, f"Output shape should be {expected_shape}, but got {output.shape}"


@pytest.mark.parametrize(
    "batch_size, num_classes",
    [
        (1, 10),  # Single item batch
        (16, 2),  # Common batch size, binary classification
        (32, 100),  # Larger batch size, more classes
    ],
)
def test_cnncifar_forward_pass_parametrized(batch_size, num_classes):
    """Check forward-pass shapes across varying batch sizes and class counts."""
    model = CNNCifar(num_classes=num_classes)
    # Input tensor for CIFAR-10 has 3 color channels
    input_tensor = torch.randn(batch_size, 3, 32, 32)

    # Perform a forward pass
    output = model(input_tensor)

    # Check if the output shape is correct
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, (
        f"For batch={batch_size} and classes={num_classes}, expected shape {expected_shape}, but got {output.shape}"
    )
