import torch
from torch import nn

from tgnnexplainer.xgraph.models.ext.tgn.utils.utils import MergeLayer

# from utils.utils import MergeLayer


class TemporalAttentionLayer(torch.nn.Module):
    """
    Temporal attention layer. Return the temporal embedding of a node given the node itself,
     its neighbors and the edge timestamps.
    """

    def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
                             output_dimension, n_head=2,
                             dropout=0.1):
        super(TemporalAttentionLayer, self).__init__()

        self.n_head = n_head

        self.feat_dim = n_node_features
        self.time_dim = time_dim

        self.query_dim = n_node_features + time_dim
        self.key_dim = n_neighbors_features + time_dim + n_edge_features

        self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)

        self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
                                                        kdim=self.key_dim,
                                                        vdim=self.key_dim,
                                                        num_heads=n_head,
                                                        dropout=dropout)

    def forward(self, src_node_features, src_time_features, neighbors_features,
                            neighbors_time_features, edge_features, neighbors_padding_mask):
        """
        "Temporal attention model
        :param src_node_features: float Tensor of shape [batch_size, n_node_features]
        :param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
        :param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
        :param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
        time_dim]
        :param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
        :param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
        :return:
        attn_output: float Tensor of shape [1, batch_size, n_node_features]
        attn_output_weights: [batch_size, 1, n_neighbors]
        """

        src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)

        query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
        key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)

        # print(neighbors_features.shape, edge_features.shape, neighbors_time_features.shape)
        # Reshape tensors so to expected shape by multi head attention
        query = query.permute([1, 0, 2])  # [1, batch_size, num_of_features]
        key = key.permute([1, 0, 2])  # [n_neighbors, batch_size, num_of_features]

        # Compute mask of which source nodes have no valid neighbors
        # invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
        # If a source node has no valid neighbor, set it's first neighbor to be valid. This will
        # force the attention to just 'attend' on this neighbor (which has the same features as all
        # the others since they are fake neighbors) and will produce an equivalent result to the
        # original tgat paper which was forcing fake neighbors to all have same attention of 1e-10
        # neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False

        # print(query.shape, key.shape)

        # attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
        #                                                             key_padding_mask=neighbors_padding_mask)

        
        attn_mask = neighbors_padding_mask
        attn_mask = attn_mask.unsqueeze(1).repeat(self.n_head, 1, 1) # important, https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=multiheadattention#torch.nn.MultiheadAttention
        # attn_mask.shape: [N(600)*num_heads, L=1, S=10], N is batch size, L is fixed as 1, S is num of neighbor edges, set to 10 generally.
        # import ipdb; ipdb.set_trace()
        attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
                                                                    attn_mask=attn_mask)

        # mask = torch.unsqueeze(neighbors_padding_mask, dim=2)  # mask [B, N, 1]
        # mask = mask.permute([0, 2, 1])
        # attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,
        #                                                           mask=mask)

        attn_output = attn_output.squeeze()
        attn_output_weights = attn_output_weights.squeeze()

        # Source nodes with no neighbors have an all zero attention output. The attention output is
        # then added or concatenated to the original source node features and then fed into an MLP.
        # This means that an all zero vector is not used.
        # attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
        # attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)

        # Skip connection with temporal attention over neighborhood and the features of the node itself
        attn_output = self.merger(attn_output, src_node_features)

        return attn_output, attn_output_weights
