import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from IPython import embed

class MeanAggregator(nn.Module):
    def __init__(self, features, device, num_sample=15, 
                 sample_nodes=False, dropout=False, 
                 gcn=False): 
        super(MeanAggregator, self).__init__()

        self.features = features
        self.device = device
        self.num_sample = num_sample
        self.sample_nodes = sample_nodes
        self.gcn = gcn
        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None
        
    def forward(self, nodes, to_neighs):
        """
        nodes --- list of nodes in a batch
        to_neighs --- list of sets, each set is the set of neighbors for node in batch
        num_sample --- number of neighbors to sample. No sampling if None.
        """

        _set = set
        if self.sample_nodes:         
            # sample neighs based on the hitting prob
            _neighs = [sorted(to_neigh, key=lambda x: x[2], reverse=True)[:self.num_sample] 
                       if len(to_neigh) >= self.num_sample else to_neigh for to_neigh in to_neighs]
            # change ds
            samp_neighs = []
            for i, adj_list in enumerate(_neighs):
                samp_neighs.append(set([node for node, rel, hp in adj_list]))
                if self.gcn:
                    samp_neighs[i].add(nodes[i])
        else:
            # no sampling
            samp_neighs = to_neighs

        unique_nodes_list = sorted(list(set.union(*samp_neighs)))

        unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)}
        mask = torch.zeros(len(samp_neighs), len(unique_nodes), device=self.device)
        column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]   
        row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
        mask[row_indices, column_indices] = 1
        num_neigh = mask.sum(1, keepdim=True)
        mask = mask.div(num_neigh.clamp(1e-8))
        mask = mask.to(self.device)

        node_tensor = torch.tensor(unique_nodes_list, device=self.device)
        embed_matrix = self.features(node_tensor)

        if self.dropout is not None:
            embed_matrix = self.dropout(embed_matrix)

        to_feats = mask.mm(embed_matrix)
        # print('to feats', to_feats[2][5].item())

        return to_feats
