from typing import List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from .normalizations import NORM2FN

# pylint:disable=no-member


class RelativeSelfAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        dim_model: int,
        dim_head: int,
        dropout: float = 0.0,
        dropattn: float = 0.0,
        layer_norm_type: str = "layer_norm"
    ):
        super().__init__()

        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_head = dim_head
        self.dim_inner = num_heads * dim_head
        self.dropout = dropout
        self.dropattn = dropattn

        self.qkv_proj = nn.Linear(dim_model, 3 * self.dim_inner, bias=True)
        self.out_proj = nn.Linear(self.dim_inner, dim_model, bias=True)
        self.layer_norm = NORM2FN[layer_norm_type](dim_model)
        # keep the variance around 1
        self.scale = 0.5 / (dim_head**0.5)

        self.reset_parameters()

    def _compute_q_k_v(self, hidden_states: torch.Tensor):
        bsz = hidden_states.shape[0]
        seq_len = hidden_states.shape[1]

        # shape: (batch_size, seq_len, hidden_size * 3)
        hidden_states = self.qkv_proj(hidden_states)
        # qkv shape: (batch_size, seq_len, hidden_size)
        query, key, value = hidden_states.split(self.dim_inner, dim=2)
        # query shape: (batch, head, seq_length, head_features)
        query = query.view(bsz, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
        # key shape: (batch_size, num_heads, head_size, tgt_len)
        key = key.view(bsz, seq_len, self.num_heads, self.dim_head).permute(0, 2, 3, 1)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        value = value.view(bsz, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
        return query, key, value

    def forward(
        self, hidden_states, rel_pos_embedding, decoder_cache=None, decoding_cache=None, extended_attn_mask=None
    ):
        """
        Args:
            hidden_states: shape (batch, query_len, dim_model) Do not normalize it in advance
            rel_pos_embedding: shape (query_len, key_len, dim_head)
            past_hidden_states: shape (batch, xl_memory_len, dim_model) only used for training
            decoding_cache: shape (2, batch, )
        """
        # TODO: Implement fast decoding

        batch_size = hidden_states.size(0)
        query_len = hidden_states.size(1)

        if decoder_cache is not None:
            head_hidden = torch.cat([decoder_cache, hidden_states], dim=1)
        else:
            head_hidden = hidden_states

        key_len = head_hidden.shape[1]

        # shape (batch_size, num_heads, tgt_len, head_size)
        query, key, value = self._compute_q_k_v(self.layer_norm(head_hidden))

        if decoder_cache is not None:
            query = query[:, :, -query_len:]

        # shape: (batch, num_heads, query_len, key_len)
        content_attn = torch.matmul(query, key)

        # Position Attention
        query_pos = query.permute([2, 0, 1, 3])
        query_pos = query_pos.reshape(query_len, batch_size * self.num_heads, self.dim_head)

        # shape: (query_len, batch_size*num_heads, key_len)
        pos_attn = torch.matmul(query_pos, rel_pos_embedding.transpose(1, 2))
        pos_attn = pos_attn.view(query_len, batch_size, self.num_heads, key_len)
        pos_attn = pos_attn.permute([1, 2, 0, 3])

        attn_logits = (content_attn + pos_attn) * self.scale

        if extended_attn_mask is not None:
            attn_logits = attn_logits.masked_fill(~extended_attn_mask, -6e4)

        # shape: (batch, num_heads, query_len, key_len)
        attn_probs = torch.softmax(attn_logits, dim=-1)

        if self.dropattn > 0.0:
            attn_probs = F.dropout(attn_probs, p=self.dropattn, training=self.training)

        # shape: (batch, num_heads, query_len, dim_head)
        content_hidden = torch.matmul(attn_probs, value)

        # Position Value
        # shape: (query_len, batch, num_heads, key_len)
        attn_probs_pos = attn_probs.permute([2, 0, 1, 3])
        # shape: (query_len, batch, num_heads, key_len)
        attn_probs_pos = attn_probs_pos.view(query_len, batch_size * self.num_heads, key_len)

        # shape: (query_len, num_heads, dim_head)
        pos_hidden = torch.matmul(attn_probs_pos, rel_pos_embedding)
        pos_hidden = pos_hidden.view(query_len, batch_size, self.num_heads, self.dim_head)
        # shape: (batch_size, num_heads, query_len, dim_head)
        pos_hidden = pos_hidden.permute([1, 2, 0, 3])

        # shape (batch_size, num_heads, query_len, dim_head)
        head_hidden = content_hidden + pos_hidden
        head_hidden = head_hidden.transpose(1, 2).reshape(batch_size, query_len, self.dim_model)

        # Output Projection
        out_hidden_states = self.out_proj(head_hidden)

        if self.dropout > 0.0:
            out_hidden_states = F.dropout(out_hidden_states, p=self.dropout, training=self.training)

        return out_hidden_states, attn_probs

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.qkv_proj.weight.data, 1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.out_proj.weight.data, 1 / math.sqrt(2))
        nn.init.constant_(self.qkv_proj.bias.data, 0.)
        nn.init.constant_(self.out_proj.bias.data, 0.)