import math

import torch
from torch import nn as nn
from torch.nn import functional as F


class ScaledDotProductAttention(nn.Module):
    def forward(self, query, key, value, mask=None, return_attention_and_scores=False, att_sig="sigmoid"):
        dk = query.size()[-1]
        scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        if att_sig == "sigmoid":
            attention = torch.nn.Sigmoid()(scores)
        elif att_sig == "softmax":
            attention = F.softmax(scores, dim=-1)
        if return_attention_and_scores:
            return attention.matmul(value), attention, scores
        else:
            return attention.matmul(value)
