from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn

from ..utils import check_shape, norm_and_to_uint8


def compute_constant_mask(
        model: nn.Module,
        preprocessed_image: torch.Tensor,
        label: torch.Tensor,
        baseline: Optional[str] = None,
        need_check_shape: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')

    mask = np.ones(preprocessed_image.shape[-2:])
    normed_mask = norm_and_to_uint8(mask, norm='abs_zero_max')
    if need_check_shape:
        if not check_shape(mask.shape, 'n,n'):
            raise ValueError(f'mask has invalid shape: {mask.shape}')
    return mask.astype(np.float16), normed_mask
