import torch
import torch.utils
import torch.utils.checkpoint

import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):

    def __init__(self,
                 embed_dim,
                 n_heads,
                 dropout):
        super().__init__()
        assert embed_dim % n_heads == 0

        self.num_attention_heads = n_heads
        self.attention_head_dim = int(embed_dim / self.num_attention_heads) # d_h
        self.all_head_size = self.num_attention_heads * self.attention_head_dim

        # key, query, value projections for all heads
        self.key = nn.Linear(embed_dim, self.all_head_size)
        self.query = nn.Linear(embed_dim, self.all_head_size)
        self.value = nn.Linear(embed_dim, self.all_head_size)

        # regularization
        self.attn_drop = nn.Dropout(dropout)

        # output projection
        self.proj = nn.Linear(embed_dim, embed_dim)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_dim)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, q, k, v, attention_mask=None):
        B, L, H = q.shape #
        #x = x.transpose(0, 1).contiguous()  # (B, L, H) -> (T, L, H)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(q) # (B, L, D)
        k = self.key(k)
        v = self.value(v)
        q = self.transpose_for_scores(q) # (B, h, L, d_k)
        k = self.transpose_for_scores(k) # (B, h, L, d_k)
        v = self.transpose_for_scores(v) # (B, h, L, d_k)

        att = torch.matmul(q, (k.transpose(-1, -2)) * (1.0 / math.sqrt(self.attention_head_dim))) # [B, h, L, L,]
        extended_attention_mask = attention_mask[:, None, None, :]
        extended_attention_mask = extended_attention_mask.to(dtype=att.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(att.dtype).min
        if attention_mask is not None:
            att = att + extended_attention_mask # .masked_fill(extended_attention_mask == 0, float('-inf'))
            #att = att.masked_fill(extended_attention_mask == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        y = torch.matmul(att, v)  #  # (B, h, d_k, d_k) X (B, h, L, d_k) -> (B, h, L, d_k)
        y = y.permute(0, 2, 1, 3).contiguous().view(B, L, H)  # (B, h, L, d_k) -> (B, L, D)
        # output projection
        y = self.proj(y)

        return y

class PhysicsSelfAttentionLayer(nn.Module):

    def __init__(self, num_hidden, n_heads=4, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(num_hidden)
        self.ln2 = nn.LayerNorm(num_hidden)
        self.ln3 = nn.LayerNorm(num_hidden)
        self.embed_dim = num_hidden
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        # self attention
        self.self_attn = MultiHeadSelfAttention(embed_dim=num_hidden,
                                                n_heads=n_heads,
                                                dropout=dropout)

        self.mlp = nn.Sequential(
            nn.Linear(self.embed_dim, 4 * self.embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * self.embed_dim, self.embed_dim),
        )
    def forward(self, tgt, memory, attention_mask):
        #compute self attention
        q, k = tgt, tgt
        tgt2 = self.self_attn(q, k, tgt, attention_mask=attention_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.ln1(tgt)
        # positionwise feed forward network
        tgt2 = self.mlp(tgt)
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.ln3(tgt)
        return tgt



class PhysicsAttentionLayer(nn.Module):

    def __init__(self, num_hidden, n_heads=4, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(num_hidden)
        self.ln2 = nn.LayerNorm(num_hidden)
        self.ln3 = nn.LayerNorm(num_hidden)
        self.embed_dim = num_hidden
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)


        # self attention
        self.self_attn = MultiHeadSelfAttention(embed_dim=num_hidden,
                                                n_heads=n_heads,
                                                dropout=dropout)
        # cross attention
        self.cross_attn = MultiHeadSelfAttention(embed_dim=num_hidden,
                                                n_heads=n_heads,
                                                dropout=dropout)

        self.mlp = nn.Sequential(
            nn.Linear(self.embed_dim, 4 * self.embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * self.embed_dim, self.embed_dim),
        )
    def forward(self, tgt, memory, attention_mask):
        #compute self attention
        q, k = tgt, tgt
        tgt2 = self.self_attn(q, k, tgt, attention_mask=attention_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.ln1(tgt)

        # cross attention
        tgt2 = self.cross_attn(tgt, memory, memory, attention_mask=attention_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.ln2(tgt)

        # positionwise feed forward network
        tgt2 = self.mlp(tgt)
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.ln3(tgt)
        return tgt

class ContextAttentionLayer(nn.Module):

    def __init__(self, num_hidden, n_heads=4, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(num_hidden)
        self.ln2 = nn.LayerNorm(num_hidden)
        self.ln3 = nn.LayerNorm(num_hidden)
        self.embed_dim = num_hidden
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        #
        # # self attention
        # self.self_attn = MultiHeadSelfAttention(embed_dim=num_hidden,
        #                                         n_heads=n_heads,
        #                                         dropout=dropout)
        # cross attention
        self.cross_attn = MultiHeadSelfAttention(embed_dim=num_hidden,
                                                n_heads=n_heads,
                                                dropout=dropout)

        self.mlp = nn.Sequential(
            nn.Linear(self.embed_dim, 4 * self.embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * self.embed_dim, self.embed_dim),
        )
    def forward(self, tgt, memory, attention_mask):
        #compute self attention
        # q, k = tgt, tgt
        # tgt2 = self.self_attn(q, k, tgt, attention_mask=attention_mask)
        # tgt = tgt + self.dropout1(tgt2)
        # tgt = self.ln1(tgt)

        # cross attention
        tgt2 = self.cross_attn(tgt, memory, memory, attention_mask=attention_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.ln2(tgt)

        # positionwise feed forward network
        tgt2 = self.mlp(tgt)
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.ln3(tgt)
        return tgt


class GlobalForceLayer(nn.Module):

    def __init__(self, num_hidden, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(num_hidden)
        self.ln2 = nn.LayerNorm(num_hidden)
        self.ln3 = nn.LayerNorm(num_hidden)
        self.embed_dim = num_hidden

        self.V_MLP_g = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(num_hidden, num_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(num_hidden, num_hidden),
            nn.Sigmoid()
        )
        #
        # self.E_MLP = nn.Sequential(
        #     nn.Linear(num_hidden, num_hidden),
        #     nn.ReLU(),
        #     nn.Linear(num_hidden, num_hidden),
        #     nn.ReLU(),
        #     nn.Linear(num_hidden, num_hidden)
        # )
        #
        self.E_MLP_g = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(num_hidden, num_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(num_hidden, num_hidden),
            nn.Sigmoid()
        )

    def forward(self, h_V, h_E, force, mask):
        #compute self attention
        # q, k = tgt, tgt
        # tgt2 = self.self_attn(q, k, tgt, attention_mask=attention_mask)
        # tgt = tgt + self.dropout1(tgt2)
        # tgt = self.ln1(tgt)
        B, L, D = h_V.shape


        # node update
        force_context_v = force.mean(1).unsqueeze(1).repeat(1, L, 1)
        force_context_v = self.V_MLP_g(force_context_v)
        h_V_new = h_V * force_context_v
        mask_V = mask.unsqueeze(-1)
        h_V_new = mask_V * h_V_new


        # edge update
        B, L, K, D = h_E.shape
        force_context_e = force.mean(1).unsqueeze(1).unsqueeze(1).repeat(1, L, K, 1)
        force_context_e = self.E_MLP_g(force_context_e)
        h_E_new = h_E * force_context_e
        mask_V = mask.unsqueeze(-1).unsqueeze(-1)
        h_E_new = mask_V * h_E_new

        return h_V_new, h_E_new




class GlobalContextLayer(nn.Module):

    def __init__(self, num_hidden,  dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(num_hidden)
        self.ln2 = nn.LayerNorm(num_hidden)
        self.ln3 = nn.LayerNorm(num_hidden)
        self.embed_dim = num_hidden

        self.V_MLP_g = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(num_hidden, num_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(num_hidden, num_hidden),
            nn.Sigmoid()
        )

    def forward(self, h_V, mask):

        B, L, D = h_V.shape


        # node update
        context_v = h_V.mean(1).unsqueeze(1).repeat(1, L, 1)
        context_v = self.V_MLP_g(context_v)
        h_V_new = h_V * context_v
        mask_V = mask.unsqueeze(-1)
        h_V_new = mask_V * h_V_new

        return h_V_new
