import pytest
import torch
from torch import nn

from byzantine_robust_fl.fed_nets.resnet import ResNet


def test_resnet_initialization_and_errors():
    model = ResNet(resnet_name="ResNet18", num_classes=100)
    assert isinstance(model, nn.Module)
    assert model.fc.out_features == 100, "FC's output dimension should be the same as that of num_classes."

    with pytest.raises(ValueError) as excinfo:
        ResNet(resnet_name="ResNet_INVALID")
    assert "not recognized" in str(excinfo.value), "ValueError。"


@pytest.mark.parametrize(
    "resnet_name, batch_size, num_classes",
    [
        ("ResNet18", 2, 10),
        ("ResNet34", 4, 20),
        ("ResNet50", 1, 100),
    ],
)
def test_resnet_forward_pass_parametrized(resnet_name, batch_size, num_classes):
    model = ResNet(resnet_name=resnet_name, num_classes=num_classes)  # small_input=True by default
    model.eval()

    x = torch.randn(batch_size, 3, 32, 32)
    with torch.inference_mode():
        y = model(x)

    expected_shape = (batch_size, num_classes)
    assert y.shape == expected_shape, f"{resnet_name}: expected output shape {expected_shape}, {y.shape} in reality"


def test_resnet_feature_map_shapes_cifar():
    model = ResNet(resnet_name="ResNet18", num_classes=10, use_bn=True)
    model.eval()

    feats = {}

    def save_shape(name):
        def hook(_m, _inp, out):
            feats[name] = out.shape

        return hook

    h1 = model.layer1.register_forward_hook(save_shape("layer1"))
    h2 = model.layer2.register_forward_hook(save_shape("layer2"))
    h3 = model.layer3.register_forward_hook(save_shape("layer3"))
    h4 = model.layer4.register_forward_hook(save_shape("layer4"))

    x = torch.randn(2, 3, 32, 32)
    with torch.inference_mode():
        _ = model(x)

    h1.remove()
    h2.remove()
    h3.remove()
    h4.remove()

    assert feats["layer1"] == (2, 64, 32, 32), f"layer1's shape should be (B,64,32,32), getting {feats['layer1']}"
    assert feats["layer2"] == (2, 128, 16, 16), f"layer2's shape should be (B,128,16,16), getting {feats['layer2']}"
    assert feats["layer3"] == (2, 256, 8, 8), f"layer3's shape should be (B,256,8,8), getting {feats['layer3']}"
    assert feats["layer4"] == (2, 512, 4, 4), f"layer4's shape should be (B,512,4,4), getting {feats['layer4']}"


def test_resnet_no_batchnorm_path():
    model = ResNet(resnet_name="ResNet34", num_classes=10, use_bn=False)
    model.eval()
    x = torch.randn(3, 3, 32, 32)
    with torch.inference_mode():
        y = model(x)
    assert y.shape == (3, 10)
