from typing import List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from .normalizations import NORM2FN

# pylint:disable=no-member


class CrossAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        dim_model: int,
        dim_head: int,
        dropout: float = 0.0,
        dropattn: float = 0.0,
        layer_norm_type: str = "layer_norm"
    ):
        super().__init__()

        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_head = dim_head
        self.dim_inner = num_heads * dim_head
        self.dropout = dropout
        self.dropattn = dropattn

        self.q_proj = nn.Linear(dim_model, self.dim_inner, bias=True)
        self.kv_proj = nn.Linear(dim_model, 2 * self.dim_inner, bias=True)
        self.out_proj = nn.Linear(self.dim_inner, dim_model, bias=True)
        self.layer_norm = NORM2FN[layer_norm_type](dim_model)
        # keep the variance around 1
        self.scale = 1.0 / (dim_head**0.5)

        self.reset_parameters()

    def _compute_k_v(self, hidden_states: torch.Tensor):
        bsz = hidden_states.shape[0]
        seq_len = hidden_states.shape[1]

        # shape: (batch_size, seq_len, hidden_size * 3)
        hidden_states = self.kv_proj(hidden_states)
        # qkv shape: (batch_size, seq_len, hidden_size)
        key, value = hidden_states.split(self.dim_inner, dim=2)
        # key shape: (batch_size, num_heads, head_size, tgt_len)
        key = key.view(bsz, seq_len, self.num_heads, self.dim_head).permute(0, 2, 3, 1)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        value = value.view(bsz, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
        return key, value

    def forward(self, hidden_states, cross_hidden_states, decoder_cache=None, extended_attn_mask=None):
        """
        Args:
            hidden_states: shape (batch, query_len, dim_model) Do not normalize it in advance
            cross_hidden_states: shape (batch, cross_len, dim_model) only used for training
                should be normalized before this module
            cross_cache: shape (2, batch, ) used for fast decoding
            rel_pos_embedding: shape (batch, key_len, dim_model)
        """
        # TODO: Implement fast decoding
        batch_size = hidden_states.size(0)
        query_len = hidden_states.size(1)

        # query shape: (batch, head, seq_length, head_features)
        query = self.q_proj(hidden_states).view(batch_size, query_len, self.num_heads,
                                                      self.dim_head).transpose(1, 2)

        # key shape: (batch_size, num_heads, head_size, tgt_len)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        key, value = self._compute_k_v(cross_hidden_states)

        # shape: (batch, num_heads, query_len, key_len)
        attn_logits = torch.matmul(query, key) * self.scale

        if extended_attn_mask is not None:
            attn_logits = attn_logits.masked_fill(~extended_attn_mask, -6e4)

        # shape: (batch, num_heads, query_len, key_len)
        attn_probs = torch.softmax(attn_logits, dim=-1)

        if self.dropattn > 0.0:
            attn_probs = F.dropout(attn_probs, p=self.dropattn, training=self.training)

        # shape: (batch, num_heads, query_len, dim_head)
        hidden_states = torch.matmul(attn_probs, value)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, query_len, self.dim_model)

        # Output Projection
        out_hidden_states = self.out_proj(hidden_states)

        if self.dropout > 0.0:
            out_hidden_states = F.dropout(out_hidden_states, p=self.dropout, training=self.training)

        return out_hidden_states, attn_probs

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.q_proj.weight.data, 1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.kv_proj.weight.data, 1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.out_proj.weight.data, 1 / math.sqrt(2))
        nn.init.constant_(self.q_proj.bias.data, 0.)
        nn.init.constant_(self.kv_proj.bias.data, 0.)
        nn.init.constant_(self.out_proj.bias.data, 0.)