import os
import sys
import importlib
from collections import OrderedDict

import pytest
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))


@pytest.mark.parametrize(
    "net_name, expected_len, expected_last",
    [
        ("ResNet18_Pruning", 17, "layer4.1.conv2.weight"),
        ("ResNet50_Pruning", 49, "layer4.2.conv3.weight"),
    ],
)
def test_dynamic_resnet_layer_mapping(net_name, expected_len, expected_last):
    res_mod = importlib.reload(importlib.import_module("models.Resnet"))
    net_cls = getattr(res_mod, net_name)
    net = net_cls(num_classes=10)
    # Ensure real modules are available (some tests monkeypatch them)
    for mod in ["matplotlib", "matplotlib.pyplot", "models.vggmodule"]:
        if mod in sys.modules:
            del sys.modules[mod]
    import matplotlib
    matplotlib.use('Agg')
    from utils.pruning_utils import get_resnet_conv_layers, prune_weight_resnet
    layers = get_resnet_conv_layers(net)
    assert len(layers) == expected_len
    assert layers[-1] == expected_last

    # Prepare weights and TF-IDF entries for the last conv layer
    weights = OrderedDict((k, v.clone()) for k, v in net.state_dict().items())
    num_channels = weights[expected_last].shape[0]
    tf_key = f"TF_0_conv{expected_len}_IDF_conv{expected_len}"
    TF_IDF = {tf_key: list(range(num_channels))}

    pruned = prune_weight_resnet(50, net, weights, TF_IDF)
    pruned_layer = pruned[expected_last]
    # Half of the channels should be zeroed
    zeros = (pruned_layer.sum(dim=(1, 2, 3)) == 0).sum().item()
    assert zeros == num_channels * 50 // 100
