import pytest
import torch
import torchvision.models as models
from nesim.utils.single_conv_filter_output_extractor import (
    SingleConvFilterOutputExtractor,
    FilterLocation,
)
from nesim.utils.getting_modules import get_module_by_name

layer_names = ["layer1.0.conv1", "layer2.0.conv1", "layer3.1.conv2", "layer4.1.conv2"]
input_channel_indices = [0, -1, 9]
output_channel_indices = [0, -1, 9]
batch_sizes = [1, 2, 4]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.mark.parametrize("layer_name", layer_names)
@pytest.mark.parametrize("input_channel_idx", input_channel_indices)
@pytest.mark.parametrize("output_channel_idx", output_channel_indices)
@pytest.mark.parametrize("batch_size", batch_sizes)
def test_extractor_zeroed_filter(
    layer_name, input_channel_idx, output_channel_idx, batch_size
):

    model = models.resnet18(weights=None).to(device)
    conv_layer = get_module_by_name(module=model, name=layer_name)
    with torch.no_grad():
        conv_layer.weight[
            output_channel_idx, input_channel_idx, :, :
        ] = torch.zeros_like(
            conv_layer.weight[output_channel_idx, input_channel_idx, :, :]
        )

    input_tensor = torch.randn(batch_size, 3, 224, 224).to(device)
    extractor = SingleConvFilterOutputExtractor(model=model, conv_layer_name=layer_name, input_tensor=input_tensor)

    single_filter_output = extractor.extract(
        filter_location=FilterLocation(
            input_channel_idx=input_channel_idx, output_channel_idx=output_channel_idx
        ),
    )
    assert (
        single_filter_output.ndim == 4
    ), f"Expected 4d output but got: {single_filter_output.ndim}"

    assert single_filter_output.shape[0] == batch_size
    assert single_filter_output.shape[1] == 1
    assert (
        single_filter_output.abs().sum() == 0.0
    ), f"The output of a zeroed conv filter should be all zeros, but the abs sum seems to be nonzero :("
