import dgl
import dgl.function as fn
from dgl.nn.pytorch import SAGEConv

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import accumulate
import random
import multiprocessing as mp
import networkx as nx

from torch.nn import init

from graph_learning.module import ModuleConfig, get_module

def single_source_shortest_path_length_range(graph, node_range, cutoff):
    dists_dict = {}
    for node in node_range:
        dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff)
    return dists_dict

def merge_dicts(dicts):
    result = {}
    for dictionary in dicts:
        result.update(dictionary)
    return result

def all_pairs_shortest_path_length_parallel(graph,cutoff=None,num_workers=4):
    nodes = list(graph.nodes)
    random.shuffle(nodes)
    if len(nodes)<50:
        num_workers = int(num_workers/4)
    elif len(nodes)<400:
        num_workers = int(num_workers/2)

    pool = mp.Pool(processes=num_workers)
    results = [pool.apply_async(single_source_shortest_path_length_range,
            args=(graph, nodes[int(len(nodes)/num_workers*i):int(len(nodes)/num_workers*(i+1))], cutoff)) for i in range(num_workers)]
    output = [p.get() for p in results]
    dists_dict = merge_dicts(output)
    pool.close()
    pool.join()
    return dists_dict

def precompute_dist_data(graph, num_nodes, approximate=0, raw=False):
    '''
    Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
    :return:
    '''
    n = num_nodes
    dists_array = np.zeros((n, n))
    # dists_dict = nx.all_pairs_shortest_path_length(graph,cutoff=approximate if approximate>0 else None)
    # dists_dict = {c[0]: c[1] for c in dists_dict}
    dists_dict = all_pairs_shortest_path_length_parallel(graph,cutoff=approximate if approximate>0 else None)
    for i, node_i in enumerate(graph.nodes()):
        shortest_dist = dists_dict[node_i]
        for j, node_j in enumerate(graph.nodes()):
            dist = shortest_dist.get(node_j, -1)
            if dist!=-1:
                # dists_array[i, j] = 1 / (dist + 1)
                if raw:
                    dists_array[node_i, node_j] = dist
                else:
                    dists_array[node_i, node_j] = 1 / (dist + 1)
    return dists_array

def get_random_anchorset(n,c=0.5):
    m = int(np.log2(n))
    copy = int(c*m)
    anchorset_id = []
    for i in range(m):
        anchor_size = int(n/np.exp2(i + 1))
        for j in range(copy):
            anchorset_id.append(np.random.choice(n,size=anchor_size,replace=False))
    return anchorset_id

def get_dist_max(anchorset_id, dist, device):
    dist_max = torch.zeros((dist.shape[0],len(anchorset_id))).to(device)
    dist_argmax = torch.zeros((dist.shape[0],len(anchorset_id))).long().to(device)
    for i in range(len(anchorset_id)):
        temp_id = anchorset_id[i]
        dist_temp = dist[:, temp_id]
        dist_max_temp, dist_argmax_temp = torch.max(dist_temp, dim=-1)
        dist_max[:,i] = dist_max_temp
        dist_argmax[:,i] = dist_argmax_temp
    return dist_max, dist_argmax

def pgnn_preselect_anchor(graph, layer_num=1, anchor_num=32, anchor_size_num=4, device='cpu'):
    anchor_set = []
    anchor_num_per_size = anchor_num//anchor_size_num
    for i in range(anchor_size_num):
        anchor_size = 2**(i+1)-1
        anchors = np.random.choice(graph.number_of_nodes(), size=(layer_num,anchor_num_per_size,anchor_size), replace=True)
        anchor_set.append(anchors)
    anchor_set_indicator = np.zeros((layer_num, anchor_num, graph.number_of_nodes()), dtype=int)

    anchorset_id = get_random_anchorset(graph.number_of_nodes(),c=1)

    dists_max, dists_argmax = get_dist_max(anchorset_id, graph.gdata['_pgnn_dist'], device)
    return dists_max, dists_argmax

@ModuleConfig.register('pgnn',
                       help='[Encoder] PGNN')
class AGNNModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return PGNN

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--input-dim', type=int)
        parser.add_argument('--feature-dim', type=int)
        parser.add_argument('--hidden-dim', type=int)
        parser.add_argument('--output-dim', type=int)
        parser.add_argument('--layer-num', type=int)
        parser.add_argument('--bn', action='store_true',
                            help='use batchnorm')

        parser.add_argument('--feature-pre', action='store_true')
        parser.add_argument('--dropout', action='store_true')

# # PGNN layer, only pick closest node for message passing
class PGNN_layer(nn.Module):
    def __init__(self, input_dim, output_dim,dist_trainable=True):
        super(PGNN_layer, self).__init__()
        self.input_dim = input_dim
        self.dist_trainable = dist_trainable

        if self.dist_trainable:
            self.dist_compute = Nonlinear(1, output_dim, 1)

        self.linear_hidden = nn.Linear(input_dim*2, output_dim)
        self.linear_out_position = nn.Linear(output_dim,1)
        self.act = nn.ReLU()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data = init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
                if m.bias is not None:
                    m.bias.data = init.constant_(m.bias.data, 0.0)

    def forward(self, feature, dists_max, dists_argmax):
        if self.dist_trainable:
            dists_max = self.dist_compute(dists_max.unsqueeze(-1)).squeeze()

        subset_features = feature[dists_argmax.flatten(), :]
        subset_features = subset_features.reshape((dists_argmax.shape[0], dists_argmax.shape[1],
                                                   feature.shape[1]))
        messages = subset_features * dists_max.unsqueeze(-1)

        self_feature = feature.unsqueeze(1).repeat(1, dists_max.shape[1], 1)
        messages = torch.cat((messages, self_feature), dim=-1)

        messages = self.linear_hidden(messages).squeeze()
        messages = self.act(messages) # n*m*d

        out_position = self.linear_out_position(messages).squeeze(-1)  # n*m_out
        out_structure = torch.mean(messages, dim=1)  # n*d

        return out_position, out_structure

class Nonlinear(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Nonlinear, self).__init__()

        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, output_dim)

        self.act = nn.ReLU()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data = init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
                if m.bias is not None:
                    m.bias.data = init.constant_(m.bias.data, 0.0)

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = self.linear2(x)
        return x

class PGNN(torch.nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dim, output_dim,
                 feature_pre, layer_num, dropout, bn,
                 name):
        super(PGNN, self).__init__()
        self.name = name
        self.use_bn = bn

        self.feature_pre = feature_pre
        self.layer_num = layer_num
        self.dropout = dropout
        if layer_num == 1:
            hidden_dim = output_dim
        if feature_pre:
            self.linear_pre = nn.Linear(input_dim, feature_dim)
            self.conv_first = PGNN_layer(feature_dim, hidden_dim)
        else:
            self.conv_first = PGNN_layer(input_dim, hidden_dim)

        if self.use_bn:
            self.bn_layers = nn.ModuleList([nn.BatchNorm1d(hidden_dim, track_running_stats=False)
                                            for i in range(layer_num-1)])

        if layer_num>1:
            self.conv_hidden = nn.ModuleList([PGNN_layer(hidden_dim, hidden_dim) for i in range(layer_num - 2)])
            self.conv_out = PGNN_layer(hidden_dim, output_dim)

    def forward(self, data, x):
        graph = data.graph()
        graph = graph.adapt(graph.local_var())

        dists_max = graph.ndata['dists_max']
        dists_argmax = graph.ndata['dists_argmax']

        if self.feature_pre:
            x = self.linear_pre(x)
        x_position, x = self.conv_first(x, dists_max, dists_argmax)
        if self.layer_num == 1:
            return x_position

        if self.use_bn:
            x = self.bn_layers[0](x)
        # x = F.relu(x) # Note: optional!
        if self.dropout:
            x = F.dropout(x, training=self.training)

        for i in range(self.layer_num-2):
            _, x = self.conv_hidden[i](x, dists_max, dists_argmax)
            # x = F.relu(x) # Note: optional!

            if self.use_bn:
                x = self.bn_layers[i+1](x)
            if self.dropout:
                x = F.dropout(x, training=self.training)
        x_position, x = self.conv_out(x, dists_max, dists_argmax)
        #x_position = F.normalize(x_position, p=2, dim=-1)

        #return x_position
        return {'hidden': x_position,
                'outputs': {}}
