from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
from pytorch_grad_cam import FullGrad, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torch.nn import functional as F
from torchray.attribution.grad_cam import grad_cam

from ..utils import check_shape, norm_and_to_uint8


def compute_gradcam(
        model: nn.Module,
        preprocessed_image: torch.Tensor,
        label: torch.Tensor,
        saliency_layer: str,
        need_check_shape: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')

    saliency = grad_cam(model, preprocessed_image, label, saliency_layer=saliency_layer)
    image_shape = preprocessed_image.shape[-2:]
    saliency = F.interpolate(saliency, image_shape, mode="bilinear", align_corners=False)
    # grad shape: (1, height, width)
    grad = saliency.detach().cpu().numpy().squeeze(0)
    # grad shape: (height, width)
    grad = grad.mean(0)
    normed_grad = norm_and_to_uint8(grad, norm='abs_zero_max')

    if need_check_shape:
        if not check_shape(grad.shape, 'n,n'):
            raise ValueError(f'grad has invalid shape: {grad.shape}')
    return grad.astype(np.float16), normed_grad


def compute_gradcam_pp(
        model: nn.Module,
        preprocessed_image: torch.Tensor,
        label: torch.Tensor,
        saliency_layer: str,
        multi_class: bool = True,
        need_check_shape: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')
    target_layers = [model.get_submodule(saliency_layer)]

    with GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=True) as gc_pp:
        if multi_class:
            targets = [ClassifierOutputTarget(label)]
        else:
            targets = None
        # saliency has shape (1, height, width) and is already interpolated to input size
        saliency = gc_pp(preprocessed_image, targets=targets)
        grad = saliency[0]
        normed_grad = norm_and_to_uint8(grad, norm='abs_zero_max')
        if need_check_shape:
            if not check_shape(grad.shape, 'n,n'):
                raise ValueError(f'grad has invalid shape: {grad.shape}')
        return grad.astype(np.float16), normed_grad


def compute_fullgrad(
        model: nn.Module,
        preprocessed_image: torch.Tensor,
        label: torch.Tensor,
        saliency_layer: str,
        multi_class: bool = True,
        need_check_shape: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')
    # target layers will be ignored by FullGrad
    with FullGrad(model=model, target_layers=[], use_cuda=True) as gc_pp:
        if multi_class:
            targets = [ClassifierOutputTarget(label)]
        else:
            targets = None
        # saliency has shape (1, height, width) and is already interpolated to input size
        saliency = gc_pp(preprocessed_image, targets=targets)
        grad = saliency[0]
        normed_grad = norm_and_to_uint8(grad, norm='abs_zero_max')
        if need_check_shape:
            if not check_shape(grad.shape, 'n,n'):
                raise ValueError(f'grad has invalid shape: {grad.shape}')
        return grad.astype(np.float16), normed_grad
