import numpy as np
import torch
import torch.nn as nn

from sde.evaluation import AttributionPercentage


def test_reset_cache():
    evaluator = AttributionPercentage()
    evaluator._cache = [1, 2, 3]
    assert len(evaluator._cache) == 3
    evaluator.reset_cache()
    assert len(evaluator._cache) == 0


def test_evaluate():
    evaluator = AttributionPercentage(pos_val=1)
    dummy_model = nn.ReLU()
    img = torch.ones((1, 10, 10))
    attr_map = torch.arange(0, 1.0, 0.01).reshape(10, 10)
    # 1/4 of the gt_mask is positive, and 1.0 indicates positive pixel in the gt_mask
    gt_mask = torch.zeros_like(attr_map)
    gt_mask[5:, 5:] = 1.0

    evaluator.evaluate(dummy_model, img, label=torch.tensor(0), attr_map=attr_map, gt_mask=gt_mask)
    assert len(evaluator._cache) == 1
    attr_percentage = evaluator._cache[0]

    expected = (attr_map[5:, 5:].sum() / attr_map.sum()).item()
    assert attr_percentage == expected


def test_summarize_step():
    evaluator = AttributionPercentage()
    evaluator._cache = [1, 2, 3]
    evaluator.summarize_step()
    assert np.allclose(np.asarray(evaluator._results), np.array([1, 2, 3]))


def test_summarize_total(tmp_path):
    evaluator = AttributionPercentage()
    evaluator._cache = [1, 2, 3]
    evaluator.summarize_step()
    total_result = evaluator.summarize_total()
    assert np.allclose(np.asarray(total_result['percentages']), np.array([1, 2, 3]))

    evaluator.visualize_result(total_result, str(tmp_path / 'vis_result.png'))
