import numpy as np
import pytest
import torch
import torch.nn as nn
from tqdm import tqdm

from sde.evaluation import SensitivityN
from sde.models import CNNColorDetector, CnnMultiColorAccumulator, IdentityMLP, ModuloModel, SyntheticModel


class MockModel(nn.Module):

    def __init__(self, modulo_number: int) -> None:
        super().__init__()
        self.modulo_number = modulo_number
        # this decision_head is a fake one and will not be used in the forward method
        self.decision_head = ModuloModel(modulo_number, 1000)
        # this real_decision_head will be used in the forward method
        self.real_decision_head = lambda x: x % modulo_number

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() == 4
        x = x.sum([1, 2, 3]).to(torch.int64).unsqueeze(-1)
        # output shape: (batch_size, 1)
        return self.real_decision_head(x).to(torch.float32)


def test_num_steps():
    evaluator = SensitivityN((0, 10, 1), num_repeats=10)
    assert evaluator.num_steps == 10


def test_reset_cache():
    evaluator = SensitivityN((0, 10, 1), num_repeats=10)
    evaluator._cache = [1, 2, 3]
    evaluator.reset_cache()
    assert len(evaluator._cache) == 0


@pytest.mark.parametrize('incremental', [True, False])
@pytest.mark.parametrize('num_channels', [1, 3])
def test_perturb(incremental, num_channels):
    evaluator = SensitivityN((0, 25, 1), num_repeats=2, incremental=incremental)

    # set image pixel values the same as attribution * 25 deliberately
    img = torch.arange(25, dtype=torch.float32).reshape(1, 5, 5).repeat(num_channels, 1, 1)
    attr_map = torch.arange(25, dtype=torch.float32).reshape(5, 5) / 25.0
    for i in range(evaluator.num_steps):
        num_ptb_pixels = evaluator.n_array[i]
        img_ptb, sum_ptb_attr = evaluator.perturb(img, attr_map, num_ptb_pixels=num_ptb_pixels)
        assert img_ptb.shape == (2, num_channels, 5, 5)
        assert sum_ptb_attr.shape == (2,)

        # since the image pixel values are the same as the attribution * 25, and the
        # replace_val is 0.0, then the difference between original and perturbed image
        # is the attribution * 25 * num_channels
        exp_sum_ptb_attr = (img - img_ptb).sum([1, 2, 3]) / 25 / num_channels
        torch.testing.assert_close(sum_ptb_attr, exp_sum_ptb_attr)


@pytest.mark.parametrize('model_type', ['modulo', 'multi_color_sum'])
def test_evaluate(model_type, tmp_path):
    img_file = '0/001.png'

    if model_type == 'modulo':
        evaluator = SensitivityN((0, 10, 1), num_repeats=100, incremental=True)
        modulo_number = 3
        model = MockModel(modulo_number)
        # set attr_map the same as img, so that the correlation between the output
        # change and the sum of perturbed pixel attribution should be 1.
        img = torch.randint(0, 2, (1, 224, 224), dtype=torch.float32)
        attr_map = img.clone().squeeze(0)

        num_pos_pixels = img.sum().to(torch.int64)
        label = num_pos_pixels % modulo_number

        for _ in tqdm(range(evaluator.num_steps)):
            evaluator.reset_cache()
            evaluator.evaluate(model, img, label=label, attr_map=attr_map, gt_mask=torch.tensor([]), img_file=img_file)
            evaluator.summarize_step()
            evaluator.increment_step()
        total_result = evaluator.summarize_total()

        # _cache should only store the pcc of all dataset samples at the latest step.
        # and the dataset only contains one sample in the test case
        assert len(evaluator._cache) == 1
        assert np.isclose(evaluator._cache[0], 1.0, atol=1e-6)
        # _extra_cache should contain only one key, which is the img_file
        assert len(evaluator._extra_cache) == 1

        assert np.allclose(np.asarray(total_result['num_ptb_pixels']), np.arange(10))
        # the first correlation should be 0.0 since the accum_changes of all repeatedly
        # perturbed images are 0, so that their variance is 0. In this case, the pcc
        # computed by pytorch is nan, and we manually define the pcc as 0.0.
        assert total_result['correlations'][0] == 0.0
        # since attr_map is the same as img, the correlation should be 1.
        assert np.allclose(np.asarray(total_result['correlations'][1:]), np.ones(9), atol=1e-5)
        evaluator.visualize_result(total_result, str(tmp_path / 'vis_results.jpeg'))
    else:
        evaluator = SensitivityN((0, 10, 1), num_repeats=100, incremental=False)
        color_detector = CNNColorDetector([
            [180, 10, 10],
            [80, 0, 200],
            [255, 255, 255],
            [255, 255, 0],
        ])
        accumulator = CnnMultiColorAccumulator(4)
        head = IdentityMLP(4)
        model = SyntheticModel(accumulator, head, color_detector=color_detector)
        img = torch.randint(0, 255, (3, 224, 224), dtype=torch.float32)
        attr_map = torch.rand((224, 224), dtype=torch.float32)

        for i in range(evaluator.num_steps):
            evaluator.reset_cache()

            evaluator.evaluate(
                model, img, label=torch.tensor(0), attr_map=attr_map, gt_mask=torch.tensor([]), img_file=img_file)
            evaluator.summarize_step()
            evaluator.increment_step()
        total_result = evaluator.summarize_total()

        # _extra_cache stores the original probability of each dataset sample, and the
        # dataset in the test case only contains one sample
        assert len(evaluator._extra_cache) == 1
        assert len(evaluator._cache) == 1
        assert len(total_result['num_ptb_pixels']) == 10
        assert len(total_result['correlations']) == 10


def test_summarize_step():
    evaluator = SensitivityN((1, 10, 1), num_repeats=10, incremental=False)
    evaluator._cache = [1, 2, 3, 4, 5]
    evaluator.summarize_step()
    assert len(evaluator._results) == 1
    step_result = evaluator._results[0]
    # the first element is the number of perturbed pixels
    assert step_result[0] == 1
    # the second element is the averaged pcc
    assert step_result[1] == 3.0


def test_summarize_total():
    evaluator = SensitivityN((1, 4, 1), num_repeats=10, incremental=False)
    evaluator._results = [(1, 1.), (2, 2.), (3, 3.)]
    total_result = evaluator.summarize_total()
    assert total_result['incremental'] == evaluator.incremental
    assert np.allclose(np.asarray(total_result['num_ptb_pixels']), np.arange(1, 4, 1))
    assert np.allclose(np.asarray(total_result['correlations']), np.arange(1, 4, 1))
