import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# pylint: disable=no-member


class SelfAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        dim_model: int,
        dim_head: int,
        dropout: float = 0.0,
        dropattn: float = 0.0,
    ):
        super().__init__()

        self.num_heads = num_heads
        self.head_size = dim_head
        self.dim_inner = num_heads * dim_head
        self.dim_model = dim_model
        self.split_size = self.dim_inner
        self.scale = math.sqrt(self.head_size)

        self.qkv_proj = nn.Linear(self.dim_model, self.dim_inner * 3)
        self.out_proj = nn.Linear(self.dim_inner, self.dim_inner)
        self.layer_norm = nn.LayerNorm(dim_model, eps=1e-5)

        self.dropout = dropout
        self.dropattn = dropattn

        self.reset_paramters()

    def _compute_attention(self, query, key, value, extended_attention_mask):
        # No need for attention masking here
        # query shape: (batch, num_heads, query_len, head_size)
        # key shape: (batch, num_heads, head_size, query_len)
        attn_weights = torch.matmul(query, key) / self.scale
        attn_weights = attn_weights.masked_fill_(~extended_attention_mask, -1.0e4)
        attn_weights = F.softmax(attn_weights, dim=-1)
        if self.dropattn > 0.0:
            attn_weights = F.dropout(attn_weights, p=self.dropattn, training=self.training)
        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights

    def _compute_q_k_v(self, hidden_states: torch.Tensor):
        bsz = hidden_states.shape[0]
        seq_len = hidden_states.shape[1]

        # shape: (batch_size, seq_len, hidden_size * 3)
        hidden_states = self.qkv_proj(hidden_states)
        # qkv shape: (batch_size, seq_len, hidden_size)
        query, key, value = hidden_states.split(self.split_size, dim=2)
        # query shape: (batch, head, seq_length, head_features)
        query = query.view(bsz, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        # key shape: (batch_size, num_heads, head_size, tgt_len)
        key = key.view(bsz, seq_len, self.num_heads, self.head_size).permute(0, 2, 3, 1)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        value = value.view(bsz, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        return query, key, value

    def forward(self, hidden_states, rel, past, extended_attention_mask):
        bsz = hidden_states.shape[0]
        tgt_len = hidden_states.shape[1]
        embed_dim = hidden_states.shape[2]

        past = None
        if past is not None:
            hidden_states = torch.cat([past, hidden_states], dim=1)

        # compute query, key, value
        query, key, value = self._compute_q_k_v(self.layer_norm(hidden_states))
        query = query[:, :, -tgt_len:]

        # Compute Self Attention
        hidden_states, attention = self._compute_attention(query, key, value, extended_attention_mask)

        # merge heads shape: (batch_size, num_heads, tgt_len, head_size) -> (batch_size, tgt, hidden_size)
        hidden_states = hidden_states.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
        hidden_states = self.out_proj(hidden_states)
        if self.dropout > 0.0:
            hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)

        # shape: (2, batch_size, num_heads, current_len, head_size)
        present = torch.stack((key.transpose(-1, -2), value))

        return hidden_states, present

    def reset_paramters(self):
        nn.init.xavier_uniform_(self.qkv_proj.weight, 1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.out_proj.weight, 1 / math.sqrt(2))
        nn.init.constant_(self.qkv_proj.bias, 0.)
        nn.init.constant_(self.out_proj.bias, 0.)