import torch
import torch.nn as nn
from collections import Counter

from IPython import embed

class RGCNAgg(nn.Module):
    """
    Aggregates a node's embeddings using mean of neighbors' embeddings
    """
    def __init__(self, features, rel_coef, rel_w, device, 
                 num_sample=15, sample_nodes=False, dropout=False,
                 gcn=False):
        """
        Initializes the aggregator for a specific graph.
        """

        super(RGCNAgg, self).__init__()

        self.features = features
        self.rel_coef = rel_coef
        self.rel_w = rel_w

        self.input_dim = self.rel_w.shape[0]
        self.output_dim = self.rel_w.shape[1]
        self.num_rel = self.rel_coef.shape[0]
        self.num_bases = self.rel_coef.shape[1]

        # self.relation_features = relation_features
        self.device = device
        self.num_sample = num_sample
        self.sample_nodes = sample_nodes
        self.gcn = gcn

        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

    def forward(self, nodes, to_neighs):
        """
        nodes --- list of nodes in a batch
        to_neighs --- list of sets, each set is the set of neighbors for node in batch
        num_sample --- number of neighbors to sample. No sampling if None.
        """
        if self.sample_nodes:            
            # sample neighs based on the hitting prob
            _neighs = [sorted(to_neigh, key=lambda x: x[2], reverse=True)[:self.num_sample] 
                       if len(to_neigh) >= self.num_sample else to_neigh for to_neigh in to_neighs]
            # change ds
            samp_neighs = []
            for i, adj_list in enumerate(_neighs):
                samp_neighs.append(set([(node, rel) for node, rel, hp in adj_list]))
                if self.gcn:
                    # self loop is 50
                    samp_neighs[i].add((nodes[i], 50))
        else:
            # no sampling
            samp_neighs = to_neighs

        unique_nodes_list = []
        unique_node_rel_list = []

        for adj_list in samp_neighs:
            adj_nodes, adj_rel = zip(*adj_list) 
            unique_nodes_list.extend(adj_nodes)
            unique_node_rel_list.extend(list(adj_list))

        # get the unique nodes
        unique_nodes = list(set(unique_nodes_list))
        node_to_emb_idx = {n:i for i,n in enumerate(unique_nodes)}
        unique_nodes_tensor = torch.tensor(unique_nodes, device=self.device)
        
        embed_matrix = self.features(unique_nodes_tensor)
        if self.dropout is not None:
            embed_matrix = self.dropout(embed_matrix)

        unique_node_rel_list = list(set(unique_node_rel_list))
        node_rel_idx_mapping = {tuple(n_r): i for i, n_r in enumerate(unique_node_rel_list)}
        node, rel = zip(*unique_node_rel_list)
        node = [node_to_emb_idx[n] for n in list(node)]

        rel = list(rel)
        node_embs = embed_matrix[node]
        # associate property 
        # a(bc) = (ab)c
        # alpa
        b = torch.matmul(node_embs, self.rel_w.permute(1, 0, 2))
        a = self.rel_coef[rel]
        node_rel_matrix = torch.bmm(a.unsqueeze(1), b.permute(1, 2, 0)).squeeze(1)

        # get counts per relation in row
        norm = []
        for samp_neigh in samp_neighs:
            # for n_r in samp_neigh:
            n, r = zip(*samp_neigh)
            r_count = Counter(r)
            for node, rel in samp_neigh:
                norm.append(1/r_count[rel])

        mask = torch.zeros(len(samp_neighs), len(unique_node_rel_list), device=self.device)
        column_indices = [node_rel_idx_mapping[n_r] for samp_neigh in samp_neighs for n_r in samp_neigh]   
        row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
        mask[row_indices, column_indices] = torch.Tensor(norm).to(self.device)

        to_feats = mask.mm(node_rel_matrix)

        return to_feats
