import argparse
import pickle
import numpy as np
import os
import multiprocessing as mp
import networkx as nx
import tqdm
from collections import deque

import graphs


parser = argparse.ArgumentParser()
parser.add_argument('--graph_id', choices=['3d', '3dd', 'taxi', 'octagon', 'traffic'])
args = parser.parse_args()

SYMMETRIC = (args.graph_id in ['3d', 'taxi'])


def make_connected_edgelist(graph_id):

    def connected_component_subgraphs(G):
        for c in nx.connected_components(G):
            yield G.subgraph(c)

    G = nx.read_edgelist('./data/{}.edgelist'.format(graph_id), nodetype=int, 
                        data=(('weight',float),), create_using=nx.DiGraph())
    if SYMMETRIC:
        G = G.to_undirected()

        Gs = list(sorted(connected_component_subgraphs(G), key = lambda g: g.number_of_nodes(), reverse=True))
        G = Gs[0]
        print('num nodes: {}; num edges: {}'.format(G.number_of_nodes(), G.number_of_edges()))
        nx.write_weighted_edgelist(G, './data/{}connected.edgelist'.format(graph_id))


def keywithmaxval(d):
    """ a) create a list of the dict's keys and values; 
        b) return the key with the max value"""  
    v=list(d.values())
    k=list(d.keys())
    return k[v.index(max(v))]


def process_edgelist(graph_id):
    G = nx.read_edgelist('./data/{}{}.edgelist'.format(graph_id, 'connected' if SYMMETRIC else ''), 
        nodetype=int, data=(('weight',float),), create_using=nx.DiGraph())
    if SYMMETRIC:
        G = G.to_undirected()

    nodes = np.array(G.nodes())

    print('V, E: ')
    print(G.number_of_nodes(), G.number_of_edges())

    N_LANDMARKS = 15
    OPTIMIZATION_TRIES = 45
    landmarks = deque(np.random.choice(nodes, (1,), replace=False), N_LANDMARKS)
    for i in tqdm.tqdm(range(OPTIMIZATION_TRIES)):
        dists = nx.multi_source_dijkstra_path_length(G, landmarks)
        landmarks.append(keywithmaxval(dists))
    landmark_dists = [nx.single_source_dijkstra_path_length(G, l) for l in landmarks]

    return G, landmark_dists


def make_XYD_pickles(graph_id):
    EPOCHS = 250
    data_pair_arrays = np.split(data_pairs, EPOCHS)

    dists = []
    for epoch in tqdm.tqdm(range(EPOCHS)):
        with mp.Pool(36) as pool:
            dists += pool.map(mp_func, data_pair_arrays[epoch])

    X = data_pairs[:,0]
    Y = data_pairs[:,1]
    D = np.array(dists)

    print('max: ', max(D))
    D = D * (50. / np.mean(D))
    print('max after norm: ', max(D))

    with open('./data/{}_150k.pickle'.format(graph_id), 'wb') as f:
        pickle.dump((X, Y, D), f)


def make_lm_embeddings(graph_id, NOISE_FEATS=96, N_LANDMARKS=32, LM_NOISE=0.2, OPTIMIZATION_TRIES=64):

    G = nx.read_edgelist('./data/{}{}.edgelist'.format(graph_id, 'connected' if SYMMETRIC else ''), 
        nodetype=int, data=(('weight',float),), create_using=nx.DiGraph())
    if SYMMETRIC:
        G = G.to_undirected()
    else:
        Grev = G.reverse()
    nodes = np.array(G.nodes())

    # Collect landmarks
    landmarks = deque(np.random.choice(nodes, (1,), replace=False), N_LANDMARKS)
    for i in tqdm.tqdm(range(OPTIMIZATION_TRIES)):
        dists = nx.multi_source_dijkstra_path_length(G, landmarks)
        landmarks.append(keywithmaxval(dists))
    landmark_dists = [nx.single_source_dijkstra_path_length(G, l) for l in landmarks]
    if not SYMMETRIC:
        landmark_dists += [nx.single_source_dijkstra_path_length(Grev, l) for l in landmarks]
        N_LANDMARKS = 2 * N_LANDMARKS

    # Get landmark stats for normalization
    lm_dists = [list(landmark_dists[i].values()) for i in range(N_LANDMARKS)]
    lm_dists = np.array(lm_dists)
    lm_mean = np.mean(lm_dists)
    lm_std = np.std(lm_dists)

    # Collect embeddings based on landmarks, normalizing each
    embs = {}
    for i in nodes:
        if LM_NOISE:
            lm_noise = np.random.normal(scale=LM_NOISE, size=[N_LANDMARKS])
        else:
            lm_noise = 0.
        lm_feats = (np.array([landmark_dists[lm][i] for lm in range(N_LANDMARKS)]) - lm_mean) / lm_std + lm_noise
        noise_feats = np.random.normal(size=[NOISE_FEATS])
        embs[i] = np.concatenate([lm_feats, noise_feats])

    if not SYMMETRIC:
        N_LANDMARKS = N_LANDMARKS // 2
        
    # Save embs to disk
    with open('./data/{}_lm_{}n{}-{}_emb_dict.pickle'.format(graph_id, N_LANDMARKS, LM_NOISE, NOISE_FEATS), 'wb') as f:
        pickle.dump(embs, f)


if __name__ == '__main__':

    print(args.graph_id)

    print('generating graph...')
    method = getattr(graphs, f'make_{args.graph_id}')
    method()

    print('making connected edgelist...')
    make_connected_edgelist(args.graph_id)

    print('finding landmarks...')
    G, landmark_dists = process_edgelist(args.graph_id)


    def heuristic(u, v):
        lbs = [l[u] - l[v] for l in landmark_dists]
        lbs += [l[v] - l[u] for l in landmark_dists]
        return max(lbs)
    
    def mp_func(pair):
        return nx.astar_path_length(G, source=pair[0], target=pair[1], heuristic=heuristic)

    # We will make a total of up to 160K data points, then pick 150K (since we might prune some)
    N = 160000
    nodes = np.array(G.nodes())
    data_pairs = np.random.choice(nodes, size=(int(N), 2))
    data_pairs = data_pairs[data_pairs[:,0] != data_pairs[:,1]] # prevent self connections
    data_pairs = np.array(list(set([tuple(d) for d in data_pairs]))) # prevent duplicates 
    data_pairs = data_pairs[:150000]

    print(data_pairs.shape)
    print(data_pairs[:3])

    print('computing target distances...')
    make_XYD_pickles(args.graph_id)

    print('making embeddings...')
    make_lm_embeddings(args.graph_id)
