from collections import defaultdict
from typing import Any, DefaultDict, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..models import IdentityMLP, ModuloModel
from .base import BaseEvaluation
from .insertion_deletion import PredictAndAccumChange


class SensitivityN(BaseEvaluation):

    def __init__(
        self,
        n_range: Tuple[int, int, int],
        num_repeats: int = 100,
        replace_val: float = 0.0,
        incremental: bool = False,
    ) -> None:
        super().__init__()
        if incremental and n_range[0] != 0:
            raise ValueError(
                'When incremental is True, n_range must start with 0, '
                'since we need to record the initial model output. '
                f'But got incremental: {incremental}, and n_range: {n_range}.')

        # call tolist() to convert np.int64 to python int for serialization later
        self.n_array = np.arange(*n_range).tolist()
        self.num_repeats = num_repeats
        self.incremental = incremental
        self._results = []
        self._cache = []
        self.replace_val = replace_val

        if self.incremental:
            # PredictAndAccumChange.prev_pred is a tensor of shape (num_repeats, 1)
            # PredictAndAccumChange.accum_change is a tensor of shape (num_repeats, 1)
            self._extra_cache: DefaultDict[str, Union[torch.Tensor,
                                                      PredictAndAccumChange]] = defaultdict(PredictAndAccumChange)
        else:
            # key is image_file, value is the original probability
            self._extra_cache: Dict[str, Union[torch.Tensor, PredictAndAccumChange]] = dict()

        # randperm_inds will be set in the first call to perturb method.
        # randperm_inds shape: (num_repeats, num_pixels)
        self.randperm_inds: Optional[torch.Tensor] = None

    @property
    def num_steps(self) -> int:
        return len(self.n_array)  # type: ignore

    def reset_cache(self) -> None:
        self._cache.clear()

    @torch.no_grad()
    def evaluate(
            self,
            model: nn.Module,
            img: torch.Tensor,
            label: torch.Tensor,
            attr_map: torch.Tensor,
            gt_mask: torch.Tensor,
            **kwargs: Any) -> None:
        assert hasattr(model, 'decision_head')
        if hasattr(model, 'softmax') and model.softmax is not None:
            raise ValueError(
                'model.softmax should be None, the softmax operation will '
                'be performed inside the evaluate method. Please set the '
                'softmax flag to False when building the model.')
        # single image without batch dimension.
        assert img.dim() == 3 and img.shape[0] in (1, 3)
        img_file = kwargs['img_file']

        num_ptb_pixels = self.n_array[self.current_step]
        # ptb_img shape: (num_repeats, num_channels, height, width)
        # sum_ptb_attr: (num_repeats, )
        ptb_img, sum_ptb_attr = self.perturb(img, attr_map, num_ptb_pixels)
        # pred shape: (num_repeats, 1) or (num_repeats, num_classes)
        pred = model(ptb_img)

        if isinstance(model.decision_head, ModuloModel):
            assert self.incremental
            prev_pred, accum_change = self._extra_cache[img_file]
            # if the modulo result is different from the label or the previous pred,
            # then the ins/del single pixel is informative.
            if self.current_step == 0:
                prev_pred = label

            # output_change shape: (num_repeats, 1)
            output_change = (pred != prev_pred).to(torch.int64)
            accum_change += output_change
            # update prev_pred and extra_cache
            prev_pred = pred
            self._extra_cache[img_file] = PredictAndAccumChange(prev_pred, accum_change)
            # compute the Pearson correlation coefficient
            change_and_attr = torch.stack([sum_ptb_attr, accum_change.squeeze(1)], 0)
            pcc = torch.corrcoef(change_and_attr)[0, 1]
            if pcc.isnan():
                # It can happen that accum_change or sum_ptb_attr has zero variance
                pcc = 0.0
            else:
                pcc = pcc.item()
            self._cache.append(pcc)
        elif isinstance(model.decision_head, IdentityMLP):
            assert not self.incremental
            if self.current_step == 0:
                ori_prob = F.softmax(model(img.unsqueeze(0)), -1)[0, label]
                self._extra_cache[img_file] = ori_prob
            else:
                ori_prob = self._extra_cache[img_file]

            # pred are the logits, apply softmax to convert them to probabilities.
            prob = F.softmax(pred, -1)[:, label]
            prob_change = ori_prob - prob
            change_and_attr = torch.stack([sum_ptb_attr, prob_change], 0)
            pcc = torch.corrcoef(change_and_attr)[0, 1]
            if pcc.isnan():
                # It can happen that accum_change or sum_ptb_attr has zero variance
                pcc = 0.0
            else:
                pcc = pcc.item()
            self._cache.append(pcc)
        else:
            raise NotImplementedError(
                f'Got unsupported type of model.decision_head: '
                f'f{model.decision_head.__class__.__name__}')

    def perturb(self, img: torch.Tensor, attr_map: torch.Tensor,
                num_ptb_pixels: int) -> Tuple[torch.Tensor, torch.Tensor]:
        assert attr_map.dim() == 2 and img.dim() == 3
        num_channels = img.shape[0]
        # repeat, flatten to (num_repeats, num_channels, height * width)
        result = torch.clone(img).repeat(self.num_repeats, 1, 1, 1).reshape(self.num_repeats, num_channels, -1)
        # transpose to shape (num_repeats, height * width, num_channels)
        result = result.permute((0, 2, 1))

        num_pixels = attr_map.shape[0] * attr_map.shape[1]
        if self.incremental:
            # if self.randperm_inds is None, then initialize it
            # if self.randperm_inds is already set, then directly use it
            if self.randperm_inds is None:
                # randperm_inds shape: (num_repeats, num_pixels)
                self.randperm_inds = torch.stack(
                    [torch.randperm(num_pixels, device=img.device) for _ in range(self.num_repeats)], 0)
            # ptb_inds shape: (num_repeats, num_ptb_pixels)
            ptb_inds = self.randperm_inds[:, :num_ptb_pixels]
        else:
            randperm_inds = torch.stack(
                [torch.randperm(num_pixels, device=img.device) for _ in range(self.num_repeats)], 0)
            ptb_inds = randperm_inds[:, :num_ptb_pixels]

        # attr_map shape: (num_repeats, num_pixels)
        attr_map = attr_map.view(-1).unsqueeze(0).repeat(self.num_repeats, 1)
        sum_ptb_attr = torch.gather(attr_map, dim=1, index=ptb_inds).sum(1)

        # ptb_inds shape: (num_repeats, num_ptb_pixels, num_channels)
        ptb_inds = ptb_inds.unsqueeze(-1).repeat(1, 1, num_channels)
        result.scatter_(1, ptb_inds, self.replace_val)
        # transpose to shape (num_repeats, num_channels, height * width)
        result = result.permute((0, 2, 1))
        # reshape to (num_repeats, num_channels, height, width)
        result = result.reshape(self.num_repeats, *img.shape)

        return result, sum_ptb_attr

    def summarize_step(self) -> None:
        num_ptb_pixels = self.n_array[self.current_step]
        # convert np.float64 to float64 for serialization later
        avg_pcc = np.mean(self._cache).item()
        self._results.append((num_ptb_pixels, avg_pcc))

    def summarize_total(self) -> Dict:
        num_ptb_pixels, step_results = list(map(list, zip(*self._results)))
        total_result = {'incremental': self.incremental, 'num_ptb_pixels': num_ptb_pixels, 'correlations': step_results}
        return total_result

    def visualize_result(self, total_result: Dict, save_path: str) -> None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))

        num_ptb_pixels = total_result['num_ptb_pixels']
        correlations = total_result['correlations']
        ax.set_title('Sensitivity-N')
        ax.set_xlabel('N')
        ax.set_ylabel('Correlation')
        ax.plot(num_ptb_pixels, correlations)
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight')
