import torch
import torch.nn as nn
from torch.nn import init

from allennlp.nn.util import masked_max, masked_mean, masked_softmax

from IPython import embed

class AttnAggregator(nn.Module):
    def __init__(self, features, input_dim, output_dim, device,
                 num_sample=15, sample_nodes=True, dropout=False,
                 gcn=False):
        """
        Initializes the aggregator for a specific graph.

        features -- function mapping LongTensor of node ids to FloatTensor of feature values.
        cuda -- whether to use GPU
        gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
        """

        super(AttnAggregator, self).__init__()

        self.features = features
        self.device = device
        self.num_sample = num_sample
        self.sample_nodes = sample_nodes
        self.shuffle=True

        self.input_dim = input_dim
        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

        self.proj = nn.Linear(input_dim, output_dim, bias=False)

        init.xavier_uniform_(self.proj.weight)

        self.attn_src = nn.Linear(output_dim, 1, bias=False)
        self.attn_dst = nn.Linear(output_dim, 1, bias=False)

        # self.linear = nn.Linear(out, 1, bias=False)
        self.leaky_relu = nn.LeakyReLU(0.2)

        self.gcn=gcn

    def forward(self, nodes, to_neighs):
        """
        nodes --- list of nodes in a batch
        to_neighs --- list of sets, each set is the set of neighbors for node in batch
        num_sample --- number of neighbors to sample. No sampling if None.
        """

        _set = set
        if self.sample_nodes:
            # sample neighs based on the hitting prob
            _neighs = [sorted(to_neigh, key=lambda x: x[2], reverse=True)[:self.num_sample]
                       if len(to_neigh) >= self.num_sample else to_neigh for to_neigh in to_neighs]
            # change ds
            samp_neighs = []
            for i, adj_list in enumerate(_neighs):
                samp_neighs.append(set([node for node, rel, hp in adj_list]))
                if self.gcn:
                    samp_neighs[i].add(int(nodes[i]))
        else:
            # no sampling
            samp_neighs = to_neighs

        unique_nodes_list = []

        unique_nodes_list = list(set.union(*samp_neighs))

        # get the unique nodes
        unique_nodes = list(set(unique_nodes_list))
        node_to_emb_idx = {n:i for i,n in enumerate(unique_nodes)}
        unique_nodes_tensor = torch.tensor(unique_nodes, device=self.device)

        embed_matrix = self.features(unique_nodes_tensor)
        if self.dropout is not None:
            embed_matrix = self.dropout(embed_matrix)

        # get new features
        embed_matrix_prime = self.proj(embed_matrix)


        to_feats = torch.empty(len(samp_neighs), self.input_dim, device=self.device)
        modified_adj_nodes = base_modified_neighbours(samp_neighs, node_to_emb_idx)

        #
        padded_tensor, mask = pad_tensor(modified_adj_nodes, mask=True)
        # sending padded tensor
        padded_tensor = padded_tensor.to(self.device)
        mask = mask.to(self.device)

        dst_nodes = []
        max_length = mask.size(1)
        for _node in nodes:
            dst_nodes.append([node_to_emb_idx[_node]] * max_length)

        dst_tensor = torch.tensor(dst_nodes).to(self.device)

        # embed matrix
        neigh_feats = embed_matrix_prime[padded_tensor]
        dst_feats = embed_matrix_prime[dst_tensor]

        # new feats
        # neigh_feats = self.proj(neigh_feats)
        # dst_feats = self.proj(dst_feats)

        dst_attn = self.leaky_relu(self.attn_dst(dst_feats))
        neigh_attn = self.leaky_relu(self.attn_src(neigh_feats))

        edge_attn = dst_attn + neigh_attn

        attn = masked_softmax(edge_attn, mask.unsqueeze(-1), dim=1)

        to_feats = torch.sum(attn * neigh_feats, dim=1)

        return to_feats


def pad_tensor(adj_nodes_list, mask=False):
    max_len = max([len(adj_nodes) for adj_nodes in adj_nodes_list])

    padded_nodes = []
    _mask = []
    for adj_nodes in adj_nodes_list:
        x = list(adj_nodes)

        x += [0] * (max_len - len(adj_nodes))
        padded_nodes.append(x)
        _mask.append([1] * len(adj_nodes) + [0] * (max_len - len(adj_nodes)))

    if not mask:
        return torch.tensor(padded_nodes)

    # returning the mask as well
    return torch.tensor(padded_nodes), torch.tensor(_mask)


def base_modified_neighbours(adj_nodes_list, idx_mapping):
    new_adj_nodes_list = []
    for adj_nodes in adj_nodes_list:
        new_adj_nodes = []
        for node in adj_nodes:
            new_adj_nodes.append(idx_mapping[node])
        new_adj_nodes_list.append(new_adj_nodes)

    return new_adj_nodes_list