from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from captum.attr import IntegratedGradients

from ..utils import check_shape, norm_and_to_uint8


def compute_integrated_gradients(
        model: nn.Module,
        preprocessed_image: torch.Tensor,
        label: torch.Tensor,
        remove_negative: bool = False,
        baseline: Optional[Tuple[int]] = None,
        need_check_shape: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')

    if baseline is not None:
        h, w = preprocessed_image.shape[-2:]
        baseline = preprocessed_image.new_tensor(baseline).reshape(1, -1, 1, 1).repeat(1, 1, h, w)

    saliency = IntegratedGradients(model).attribute(preprocessed_image, target=label, n_steps=35, baselines=baseline)
    if remove_negative:
        saliency = torch.clamp(saliency, min=0.0)

    # grad shape: (1, height, width)
    grad = saliency.detach().cpu().numpy().squeeze(axis=0)
    # grad shape: (height, width)
    grad = grad.mean(0)
    normed_grad = norm_and_to_uint8(grad, norm='abs_zero_max')

    if need_check_shape:
        if not check_shape(grad.shape, 'n,n'):
            raise ValueError(f'grad has invalid shape: {grad.shape}')
    return grad.astype(np.float16), normed_grad
