from typing import Optional, Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
from captum.attr import Lime
from skimage.segmentation import quickshift, felzenszwalb

from ..utils import check_shape, norm_and_to_uint8


def compute_lime(
        model: nn.Module,
        preprocessed_image: torch.Tensor,
        label: torch.Tensor,
        segmentation: Optional[str] = 'quickshift',
        denorm_stats: Optional[Dict] = None,
        use_label: bool = True,
        baseline: Optional[Tuple[int]] = None,
        need_check_shape: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')

    h, w = preprocessed_image.shape[-2:]
    if segmentation is not None:
        assert segmentation in ('quickshift', 'felzenszwalb')
        img_for_making_mask = torch.permute(preprocessed_image[0], [1, 2, 0])
        if denorm_stats is not None:
            denorm_mean = torch.tensor(denorm_stats['mean']).to(preprocessed_image)
            denorm_std = torch.tensor(denorm_stats['std']).to(preprocessed_image)
            img_for_making_mask = img_for_making_mask * denorm_std + denorm_mean

        img_for_making_mask = (255.0 * img_for_making_mask.cpu().numpy()).astype(np.uint8)
        if img_for_making_mask.shape[-1] == 1:
            img_for_making_mask = np.tile(img_for_making_mask, [1, 1, 3])

        if segmentation == 'quickshift':
            mask = torch.tensor(quickshift(img_for_making_mask, kernel_size=3), device=preprocessed_image.device)
        else:
            mask = torch.tensor(felzenszwalb(img_for_making_mask, scale=5.), device=preprocessed_image.device)
    else:
        mask = torch.arange(0, h * w, 1, device=preprocessed_image.device).reshape(h, w)
    target = label if use_label else None

    if baseline is not None:
        h, w = preprocessed_image.shape[-2:]
        baseline = preprocessed_image.new_tensor(baseline).reshape(1, -1, 1, 1).repeat(1, 1, h, w)

    saliency = Lime(model).attribute(
        preprocessed_image, target=target, baselines=baseline, feature_mask=mask, n_samples=200)

    # saliency shape: (1, height, width)
    saliency = saliency.detach().cpu().numpy().squeeze(axis=0)
    # saliency shape: (height, width)
    saliency = saliency.mean(0)
    normed_saliency = norm_and_to_uint8(saliency, norm='abs_zero_max')

    if need_check_shape:
        if not check_shape(saliency.shape, 'n,n'):
            raise ValueError(f'grad has invalid shape: {saliency.shape}')
    return saliency.astype(np.float16), normed_saliency
