import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from IPython import embed

from allennlp.modules.seq2vec_encoders.pytorch_seq2vec_wrapper import PytorchSeq2VecWrapper

from allennlp.nn.util import masked_max, masked_mean
from allennlp.nn.util import get_text_field_mask


class LSTMAggregator(nn.Module):
    """
    Aggregates a node's embeddings using mean of neighbors' embeddings
    """
    def __init__(self, features, lstm_dim, device, num_sample=30, sample_nodes=False, dropout=True, self_loop=True): 
        """
        Initializes the aggregator for a specific graph.
        features -- function mapping LongTensor of node ids to FloatTensor of feature values.
        cuda -- whether to use GPU
        gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
        """

        super(LSTMAggregator, self).__init__()

        self.features = features
        self.lstm_dim = lstm_dim
        self.sample_nodes = True

        self.lstm = PytorchSeq2VecWrapper(nn.LSTM(self.lstm_dim, 
                                            self.lstm_dim,
                                            batch_first=True,
                                            bidirectional=False))
        self.num_sample = num_sample
        self.device = device
        self.self_loop = self_loop
        self.shuffle = True
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, nodes, to_neighs):
        """
        nodes --- list of nodes in a batch
        to_neighs --- list of sets, each set is the set of neighbors for node in batch
        num_sample --- number of neighbors to sample. No sampling if None.
        """
        # Local pointers to functions (speed hack)
        _set = set
        if self.sample_nodes:         
            # sample neighs based on the hitting prob
            _neighs = [sorted(to_neigh, key=lambda x: x[2], reverse=True)[:self.num_sample] 
                       if len(to_neigh) >= self.num_sample else to_neigh for to_neigh in to_neighs]
            # change ds
            samp_neighs = []
            for i, adj_list in enumerate(_neighs):
                samp_neighs.append(set([node for node, rel, hp in adj_list]))
                if self.self_loop:
                    samp_neighs[i].add(nodes[i])
        else:
            # no sampling
            samp_neighs = to_neighs

        adj_nodes_list = []
        unique_nodes_list = []
        for adj_list in samp_neighs:
            if self.shuffle:
                adj_list = list(adj_list)
                random.shuffle(adj_list)

            # no relation
            unique_nodes_list.extend(adj_list)
            adj_nodes_list.append(adj_list)
        
        unique_nodes = list(set(unique_nodes_list))
       
        idx_mapping = {n:i for i,n in enumerate(unique_nodes)}
        unique_nodes_tensor = torch.tensor(unique_nodes).to(self.device)
        embs_tensor = self.features(unique_nodes_tensor)

        embs_tensor = self.dropout(embs_tensor)

        # adding a zero tensor for padding
        modified_adj_nodes = base_modified_neighbours(adj_nodes_list, idx_mapping)

        padded_tensor, mask = pad_tensor(modified_adj_nodes, mask=True)
        padded_tensor = padded_tensor.to(self.device)
        mask = mask.to(self.device)
        padded_embs = embs_tensor[padded_tensor]

        # create mask
        hidden_states = self.lstm(padded_embs, mask)

        return hidden_states


def pad_tensor(adj_nodes_list, mask=False):
    max_len = max([len(adj_nodes) for adj_nodes in adj_nodes_list])

    padded_nodes = []
    _mask = []
    for adj_nodes in adj_nodes_list:
        x = list(adj_nodes)
        
        x += [0] * (max_len - len(adj_nodes))
        padded_nodes.append(x)
        _mask.append([1] * len(adj_nodes) + [0] * (max_len - len(adj_nodes)))

    if not mask:
        return torch.tensor(padded_nodes)

    # returning the mask as well
    return torch.tensor(padded_nodes), torch.tensor(_mask)


def base_modified_neighbours(adj_nodes_list, idx_mapping):
    new_adj_nodes_list = []
    for adj_nodes in adj_nodes_list:
        new_adj_nodes = []
        for node in adj_nodes:
            new_adj_nodes.append(idx_mapping[node])        
        new_adj_nodes_list.append(new_adj_nodes)
    
    return new_adj_nodes_list