import math
import numpy as np
import torch as T
import torch.nn as nn
import torch.nn.functional as F


class TruncatedStructuredAttention(nn.Module):
    def __init__(self, config, query_dim: int = 512, value_dim: int = 512):
        super(TruncatedStructuredAttention, self).__init__()

        self.qD = query_dim
        self.vD = value_dim

        self.config = config
        self.heads = config["heads"]
        self.d = config["head_dim"]
        self.dropout = config["dropout"]
        self.attn_dropout = config["attn_dropout"]
        self.eps = 1e-32
        self.position_max_len = config["position_max_len"]
        self.attn_scalar = T.tensor(math.sqrt(self.d)).float()

        # initialize params
        self.init_QKV()
        self.init_head_compose()

    """
    Parameter Initializers
    """

    def init_QKV(self):
        self.query_linear = nn.Linear(self.qD, self.heads * self.d, bias=False)
        self.key_linear = nn.Linear(self.vD, self.heads * self.d, bias=False)
        self.value_linear = nn.Linear(self.vD, self.heads * self.d, bias=False)

    # %%
    def init_head_compose(self):
        self.head_compose_linear = nn.Linear(self.heads * self.d, self.qD, bias=False)

    # %%
    def sum_normalize(self, logits, dim=-1):
        return logits / T.sum(logits + self.eps, keepdim=True, dim=dim)

    # %%

    def score_contents(self, Q, K):
        N, _, qS, _ = Q.size()
        _, _, vS, w, _ = K.size()

        assert qS == vS

        assert Q.size() == (N, self.heads, qS, self.d)
        assert K.size() == (N, self.heads, vS, w, self.d)

        Q = Q.unsqueeze(-2)
        Kt = K.permute(0, 1, 2, 4, 3).contiguous()

        assert Q.size() == (N, self.heads, qS, 1, self.d)
        assert Kt.size() == (N, self.heads, vS, self.d, w)

        content_scores = T.matmul(Q, Kt)
        assert content_scores.size() == (N, self.heads, qS, 1, w)
        content_scores = content_scores.squeeze(-2)

        return content_scores / self.attn_scalar.to(content_scores.device)

    """
    Forward Function
    """

    # %%
    def forward(self, Q, K, V,
                query_position_idx,
                key_position_idx,
                structured_attention_mask):
        N, qS, _ = Q.size()
        _, vS, w, _ = V.size()

        assert query_position_idx.size() == (N, qS, 1)
        assert key_position_idx.size() == (N, qS, w)
        relative_mat_idx = key_position_idx - query_position_idx

        right_mask = T.where(relative_mat_idx <= 0,
                             T.zeros_like(relative_mat_idx).float().to(Q.device),
                             T.ones_like(relative_mat_idx).float().to(Q.device))

        left_mask = T.where(relative_mat_idx >= 0,
                            T.zeros_like(relative_mat_idx).float().to(Q.device),
                            T.ones_like(relative_mat_idx).float().to(Q.device))

        assert right_mask.size() == (N, qS, w)
        assert left_mask.size() == (N, qS, w)

        position_mask = T.cat([right_mask.unsqueeze(1).repeat(1,self.heads//2, 1, 1),
                               left_mask.unsqueeze(1).repeat(1, self.heads - self.heads//2, 1, 1)],
                              dim=1)
        assert position_mask.size() == (N, self.heads, qS, w)

        assert K.size() == V.size()

        Q = self.query_linear(Q)
        K = self.key_linear(K)
        V = self.value_linear(V)

        assert Q.size() == (N, qS, self.heads * self.d)
        assert V.size() == (N, vS, w, self.heads * self.d)
        assert K.size() == V.size()

        Q = Q.view(N, qS, self.heads, self.d)
        K = K.view(N, vS, w, self.heads, self.d)
        V = V.view(N, vS, w, self.heads, self.d)

        Q = Q.permute(0, 2, 1, 3).contiguous()
        K = K.permute(0, 3, 1, 2, 4).contiguous()
        V = V.permute(0, 3, 1, 2, 4).contiguous()

        structured_attention_mask = structured_attention_mask.unsqueeze(1)
        assert structured_attention_mask.size() == (N, 1, qS, w)
        structured_attention_mask = structured_attention_mask * position_mask

        content_scores = self.score_contents(Q, K)

        assert content_scores.size() == (N, self.heads, qS, w)

        edge_scores = content_scores
        exp_edge_scores = structured_attention_mask * (T.exp(edge_scores - T.max(edge_scores)))
        attention_scores = self.sum_normalize(exp_edge_scores, dim=-1)

        attention_scores = F.dropout(attention_scores, p=self.attn_dropout, training=self.training)
        attention_scores = attention_scores.unsqueeze(-2)

        assert V.size() == (N, self.heads, qS, w, self.d)
        assert attention_scores.size() == (N, self.heads, qS, 1, w)

        attended_values = T.matmul(attention_scores, V).squeeze(-2)

        assert attended_values.size() == (N, self.heads, qS, self.d)

        attended_values = attended_values.permute(0, 2, 1, 3).contiguous()
        attended_values = attended_values.view(N, qS, self.heads * self.d)

        attended_values = self.head_compose_linear(attended_values)
        attended_values = attended_values.view(N, qS, self.qD)

        return {"attended_values": attended_values}
