import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# pylint: disable=no-member


class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.hidden_size % config.num_attention_heads == 0

        self.num_heads = config.num_attention_heads
        self.head_size = config.hidden_size // config.num_attention_heads
        self.split_size = config.hidden_size
        self.scale = math.sqrt(self.head_size)

        self.qkv_proj = nn.Linear(config.hidden_size, config.hidden_size * 3)

        self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)

        self.attn_dropout = nn.Dropout(config.attention_dropout_prob)
        self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)

        self.reset_paramters()

    def _compute_attention(self, query, key, value, extended_attention_mask):
        # No need for attention masking here
        # query shape: (batch, num_heads, query_len, head_size)
        # key shape: (batch, num_heads, head_size, query_len)
        attn_weights = torch.matmul(query, key) / self.scale
        attn_weights = attn_weights.masked_fill_(~extended_attention_mask, -1.0e4)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights

    def _compute_q_k_v(self, hidden_states: torch.Tensor):
        bsz = hidden_states.shape[0]
        seq_len = hidden_states.shape[1]

        # shape: (batch_size, seq_len, hidden_size * 3)
        hidden_states = self.qkv_proj(hidden_states)
        # qkv shape: (batch_size, seq_len, hidden_size)
        query, key, value = hidden_states.split(self.split_size, dim=2)
        # query shape: (batch, head, seq_length, head_features)
        query = query.view(bsz, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        # key shape: (batch_size, num_heads, head_size, tgt_len)
        key = key.view(bsz, seq_len, self.num_heads, self.head_size).permute(0, 2, 3, 1)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        value = value.view(bsz, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        return query, key, value

    def forward(self, hidden_states, past, extended_attention_mask):
        bsz = hidden_states.shape[0]
        tgt_len = hidden_states.shape[1]
        embed_dim = hidden_states.shape[2]

        # compute query, key, value
        query, key, value = self._compute_q_k_v(hidden_states)

        if past is not None:
            # key shape (batch*num_heads, current_len, head_size)
            past_key, past_value = past[0].transpose(-1, -2), past[1]
            key = torch.cat((past_key, key), dim=-1)
            value = torch.cat((past_value, value), dim=-2)

        # Compute Self Attention
        hidden_states, attention = self._compute_attention(query, key, value, extended_attention_mask)

        # merge heads shape: (batch_size, num_heads, tgt_len, head_size) -> (batch_size, tgt, hidden_size)
        hidden_states = hidden_states.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
        hidden_states = self.out_proj(hidden_states)
        hidden_states = self.hidden_dropout(hidden_states)

        # shape: (2, batch_size, num_heads, current_len, head_size)
        present = torch.stack((key.transpose(-1, -2), value))

        return hidden_states, present, attention

    def reset_paramters(self):
        nn.init.xavier_uniform_(self.qkv_proj.weight, 1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.out_proj.weight, 1 / math.sqrt(2))
        nn.init.constant_(self.qkv_proj.bias, 0.)
        nn.init.constant_(self.out_proj.bias, 0.)