from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
from skimage import filters
from skimage.draw import disk

from ..utils import check_shape, norm_and_to_uint8


def compute_constant_mask_circle(
    model: nn.Module,
    preprocessed_image: torch.Tensor,
    label: torch.Tensor,
    size: float = 0.7,
    need_check_shape: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    if need_check_shape:
        if not check_shape(preprocessed_image.shape, '1,c,n,n'):
            raise ValueError(f'image has invalid shape: {preprocessed_image.shape}')

    attr_map = np.zeros(preprocessed_image.shape[-2:])
    height, width = attr_map.shape
    assert height == width
    rr, cc = disk((height // 2, width // 2), int(height * 0.5 * size))
    attr_map[rr, cc] = 1
    attr_map = filters.gaussian(attr_map)
    normed_attr_map = norm_and_to_uint8(attr_map, norm='abs_min_max')
    if need_check_shape:
        if not check_shape(attr_map.shape, 'n,n'):
            raise ValueError(f'attr_map has invalid shape: {attr_map.shape}')
    return attr_map.astype(np.float16), normed_attr_map
