import torch
from captum.attr import visualization as viz
from src.explainers.utils import normalize_attr

def test_normalize_attr_equal():
    batch = torch.randn(2, 3, 16, 16).clip(-1, 1)
    batch_t = normalize_attr(batch, 'absolute_value')
    batch_n = viz._normalize_attr(
        batch[0].permute(1, 2, 0).numpy(force = True), 
        sign = 'absolute_value',
        reduction_axis = 2)
    return torch.allclose(torch.from_numpy(batch_n), batch_t[0])
