import math
import numpy as np
import torch as T
import torch.nn as nn
import torch.nn.functional as F


class StructuredAttention(nn.Module):
    def __init__(self, config, query_dim: int = 512, value_dim: int = 512):
        super(StructuredAttention, self).__init__()

        self.qD = query_dim
        self.vD = value_dim

        self.config = config
        self.heads = config["heads"]
        self.d = config["head_dim"]
        self.dropout = config["dropout"]
        self.attn_dropout = config["attn_dropout"]
        self.eps = 1e-32
        self.attn_scalar = T.tensor(math.sqrt(self.d)).float()

        # initialize params
        self.init_QKV()
        self.init_head_compose()

    """
    Parameter Initializers
    """

    def init_QKV(self):
        self.query_linear = nn.Linear(self.qD, self.heads * self.d, bias=False)
        self.key_linear = nn.Linear(self.vD, self.heads * self.d, bias=False)
        self.value_linear = nn.Linear(self.vD, self.heads * self.d, bias=False)

    # %%
    def init_head_compose(self):
        self.head_compose_linear = nn.Linear(self.heads * self.d, self.qD, bias=False)

    # %%
    def sum_normalize(self, logits, dim=-1):
        return logits / T.sum(logits + self.eps, keepdim=True, dim=dim)

    # %%

    def score_contents(self, Q, K):
        N, _, qS, _ = Q.size()
        vS = K.size(2)

        assert Q.size() == (N, self.heads, qS, self.d)
        assert K.size() == (N, self.heads, vS, self.d)

        Kt = K.permute(0, 1, 3, 2).contiguous()
        content_scores = T.matmul(Q, Kt)
        assert content_scores.size() == (N, self.heads, qS, vS)

        return content_scores / self.attn_scalar.to(content_scores.device)

    """
    Forward Function
    """

    # %%
    def forward(self, Q, K, V,
                structured_attention_mask):
        N, qS, _ = Q.size()
        _, vS, _ = V.size()

        assert K.size() == V.size()

        Q = self.query_linear(Q)
        K = self.key_linear(K)
        V = self.value_linear(V)

        assert Q.size() == (N, qS, self.heads * self.d)
        assert V.size() == (N, vS, self.heads * self.d)
        assert K.size() == V.size()

        Q = Q.view(N, qS, self.heads, self.d)
        K = K.view(N, vS, self.heads, self.d)
        V = V.view(N, vS, self.heads, self.d)

        Q = Q.permute(0, 2, 1, 3).contiguous()
        K = K.permute(0, 2, 1, 3).contiguous()
        V = V.permute(0, 2, 1, 3).contiguous()

        structured_attention_mask = structured_attention_mask.unsqueeze(1)
        assert structured_attention_mask.size() == (N, 1, qS, vS)

        edge_scores = self.score_contents(Q, K)
        exp_edge_scores = structured_attention_mask * (T.exp(edge_scores - T.max(edge_scores)))
        attention_scores = self.sum_normalize(exp_edge_scores, dim=-1)

        attention_scores = F.dropout(attention_scores, p=self.attn_dropout, training=self.training)
        attended_values = T.matmul(attention_scores, V)

        assert attended_values.size() == (N, self.heads, qS, self.d)

        attended_values = attended_values.permute(0, 2, 1, 3).contiguous()
        attended_values = attended_values.view(N, qS, self.heads * self.d)

        attended_values = self.head_compose_linear(attended_values)
        attended_values = attended_values.view(N, qS, self.qD)

        return {"attended_values": attended_values}
