import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttentionLayer(nn.Module):

    def __init__(self, d_model, h, out_fc, dr_rate=0):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.out_fc = out_fc              # (d_model, d_embed)
        self.dropout = nn.Dropout(p=dr_rate)
        self.relu = nn.ReLU(inplace=True)


    def calculate_attention(self, query, key, value, mask):
        # query, key, value: (n_batch, h, seq_len, d_k)
        # mask: (n_batch, seq_len, seq_len)

        d_k = key.shape[-1]

        attention_score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T, (n_batch, h, seq_len, seq_len)
        attention_score = attention_score / math.sqrt(d_k)

        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0, -1e9)

        attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, h, seq_len, seq_len)

        attention_prob = self.dropout(attention_prob)
        out = torch.matmul(attention_prob, value) # (n_batch, h, seq_len, d_k)
        return out

    def forward(self, query, key, value, mask=None):
        # query, key, value: (n_batch, seq_len, d_embed)
        # mask: (n_batch, seq_len, seq_len)
        # return value: (n_batch, h, seq_len, d_k)
        n_batch = query.size(0)

        def transform(x, fc): # (n_batch, seq_len, d_embed)
            out = fc(x)       # (n_batch, seq_len, d_model)
            out = out.view(n_batch, -1, self.h, self.d_model//self.h) # (n_batch, seq_len, h, d_k)
            out = out.transpose(1, 2) # (n_batch, h, seq_len, d_k)
            return out

        #transform
        query = query.unsqueeze(1)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)
        mask = mask.unsqueeze(1)
        
        out = self.calculate_attention(query, key, value, mask ) # (n_batch, h, seq_len, d_k)

        out = out.transpose(1, 2) # (n_batch, seq_len, h, d_k)
        out = out.contiguous().view(n_batch,  -1, self.d_model) # (n_batch, seq_len, d_model)

        out = self.out_fc(out) # (n_batch, seq_len, d_embed)
        out = self.relu(out)

        return out
