import torch
import torch.nn as nn

from allennlp.nn.util import masked_max, masked_mean, masked_softmax
from allennlp.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder

from IPython import embed

class TransformerAggregator(nn.Module):
    def __init__(self, features, input_dim, device, 
                 num_sample=15, sample_nodes=False, dropout=False,
                 num_heads=1, pd=None, hd=None, fh=None,
                 maxpool=False, dp=0.1, num_layer=1, full_dim=False, 
                 self_loop=False): 
        """
        Initializes the aggregator for a specific graph.

        features -- function mapping LongTensor of node ids to FloatTensor of feature values.
        cuda -- whether to use GPU
        gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
        """

        super(TransformerAggregator, self).__init__()

        self.features = features
        self.device = device
        self.num_sample = num_sample
        self.sample_nodes = sample_nodes
        self.shuffle=True
        self.num_heads = num_heads
        self.proj_dim = pd or int(input_dim/2)
        self.hidden_dim = hd or input_dim
        self.ff_hidden = fh or int(input_dim/2)
        self.maxpool = maxpool
        self.self_loop = self_loop
        self.num_layers = num_layer

        if full_dim:
            self.proj_dim = input_dim
            self.ff_hidden = input_dim
        
        self.input_dim = input_dim
        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

        self.attention = StackedSelfAttentionEncoder(input_dim,
                                                hidden_dim=self.hidden_dim, 
                                                projection_dim=self.proj_dim,
                                                feedforward_hidden_dim=self.ff_hidden,
                                                num_layers=self.num_layers,
                                                num_attention_heads=self.num_heads,
                                                use_positional_encoding=False,
                                                attention_dropout_prob=dp)


    def forward(self, nodes, to_neighs):
        """
        nodes --- list of nodes in a batch
        to_neighs --- list of sets, each set is the set of neighbors for node in batch
        num_sample --- number of neighbors to sample. No sampling if None.
        """

        _set = set
        if self.sample_nodes:    
            # sample neighs based on the hitting prob
            _neighs = [sorted(to_neigh, key=lambda x: x[2], reverse=True)[:self.num_sample] 
                       if len(to_neigh) >= self.num_sample else to_neigh for to_neigh in to_neighs]
            # change ds
            samp_neighs = []
            for i, adj_list in enumerate(_neighs):
                samp_neighs.append(set([node for node, rel, hp in adj_list]))
                if self.self_loop:
                    samp_neighs[i].add(nodes[i])
        else:
            # no sampling
            samp_neighs = to_neighs

        unique_nodes_list = []

        unique_nodes_list = list(set.union(*samp_neighs))

        # get the unique nodes
        unique_nodes = list(set(unique_nodes_list))
        node_to_emb_idx = {n:i for i,n in enumerate(unique_nodes)}
        unique_nodes_tensor = torch.tensor(unique_nodes, device=self.device)

        embed_matrix = self.features(unique_nodes_tensor)
        if self.dropout is not None:
            embed_matrix = self.dropout(embed_matrix)

        to_feats = torch.empty(len(samp_neighs), self.input_dim, device=self.device)

        modified_adj_nodes = base_modified_neighbours(samp_neighs, node_to_emb_idx)

        padded_tensor, mask = pad_tensor(modified_adj_nodes, mask=True)
        # sending padded tensor
        padded_tensor = padded_tensor.to(self.device)
        mask = mask.to(self.device)

        # embed matrix
        neigh_feats = embed_matrix[padded_tensor]

        attn_feats = self.attention(neigh_feats, mask)
        if self.maxpool:
            to_feats = masked_max(attn_feats, mask.unsqueeze(-1), dim=1)
        else:
            to_feats = masked_mean(attn_feats, mask.unsqueeze(-1), dim=1)

        return to_feats


def pad_tensor(adj_nodes_list, mask=False):
    max_len = max([len(adj_nodes) for adj_nodes in adj_nodes_list])

    padded_nodes = []
    _mask = []
    for adj_nodes in adj_nodes_list:
        x = list(adj_nodes)
        
        x += [0] * (max_len - len(adj_nodes))
        padded_nodes.append(x)
        _mask.append([1] * len(adj_nodes) + [0] * (max_len - len(adj_nodes)))

    if not mask:
        return torch.tensor(padded_nodes)

    # returning the mask as well
    return torch.tensor(padded_nodes), torch.tensor(_mask)


def base_modified_neighbours(adj_nodes_list, idx_mapping):
    new_adj_nodes_list = []
    for adj_nodes in adj_nodes_list:
        new_adj_nodes = []
        for node in adj_nodes:
            new_adj_nodes.append(idx_mapping[node])        
        new_adj_nodes_list.append(new_adj_nodes)
    
    return new_adj_nodes_list