import torch 
import numpy as np
from torch import nn


def scaled_dot_product(q, k, v, mask=None, temperature=1, dropout=0.0, training=True, attention="softmax"):

    factor = 1/np.sqrt(q.size(-1))
    attn_logits = q @ k.transpose(-2, -1) * factor
    
    if mask is not None:
        attn_logits += mask 

    if attention == "softmax":
        attention = nn.functional.softmax(attn_logits/temperature, dim=-1)
    elif attention == "sigmoid": #based on https://arxiv.org/pdf/2409.04431
        attention = nn.functional.sigmoid(attn_logits - torch.log(attn_logits.shape[-1]))
    elif attention == "relu": #based on https://arxiv.org/pdf/2309.08586
        attention = nn.functional.relu(attn_logits)/attn_logits.shape[-1]

    attention = torch.dropout(attention, dropout, train=training)
    values = attention @ v

    return values, attention

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, temperature=1, dropout=0.0, activation_attention="softmax"):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 module wrt the number of heads."
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.temperature = temperature
        self.dropout = dropout
        self.activation_attention = activation_attention
        self.qkv_proj = torch.nn.Linear(input_dim, 3*embed_dim, bias=True)
        self.o_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True)
        self._init_params()

    def _init_params(self):
        # Original Transformer initialization, see PyTorch documentation
        torch.nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, src_key_padding_mask, matrix_mask, temperature=None):
        matrix_mask_extended = matrix_mask[:,None,:,:]
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv_proj(x)
        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)      #[Batch, SeqLen, 4, 3*96], considering embed_dim=384 and num_heads=4
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)
        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=matrix_mask_extended, temperature=self.temperature if temperature is None else temperature, dropout=self.dropout, training = self.training)
        
        attention = torch.mean(attention, dim=1) # Average over heads        
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)

        return o, attention
    