from typing import Optional
import math
import torch
from pado.core import PadoModule

from pado.nn.transformer.utils import apply_attn_mask, apply_weak_attention_suppression

__all__ = ["ScaledDotProduct"]


class ScaledDotProduct(PadoModule):

    def __init__(self,
                 scaling_factor: int, *,
                 normalize: bool = True,
                 was: bool = False,
                 was_gamma: float = 0.5) -> None:
        super().__init__()
        self.scaling_factor = scaling_factor
        self.scaling = float(1 / math.sqrt(scaling_factor))
        self.normalize = normalize  # apply softmax

        # apply weak-attention suppression
        # https://arxiv.org/abs/2005.09137
        self.was = was
        self.was_gamma = was_gamma

    def reset_scaling(self, new_scaling_factor: int) -> None:
        self.scaling_factor = new_scaling_factor
        self.scaling = float(1 / math.sqrt(new_scaling_factor))

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Scaled dot-product W = softmax(QK^T / sqrt(dim))
        :param query:   (batch_size, num_heads, query_length, head_dim)
        :param key:     (batch_size, num_heads, key_length, head_dim)
        :param mask:    (batch_size, query_length, key_length)  bool, T: valid, F: pad
        :return:        (batch_size, num_heads, query_length, key_length)
        """
        q_len, dq = query.shape[-2:]
        k_len, dk = key.shape[-2:]
        if (dq != dk) or (query.ndim != key.ndim):
            raise ValueError(f"Query {query.shape}, Key {key.shape} shape mismatch.")

        attn = torch.matmul(query, key.transpose(2, 3))
        attn *= self.scaling  # inplace OK
        attn = apply_attn_mask(attn, mask)
        if self.was:
            attn = apply_weak_attention_suppression(attn, mask, gamma=self.was_gamma)
        if self.normalize:
            attn = torch.softmax(attn, dim=-1)
            if mask is not None:
                attn = attn.masked_fill(torch.logical_not(mask.unsqueeze(1)), 0.0)
        return attn

    def extra_repr(self) -> str:
        s = f"scaling_factor={self.scaling_factor}"
        if not self.normalize:
            s += ", normalize=False"
        if self.was:
            s += f", was=True, was_gamma={self.was_gamma}"
        return s
