import sys, os
import pytest

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.pruning_utils import calculate_TF_values


def test_calculate_tf_values_vit():
    avg = {
        'layer_0': {0: [1.0, 1.0], 1: [2.0, 2.0]},
        'layer_1': {0: [2.0, 1.0], 1: [1.0, 2.0]},
    }
    classes = [0, 1]
    tf = calculate_TF_values(classes, avg)
    assert tf['TF_0_layer_0'] == [0.5, 0.5]
    assert tf['TF_1_layer_1'][0] == pytest.approx(1.0/3)
    assert tf['TF_1_layer_1'][1] == pytest.approx(2.0/3)
