import math
import numpy as np
import torch as T
import torch.nn as nn
import torch.nn.functional as F


class Multiheaded_Attention(nn.Module):
    def __init__(self, config, query_dim: int = 512, value_dim: int = 512):
        super(Multiheaded_Attention, self).__init__()

        self.qD = query_dim
        self.vD = value_dim

        self.config = config
        self.heads = config["heads"]
        self.d = config["head_dim"]
        self.dropout = config["dropout"]
        self.attn_dropout = config["attn_dropout"]
        self.eps = 1e-32
        self.position_max_len = config["position_max_len"]
        self.attn_scalar = T.tensor(math.sqrt(self.d)).float()

        # initialize params
        self.init_QKV()
        self.init_position()
        self.init_head_compose()

    """
    Parameter Initializers
    """

    def init_QKV(self):
        self.query_linear = nn.Linear(self.qD, self.heads * self.d, bias=False)
        self.key_linear = nn.Linear(self.vD, self.heads * self.d, bias=False)
        self.value_linear = nn.Linear(self.vD, self.heads * self.d, bias=False)

    # %%
    def init_position(self):
        self.content_bias = nn.Parameter(T.zeros(self.heads))
        self.position_bias = nn.Parameter(T.zeros(self.heads))
        self.position_linear = nn.Linear(self.qD, self.heads * self.d, bias=False)

    # %%
    def init_head_compose(self):
        self.head_compose_linear = nn.Linear(self.heads * self.d, self.qD, bias=False)

    # %%
    def score_positions(self, Q, vS, relative_position_embeddings):
        N, H, qS, d = Q.size()
        S = max([qS, vS])
        position_idx = T.arange(S).unsqueeze(0).repeat(S, 1)
        position_idx_t = position_idx.permute(1, 0).contiguous()
        relative_mat_idx = position_idx - position_idx_t + self.position_max_len
        relative_mat_idx = relative_mat_idx[0:qS, 0:vS]

        RE = relative_position_embeddings(relative_mat_idx.to(Q.device))
        assert RE.size() == (qS, vS, self.qD)
        RE = self.position_linear(RE)
        assert RE.size() == (qS, vS, self.heads * self.d)

        RE = RE.view(qS, vS, self.heads, self.d)
        RE = RE.permute(2, 0, 1, 3).contiguous().unsqueeze(0)
        assert RE.size() == (1, self.heads, qS, vS, self.d)

        REt = RE.permute(0, 1, 2, 4, 3)

        Q = Q.unsqueeze(-2)
        assert Q.size() == (N, H, qS, 1, d)

        v = self.position_bias.view(1, self.heads, 1, 1, 1)
        position_scores = T.matmul(Q + v, REt)

        assert position_scores.size() == (N, H, qS, 1, vS)
        position_scores = position_scores.squeeze(-2)

        return position_scores / self.attn_scalar.to(position_scores.device)

    # %%
    def sum_normalize(self, logits, dim=-1):
        return logits / T.sum(logits + self.eps, keepdim=True, dim=dim)

    # %%

    def score_contents(self, Q, K):
        N, _, qS, _ = Q.size()
        vS = K.size(2)

        assert Q.size() == (N, self.heads, qS, self.d)
        assert K.size() == (N, self.heads, vS, self.d)

        u = self.content_bias.view(1, self.heads, 1, 1)

        Kt = K.permute(0, 1, 3, 2).contiguous()
        content_scores = T.matmul(Q + u, Kt)
        assert content_scores.size() == (N, self.heads, qS, vS)

        return content_scores / self.attn_scalar.to(content_scores.device)

    """
    Forward Function
    """

    # %%
    def forward(self, Q, K, V,
                relative_position_embeddings,
                attention_mask,
                special_mask):
        N, qS, _ = Q.size()
        _, vS, _ = V.size()

        assert K.size() == V.size()

        Q = self.query_linear(Q)
        K = self.key_linear(K)
        V = self.value_linear(V)

        assert Q.size() == (N, qS, self.heads * self.d)
        assert V.size() == (N, vS, self.heads * self.d)
        assert K.size() == V.size()

        Q = Q.view(N, qS, self.heads, self.d)
        K = K.view(N, vS, self.heads, self.d)
        V = V.view(N, vS, self.heads, self.d)

        Q = Q.permute(0, 2, 1, 3).contiguous()
        K = K.permute(0, 2, 1, 3).contiguous()
        V = V.permute(0, 2, 1, 3).contiguous()

        attention_mask = attention_mask.unsqueeze(1)

        assert attention_mask.size() == (N, 1, qS, vS)
        special_mask = special_mask.unsqueeze(1)
        assert special_mask.size() == (N, 1, qS, vS)

        content_scores = self.score_contents(Q, K)
        position_scores = self.score_positions(Q, vS, relative_position_embeddings)

        edge_scores = content_scores + position_scores
        attention_scores = self.sum_normalize(special_mask * F.softmax(edge_scores, dim=-1))

        attention_scores = F.dropout(attention_scores, p=self.attn_dropout, training=self.training)
        attended_values = T.matmul(attention_scores, V)

        assert attended_values.size() == (N, self.heads, qS, self.d)

        attended_values = attended_values.permute(0, 2, 1, 3).contiguous()
        attended_values = attended_values.view(N, qS, self.heads * self.d)

        attended_values = self.head_compose_linear(attended_values)
        attended_values = attended_values.view(N, qS, self.qD)

        return {"attended_values": attended_values}
