from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from captum.attr import Occlusion

from ..utils import check_shape, norm_and_to_uint8


def compute_occlusion(
        model: nn.Module,
        preprocessed_image: torch.Tensor,
        label: torch.Tensor,
        sliding_window_shapes: Tuple[int, int],
        strides: Optional[Union[Tuple[int, int], int]] = None,
        baselines: Optional[float] = None,
        need_check_shape: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')
    # the parameter parsed from yaml file is a list rather than a tuple
    sliding_window_shapes = tuple(sliding_window_shapes)
    strides = strides if isinstance(strides, int) else tuple(strides)

    # saliency shape: (1, 1, height ,width)
    saliency = Occlusion(model).attribute(
        preprocessed_image,
        target=label,
        sliding_window_shapes=sliding_window_shapes,
        strides=strides,
        baselines=baselines)
    saliency = saliency.detach().cpu().numpy().squeeze(0)
    saliency = saliency.mean(0)
    normed_saliency = norm_and_to_uint8(saliency, norm='abs_zero_max')

    if need_check_shape:
        if not check_shape(saliency.shape, 'n,n'):
            raise ValueError(f'saliency has invalid shape: {saliency.shape}')
    return saliency.astype(np.float16), normed_saliency