import os
import sys
from collections import OrderedDict
from types import SimpleNamespace

import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.pruning_utils import prune_weight_vit


def test_vit_pruning_layers():
    net = SimpleNamespace(config=SimpleNamespace(num_hidden_layers=1))
    weights = OrderedDict({
        'vit.embeddings.patch_embeddings.projection.weight': torch.randn(8, 4),
        'vit.encoder.layer.0.attention.attention.query.weight': torch.randn(8, 8),
    })

    num_patch = weights['vit.embeddings.patch_embeddings.projection.weight'].shape[0]
    num_attn = weights['vit.encoder.layer.0.attention.attention.query.weight'].shape[0]

    TF_IDF = {
        'TF_0_layer_0_IDF_layer_0': list(range(num_patch)),
        'TF_0_layer_1_IDF_layer_1': list(range(num_attn)),
        # Nonexistent layer to ensure guards prevent KeyErrors
        'TF_0_layer_2_IDF_layer_2': list(range(num_attn)),
    }

    pruned = prune_weight_vit(50, net, weights, TF_IDF)
    patch = pruned['vit.embeddings.patch_embeddings.projection.weight']
    attn = pruned['vit.encoder.layer.0.attention.attention.query.weight']

    zeros_patch = (patch.sum(dim=1) == 0).sum().item()
    zeros_attn = (attn.sum(dim=1) == 0).sum().item()

    assert zeros_patch == num_patch * 50 // 100
    assert zeros_attn == num_attn * 50 // 100
