import numpy as np
import pytest
import torch
import torch.nn as nn
from tqdm import tqdm

from sde.evaluation import InsertionDeletion
from sde.models import CNNColorDetector, CnnMultiColorAccumulator, IdentityMLP, ModuloModel, SyntheticModel


class MockModel(nn.Module):

    def __init__(self, modulo_number: int) -> None:
        super().__init__()
        self.modulo_number = modulo_number
        # this decision_head is a fake one and will not be used in the forward method
        self.decision_head = ModuloModel(modulo_number, 1000)
        # this real_decision_head will be used in the forward method
        self.real_decision_head = lambda x: x % modulo_number

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() == 4
        x = x.sum([1, 2, 3]).to(torch.int64).unsqueeze(-1)
        # output shape: (batch_size, 1)
        return self.real_decision_head(x).to(torch.float32)


def test_num_steps():
    evaluator_1 = InsertionDeletion(ins_del_levels=(0.1, 1.0, 0.1))
    assert evaluator_1.num_steps == 9

    evaluator_2 = InsertionDeletion(ins_del_levels=1000)
    assert evaluator_2.num_steps == 1000


def test_reset_cache():
    evaluator = InsertionDeletion()
    # manually set the cache and then reset it
    evaluator._cache = [0.1, 0.2, 0.3]
    evaluator.reset_cache()
    assert len(evaluator._cache) == 0


@pytest.mark.parametrize('mode', ['deletion', 'insertion'])
@pytest.mark.parametrize('num_channels', [1, 3])
@pytest.mark.parametrize('ins_del_level', [0.0, 0.5, 1.0, 0, 8, 16])
def test_perturb(mode, num_channels, ins_del_level):
    img = torch.linspace(0.0, 1.0, 16, dtype=torch.float32).reshape(1, 4, 4)
    if num_channels == 3:
        img = img.repeat([3, 1, 1])
    # first pixel is the most important, and the last pixel is the least important
    attr_map = torch.linspace(1.0, 0.0, 16).reshape(4, 4)

    evaluator = InsertionDeletion(mode=mode, replace_val=0.0)
    ptb_img = evaluator.perturb(img, attr_map, ins_del_level)

    if ins_del_level == 0.0 or ins_del_level == 0:
        expected = img if mode == 'deletion' else torch.zeros_like(img)
    elif ins_del_level == 1.0 or ins_del_level == 16:
        expected = torch.zeros_like(img) if mode == 'deletion' else img
    else:
        expected = torch.clone(img).view(num_channels, -1)
        if mode == 'deletion':
            # top-8 important pixels are deleted
            expected[:, :8] = 0.0
        else:
            # top-8 important pixels are inserted
            expected[:, 8:] = 0.0
        expected = expected.view(img.shape)

    torch.testing.assert_close(ptb_img, expected)


@pytest.mark.parametrize('mode', ['deletion', 'insertion'])
@pytest.mark.parametrize('ins_del_levels', [(0.1, 1.0, 0.1), 3])
def test_summarize_total(mode, ins_del_levels):
    evaluator = InsertionDeletion(mode, ins_del_levels)
    num_steps = evaluator.num_steps

    if isinstance(ins_del_levels, tuple):
        step_results = np.linspace(0., 1., num_steps)
    else:
        # three steps, two instances
        step_results = list(np.arange(6).reshape(3, 2))

    evaluator._results = [(evaluator.ins_del_levels[i], step_results[i]) for i in range(num_steps)]
    total_result = evaluator.summarize_total()
    assert total_result['mode'] == evaluator.mode
    assert np.allclose(np.asarray(total_result['ins_del_levels']), evaluator.ins_del_levels)

    if isinstance(ins_del_levels, tuple):
        assert np.allclose(np.asarray(total_result['ins_del_results']), step_results)
    else:
        if mode == 'deletion':
            expected = (1 - np.arange(6).reshape(3, 2) / np.array([4, 5])).mean(1)
        else:
            expected = (np.arange(6).reshape(3, 2) / np.array([4, 5])).mean(1)
        assert np.allclose(np.asarray(total_result['ins_del_results']), expected)


@pytest.mark.parametrize('ins_del_levels', [(0.1, 1.0, 0.1), 3])
def test_summarize_step(ins_del_levels):
    evaluator = InsertionDeletion('deletion', ins_del_levels=ins_del_levels)
    evaluator._cache = [0, 1, 2]
    evaluator.summarize_step()
    step_result = evaluator._results[0]
    # step_result[0] is the ins_del_level
    assert step_result[0] == evaluator.ins_del_levels[0]
    if isinstance(ins_del_levels, tuple):
        # conventional ins/del, the step_result is the average value across all
        # the instances
        assert step_result[1] == 1
    else:
        # for the modulo model, the step_result is not averaged
        assert np.allclose(step_result[1], np.asarray([0, 1, 2]))

    evaluator.increment_step()
    evaluator.reset_cache()

    evaluator._cache = [2, 3, 10]
    evaluator.summarize_step()
    assert len(evaluator._results) == 2
    step_result = evaluator._results[1]
    assert step_result[0] == evaluator.ins_del_levels[1]
    if isinstance(ins_del_levels, tuple):
        assert step_result[1] == 5
    else:
        assert np.allclose(step_result[1], np.asarray([2, 3, 10]))


@pytest.mark.parametrize('model_type', ['modulo', 'multi_color_sum'])
def test_evaluate(model_type, tmp_path):
    img_file = '0/001.png'

    if model_type == 'modulo':
        evaluator = InsertionDeletion('deletion', ins_del_levels=224 * 224)
        modulo_number = 3
        model = MockModel(modulo_number)
        # img is the same as attr_map, so that the deletion curve should be a
        # descending straight line followed by a horizontal line (at zero).
        img = torch.zeros((1, 224, 224), dtype=torch.float32)
        img[:, :50, :50] = 1.0
        attr_map = img.clone().squeeze(0)

        num_pos_pixels = img.sum().to(torch.int64).item()
        label = num_pos_pixels % modulo_number

        for _ in tqdm(range(evaluator.num_steps)):
            evaluator.reset_cache()
            evaluator.evaluate(model, img, label=label, attr_map=attr_map, gt_mask=torch.tensor([]), img_file=img_file)
            evaluator.summarize_step()
            evaluator.increment_step()
        total_result = evaluator.summarize_total()

        assert len(evaluator._cache) == 1
        assert evaluator._cache[0] == num_pos_pixels
        assert len(evaluator._extra_cache) == 1

        assert np.allclose(np.asarray(total_result['ins_del_levels']), np.arange(224 * 224))
        ins_del_results = total_result['ins_del_results']
        assert len(ins_del_results) == 224 * 224
        # starting from num_pos_pixels, the deletion result will stagnate at 0
        assert np.allclose(ins_del_results[2500:], np.zeros(224 * 224 - 2500))
        evaluator.visualize_result(total_result, str(tmp_path / 'vis_results.jpeg'))
    else:
        evaluator = InsertionDeletion('deletion', ins_del_levels=(0.1, 1.0, 0.1))
        color_detector = CNNColorDetector([
            [180, 10, 10],
            [80, 0, 200],
            [255, 255, 255],
            [255, 255, 0],
        ])
        accumulator = CnnMultiColorAccumulator(4)
        head = IdentityMLP(4)
        model = SyntheticModel(accumulator, head, color_detector=color_detector)
        img = torch.randint(0, 255, (3, 224, 224), dtype=torch.float32)
        attr_map = torch.rand((224, 224), dtype=torch.float32)

        evaluator.evaluate(
            model,
            img,
            label=torch.tensor(0),
            attr_map=attr_map,
            gt_mask=torch.tensor([]),
        )
        assert len(evaluator._cache) == 1
        assert evaluator._extra_cache is None
