from typing import Optional
import torch

__all__ = ["attn_mask_reshape", "apply_attn_mask", "apply_weak_attention_suppression"]

_FP16_MIN_VALUE = -1e4  # min = -65500.0
_FP32_MIN_VALUE = -1e6  # min = -3.4e38 ~= -inf


@torch.no_grad()
def attn_mask_reshape(attn: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    b, n_head, q_len, k_len = attn.shape

    mask = mask.bool()

    if mask.ndim == 2:
        if mask.shape != (q_len, k_len):
            raise ValueError(f"Attention mask 2D shape invalid. "
                             f"Should be ({q_len}, {k_len}) but got {mask.shape}.")
        mask = mask.view(1, 1, q_len, k_len)  # (1, 1, q_len, k_len)

    elif mask.ndim == 3:
        if mask.shape != (b, q_len, k_len):
            raise ValueError(f"Attention mask 3D shape invalid. "
                             f"Should be ({b}, {q_len}, {k_len}) but got {mask.shape}.")
        mask = mask.view(b, 1, q_len, k_len)  # (b, 1, q_len, k_len)

    elif mask.ndim == 4:
        if (mask.shape != (b, 1, q_len, k_len)) or (mask.shape != (b, n_head, q_len, k_len)):
            raise ValueError(f"Attention mask 4D shape invalid. "
                             f"Should be ({b}, {n_head} (or 1), {q_len}, {k_len}) but got {mask.shape}.")
    else:
        raise ValueError(f"Unsupported mask dimension, should be either 2/3/4D, but got {mask.ndim}.")
    return mask


def apply_attn_mask(attn: torch.Tensor,
                    mask: Optional[torch.Tensor]) -> torch.Tensor:
    """
    Apply attention masking, use BEFORE softmax.
    :param attn:    (batch_size, num_heads, query_length, key_length)
    :param mask:    (batch_size, query_length, key_length)  bool, T: valid, F: pad
    :return:        (batch_size, num_heads, query_length, key_length)
    """
    if mask is None:
        return attn

    mask = attn_mask_reshape(attn, mask)

    if attn.dtype == torch.float32:
        attn = attn.masked_fill(torch.logical_not(mask), _FP32_MIN_VALUE)
    else:  # FP16, BFloat16, ...
        attn = attn.masked_fill(torch.logical_not(mask), _FP16_MIN_VALUE)

    # we intentionally do not masked_fill with -inf, because there is instability issue.
    # https://github.com/pytorch/pytorch/issues/41508

    # add (HuggingFace, AllenNLP)
    # negative_offset = torch.not_equal(mask, 1) * _LARGE_NEGATIVE_VALUE
    # attn += negative_offset  # inplace OK
    return attn


def apply_weak_attention_suppression(attn: torch.Tensor,
                                     mask: Optional[torch.Tensor],
                                     gamma: float = 0.5) -> torch.Tensor:
    """
    Apply weak attention suppression (WAS), use BEFORE softmax and AFTER masking.
    https://arxiv.org/pdf/2005.09137.pdf
    Dynamically remove prob. under (mean - gamma * std)
    :param attn:    (batch_size, num_heads, query_length, key_length)
    :param mask:    (batch_size, query_length, key_length)
    :param gamma:
    :return:        (batch_size, num_heads, query_length, key_length)
    """
    if mask is not None:
        mask = attn_mask_reshape(attn, mask)

    # compute which values to be masked
    with torch.no_grad():
        score = torch.softmax(attn, dim=-1, dtype=torch.float32)
        if mask is not None:
            score = score.masked_fill(torch.logical_not(mask), 0.0)

        # we need mean and std over non-zero attention scores (unmasked ones)
        nonzero_count = torch.sum(torch.greater(score, 1e-6), dim=-1, keepdim=True).clamp_min_(1)  # at least 1

        score_mean = torch.div(1, nonzero_count)  # sum should be 1, so mean is just  (1 / L)
        score_var = torch.sum(torch.square(score - score_mean), dim=-1, keepdim=True).div_(nonzero_count)
        score_std = torch.sqrt(score_var)

        # score_sum = torch.sum(score, dim=-1, keepdim=True)
        # score_sq_sum = torch.sum(score * score, dim=-1, keepdim=True)
        # score_mean = score_sum.div_(nonzero_count)
        # score_sq_mean = score_sq_sum.div_(nonzero_count)
        # score_std = torch.sqrt(score_sq_mean - (score_mean * score_mean) + 1e-8)

        threshold = score_mean - gamma * score_std
        if mask is not None:
            mask = torch.logical_and(mask, torch.greater_equal(score, threshold))
        else:
            mask = torch.greater_equal(score, threshold)

    # add (HuggingFace, AllenNLP)
    # negative_offset = torch.not_equal(mask, 1) * _LARGE_NEGATIVE_VALUE
    # attn += negative_offset  # inplace OK

    # fill (FairSeq), requires safe_softmax
    if attn.dtype == torch.float32:
        attn = attn.masked_fill(torch.logical_not(mask), _FP32_MIN_VALUE)
    else:  # FP16, BFloat16
        attn = attn.masked_fill(torch.logical_not(mask), _FP16_MIN_VALUE)
    return attn
