import logging
import os
import os.path as osp
from copy import deepcopy
from functools import partial

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
from skimage.metrics import structural_similarity

from sde.attribution_methods import AttributionGenerator


def apply_weight_init(mode):
    return partial(weights_init, mode=mode)


def weights_init(m, mode):
    if isinstance(m, torch.nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.zeros_(m.bias)
    elif isinstance(m, torch.nn.Linear):
        if mode == "normal":
            torch.nn.init.normal_(m.weight, mean=0, std=0.05)
        elif mode == "sparse":
            torch.nn.init.sparse_(m.weight, sparsity=0.5, std=0.01)
        elif mode == "biased":
            torch.nn.init.uniform_(m.weight)
        else:
            raise NotImplementedError
        torch.nn.init.zeros_(m.bias)


def get_module(model, module):
    r"""Returns a specific layer in a model based.
    Shameless (shameful?) copy from `TorchRay`.
    :attr:`module` is either the name of a module (as given by the
    :func:`named_modules` function for :class:`torch.nn.Module` objects) or
    a :class:`torch.nn.Module` object. If :attr:`module` is a
    :class:`torch.nn.Module` object, then :attr:`module` is returned unchanged.
    If :attr:`module` is a str, the function searches for a module with the
    name :attr:`module` and returns a :class:`torch.nn.Module` if found;
    otherwise, ``None`` is returned.
    Args:
        model (:class:`torch.nn.Module`): model in which to search for layer.
        module (str or :class:`torch.nn.Module`): name of layer (str) or the
            layer itself (:class:`torch.nn.Module`).
    Returns:
        :class:`torch.nn.Module`: specific PyTorch layer (``None`` if the layer
            isn't found).
    """
    if isinstance(module, torch.nn.Module):
        return module

    assert isinstance(module, str)
    if module == '':
        return model

    for name, curr_module in model.named_modules():
        if name == module:
            return curr_module

    return None


def perturb_model(model, layers, mode="sparse"):
    for layer in layers:
        module = get_module(model, layer)
        module.apply(apply_weight_init(mode=mode))


class SanityCheck:

    def __init__(self, model, device, perturb_mode='biased'):
        self.model = model
        self.device = device
        self.ori_state_dict = deepcopy(self.model.state_dict())
        # layer names of classifier
        self.model_layers = self.filter_names([n[0] for n in self.model.named_modules()])
        self.logger = logging.getLogger()
        self.mode = perturb_mode

    def reload(self):
        self.model.load_state_dict(self.ori_state_dict)
        self.model.to(self.device)
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False

    def evaluate(  # noqa
            self,
            attr_map,
            input_tensor,
            target,
            attribution_method_cfg,
            perturb_layers,
            verbose=False,
            save_dir=None,
            save_heatmaps=False):
        """Apply sanity check to the attribution method with a single image.
        Given a list `perturb_layers = ['a', 'b', 'c']`. There will be
        `len(perturb_layers)` perturbation settings. First, 'a' and all the
        subsequent layers will be perturbed. Next, 'b' and all the
        subsequent layers will be perturbed. Then, the similar for 'c'. For
        each perturbation setting, an SSIM value will be computed. In the end,
        a dict containing all the SSIM values will be returned.

        Args:
            attr_map (np.ndarray): attr_map generated by the unperturbed model.
                It has `dtype` of `np.uint8` and shape of (h, w).
            input_tensor (torch.Tensor): input image with shape (3, h, w).
            target (int): class label of the image.
            attribution_method_cfg (dict): attribution configurations.
            perturb_layers (list): layers denoting the perturbed range of the
                model.
            input_tensor: input tensor for the model.
            verbose (bool, optional): if True, log the messages during
                perturbation and re-attribution process. The messages contain
                e.g. which layers are perturbed, etc.
            save_dir (str, optional): directory for saving the results.
                Only useful when `save_heatmaps` is True.
            save_heatmaps (bool, optional): if True, save the heatmaps
                produced by the perturbed models along with
                the original attr_map.

        Returns:
            ssim_all (dict): key 'ssim_val'. ssim values under all the
            perturbation settings.
        """
        if save_heatmaps:
            assert save_dir is not None, \
                "if save_masks, save_dir must not be None"
            if save_dir is not None:
                if not osp.exists(save_dir):
                    os.mkdir(save_dir)
        model_layers = deepcopy(self.model_layers)

        # start from the last layer
        model_layers = model_layers[::-1]
        ssim_all = []
        for layer in perturb_layers:
            # reload state_dict
            self.reload()
            if verbose:
                self.logger.info(f'Perturb {layer} and subsequent layers')
            p_layers = []
            for m in model_layers:
                if layer != m:
                    p_layers.append(m)
                else:
                    break
            p_layers.append(layer)
            if verbose:
                self.logger.info(f"Following layers will be perturbed: "
                                 f"[{', '.join(p_layers)}]")

            # start sanity check using perturbed layers
            ssim_val = self.sanity_check_single(
                input_tensor=input_tensor,
                target=target,
                attr_method_cfg=attribution_method_cfg,
                perturb_layers=p_layers,
                ori_attribution_map=attr_map,
                save_dir=save_dir,
                save_heatmaps=save_heatmaps)
            if verbose:
                self.logger.info(f'ssim_val: {ssim_val}')
            ssim_all.append(ssim_val)
        if save_heatmaps:
            self.show_mask(attr_map, out_file=osp.join(save_dir, 'ori_attribution_map'))
        return dict(ssim_all=ssim_all)

    def sanity_check_single(
            self,
            input_tensor,
            target,
            attr_method_cfg,
            perturb_layers,
            ori_attribution_map,
            save_dir=None,
            save_heatmaps=False):

        # get attribution map of the perturbed model
        perturb_model(self.model, perturb_layers, mode=self.mode)
        attr_generator = AttributionGenerator(attr_method_cfg)
        attribution_map = attr_generator.generate_attribution(self.model, input_tensor.unsqueeze(0), target)
        ssim_val = self.ssim(ori_attribution_map, attribution_map)
        if save_heatmaps:
            attribution_map = (attribution_map * 255).astype(np.uint8)
            out_file = osp.join(save_dir, f"{perturb_layers[-1]}")
            self.show_mask(attribution_map, out_file=out_file)
        return ssim_val

    @staticmethod
    def ssim(mask_1, mask_2):
        if not isinstance(mask_1, np.ndarray):
            mask_1 = mask_1.detach().cpu().numpy()
        if not isinstance(mask_2, np.ndarray):
            mask_2 = mask_2.detach().cpu().numpy()
        if mask_1.shape[0] == 3:
            mask_1 = mask_1.transpose([1, 2, 0])
        if mask_2.shape[0] == 3:
            mask_2 = mask_2.transpose([1, 2, 0])
        mask_1 = SanityCheck.convert_mask(mask_1)
        mask_2 = SanityCheck.convert_mask(mask_2)
        return structural_similarity(mask_1, mask_2, win_size=5, data_range=255, channel_axis=2)

    @staticmethod
    def convert_mask(m):
        if m.dtype in (np.float64, np.float32, np.float16, np.float128):
            if m.max() > 1.0 or m.min() < 0:
                m = SanityCheck.normalize_mask(m)
            m = (m * 255).astype(np.uint8)
        return m

    @staticmethod
    def normalize_mask(mask):
        mask -= mask.min()
        mask /= mask.max()
        return mask

    @staticmethod
    def filter_names(names):
        res = []
        for i in range(len(names) - 1):
            if not names[i] in names[i + 1]:
                res.append(names[i])
        res.append(names[-1])
        return res

    @staticmethod
    def show_mask(mask, show=False, out_file=None):
        if mask.dtype in (float, np.float32, np.float16, np.float128):
            mask = (mask * 255).astype(np.uint8)
        mask_to_show = np.copy(mask)
        if mask_to_show.shape[0] == 3:
            mask_to_show = mask_to_show.transpose([1, 2, 0])

        # TODO meaning of this previous code unclear
        # norm = colors.CenteredNorm(0)
        # cm = plt.cm.get_cmap('bwr')
        # mask_to_show = cm(norm(mask_to_show))
        plt.imshow(mask_to_show, cmap='bwr', norm=colors.CenteredNorm(0))
        plt.axis('off')

        if out_file is not None:
            dir_name = osp.abspath(osp.dirname(out_file))
            if not os.path.exists(dir_name):
                os.mkdir(dir_name)
            plt.imsave(out_file + '.JPEG', mask_to_show)
        if not show:
            plt.close()
        else:
            plt.show()
