import os.path as osp
from collections import defaultdict, namedtuple
from typing import Any, DefaultDict, Dict, Optional, Set, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

from ..models import IdentityMLP, ModuloModel
from ..utils import mkdir_or_exist
from .base import BaseEvaluation

PredictAndAccumChange = namedtuple('PredictAndAccumChange', ['prev_pred', 'accum_change'], defaults=[None, 0])


class InsertionDeletion(BaseEvaluation):
    """Insertion/Deletion evaluation metric.

    Args:
        mode: Evaluation mode. Can be "insertion" or "deletion".
        ins_del_levels: Insertion/Deletion levels. If it is an int, then the
            progressive insertion/deletion will be performed on one pixel per step.
            Otherwise, it should be a tuple of three floats, in format
            (star, end, step), then it will be converted to
            `numpy.arange(start, end, step)`. The float array indicates the ratios of
            pixels. For instance, 0.2 means inserting/deleting top-20% pixels.
        replace_val: The value for filling in the perturbed pixels.

    """

    def __init__(
        self,
        mode: str = 'deletion',
        ins_del_levels: Union[Tuple[float, float, float], int] = 50176,
        replace_val: float = 0.0,
        debug: Optional[Dict] = None,
    ) -> None:
        assert mode in ('insertion', 'deletion')
        super().__init__()
        self.mode = mode
        if isinstance(ins_del_levels, int):
            # call tolist() to convert np.int64 to python int for serialization later
            self.ins_del_levels = np.arange(ins_del_levels).tolist()
            # key is the img_file, value is the tuple (prev_pred, accum_change)
            # only applicable to the case where the model has the modulo decision head
            self._extra_cache: DefaultDict[str, PredictAndAccumChange] = defaultdict(PredictAndAccumChange)
        else:
            self.ins_del_levels = np.arange(*ins_del_levels).tolist()
            self._extra_cache = None
        self._results = []
        self._cache = []

        self.replace_val = replace_val
        self.debug_cfg = debug
        self._debug_samples: Set[str] = set()

    @property
    def num_steps(self) -> int:
        return len(self.ins_del_levels)  # type: ignore

    def reset_cache(self) -> None:
        self._cache.clear()

    @torch.no_grad()
    def evaluate(
            self,
            model: nn.Module,
            img: torch.Tensor,
            label: torch.Tensor,
            attr_map: torch.Tensor,
            gt_mask: torch.Tensor,
            **kwargs: Any) -> None:
        assert hasattr(model, 'decision_head')
        if hasattr(model, 'softmax') and model.softmax is not None:
            raise ValueError(
                'model.softmax should be None, the softmax operation will '
                'be performed inside the evaluate method. Please set the '
                'softmax flag to False when building the model.')
        # single image without batch dimension.
        assert img.dim() == 3 and img.shape[0] in (1, 3)

        ins_del_level = self.ins_del_levels[self.current_step]
        ptb_img = self.perturb(img, attr_map, ins_del_level)
        # model only accepts input with shape (batch_size, num_channels, height, width)
        # pred shape: (1,) or (1, num_classes)
        pred = model(ptb_img.unsqueeze(0))

        if isinstance(model.decision_head, ModuloModel):
            img_file = kwargs['img_file']
            prev_pred, accum_change = self._extra_cache[img_file]
            # if the modulo result is different from the label or the previous pred,
            # then the ins/del single pixel is informative.
            if self.current_step == 0:
                prev_pred = label

            output_change = (pred != prev_pred).to(torch.int64).item()
            # accumulate the change for each image
            accum_change += output_change
            # update prev_pred and extra_cache
            prev_pred = pred
            self._extra_cache[img_file] = PredictAndAccumChange(prev_pred, accum_change)
            self._cache.append(accum_change)

            if self.debug_cfg is not None:
                if len(self._debug_samples) <= self.debug_cfg['num_samples']:
                    self._debug_samples.add(img_file)
                # self._debug_samples contain the img_file(s) to be visualized
                if img_file in self._debug_samples:
                    title = f'ins/del level: {ins_del_level}, accum_change: {accum_change}'
                    img_file_wo_extension, extension = osp.splitext(img_file)
                    mkdir_or_exist(osp.join(self.debug_cfg['save_dir'], img_file_wo_extension))
                    debug_img_file = osp.join(img_file_wo_extension, f'ins_del_{ins_del_level}{extension}')
                    save_path = osp.join(self.debug_cfg['save_dir'], debug_img_file)
                    # in black-white-modulo, the input tensor value is either 0. or 1.
                    ptb_img = ptb_img * 255.0
                    self.save_ptb_img(ptb_img, title, save_path)

        elif isinstance(model.decision_head, IdentityMLP):
            # pred are the logits, apply softmax to convert them to probabilities.
            prob = F.softmax(pred, -1)[0, label].item()
            self._cache.append(prob)

            if self.debug_cfg is not None:
                img_file = kwargs['img_file']
                if len(self._debug_samples) <= self.debug_cfg['num_samples']:
                    self._debug_samples.add(img_file)

                if img_file in self._debug_samples:
                    title = f'ins/del level: {ins_del_level:.2f}, probability: {prob:.4f}'
                    img_file_wo_extension, extension = osp.splitext(img_file)
                    mkdir_or_exist(osp.join(self.debug_cfg['save_dir'], img_file_wo_extension))
                    debug_img_file = osp.join(img_file_wo_extension, f'ins_del_{ins_del_level:.2f}{extension}')
                    save_path = osp.join(self.debug_cfg['save_dir'], debug_img_file)
                    self.save_ptb_img(ptb_img, title, save_path)
        else:
            raise NotImplementedError(
                f'Got unsupported type of model.decision_head: '
                f'{model.decision_head.__class__.__name__}')

    def perturb(self, img: torch.Tensor, attr_map: torch.Tensor, ins_del_level: Union[float, int]) -> torch.Tensor:
        assert attr_map.dim() == 2 and img.dim() == 3
        img_shape = img.shape
        # flatten to shape (num_channel, height * width)
        result = torch.clone(img).reshape(img_shape[0], -1)

        num_pixels = attr_map.shape[0] * attr_map.shape[1]
        attr_map = attr_map.view(-1)

        if self.mode == 'deletion':
            if isinstance(ins_del_level, (int, np.integer)):
                # perturb top-k pixels
                num_ptb_pixels = ins_del_level
            else:
                # perturb top-ratio pixels
                num_ptb_pixels = int(ins_del_level * num_pixels)
            # descending: 1st is the most important pixel,
            # last one is the least important
            ptb_inds = torch.argsort(attr_map, descending=True)[:num_ptb_pixels]
        else:
            if isinstance(ins_del_level, int):
                # perturb last-(N-k) pixels
                num_ptb_pixels = num_pixels - ins_del_level
            else:
                # perturb (1 - top-ratio) pixels
                num_ptb_pixels = int((1 - ins_del_level) * num_pixels)
            # ascending: 1st is the least important pixel,
            # last one is the most important
            ptb_inds = torch.argsort(attr_map, descending=False)[:num_ptb_pixels]
        result[:, ptb_inds] = self.replace_val
        result = result.reshape(img_shape)
        return result

    def summarize_step(self) -> None:
        ins_del_level = self.ins_del_levels[self.current_step]
        if self._extra_cache is None:
            # step_result is a scalar that indicates the average predicted probability
            step_result = np.mean(self._cache)
        else:
            # step_result is an array with shape (num_instances, )
            step_result = np.asarray(self._cache)
        self._results.append((ins_del_level, step_result))

    def summarize_total(self) -> Dict:
        ins_del_levels, step_results = list(map(list, zip(*self._results)))
        if self._extra_cache is not None:
            # accum_changes shape: (num_steps, num_instances)
            accum_changes = np.asarray(step_results)
            # the total number of positive pixels is the accum_change at the last step
            total_change = accum_changes[-1]
            if self.mode == 'deletion':
                step_results = 1 - (accum_changes / (total_change + 1e-8))
            else:
                step_results = accum_changes / (total_change + 1e-8)
            step_results = step_results.mean(1).tolist()

        total_result = {'mode': self.mode, 'ins_del_levels': ins_del_levels, 'ins_del_results': step_results}

        return total_result

    def visualize_result(self, total_result: Dict, save_path: str) -> None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))

        mode = total_result['mode']
        prefix = 'Deletion' if mode == 'deletion' else 'Insertion'

        ins_del_levels = total_result['ins_del_levels']
        ins_del_results = total_result['ins_del_results']
        ax.set_title(f'Mode: {mode}')
        ax.set_xlabel(f'{prefix} Level')
        ax.set_ylabel(f'{prefix} Result')
        ax.plot(ins_del_levels, ins_del_results)
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight')

    @staticmethod
    def save_ptb_img(ptb_img: torch.Tensor, title: str, save_path: str) -> None:
        num_channels = ptb_img.shape[0]
        ptb_img = ptb_img.detach().cpu().permute((1, 2, 0)).to(torch.uint8).numpy()
        if num_channels == 1:
            ptb_img = np.squeeze(ptb_img, 0)
            ptb_img = Image.fromarray(ptb_img, mode='L')
        else:
            ptb_img = Image.fromarray(ptb_img, mode='RGB')

        fig, ax = plt.subplots(1, 1, figsize=(10, 12))
        ax.imshow(ptb_img)
        ax.set_title(title)
        ax.axis('off')
        plt.tight_layout()

        plt.savefig(save_path, bbox_inches='tight')
        plt.close(fig)
