### playing around with node2vec

import argparse
import networkx as nx
from node2vec import Node2Vec
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, confusion_matrix
from sympy.utilities.iterables import multiset_permutations
import numpy.random as nprand
import numpy as np
import pandas as pd
from pathlib import Path
import os
from datetime import date



### Parse arguments in from command line

parser = argparse.ArgumentParser()
parser.add_argument('--nodes', dest='nodes', type=int, default=50)
parser.add_argument('--K', dest='K', type=int, default=2)
parser.add_argument('--ratio', dest='ratio', type=float, default=0.2)
parser.add_argument('--sims', dest='sims', type=int, default=1)
# parser.add_argument('--within', dest='within', type=float, default=0.5)
# parser.add_argument('--between', dest='between', type=float, default=0.01)
parser.add_argument('--notes', dest='notes', type=str, default='')

args = parser.parse_args()


### initially specify p = 0.05 just by default
### to get it working

n = args.nodes
K = args.K 
## p = 0.05 ## for dense
# p = 5/n ## for sparse
# p = log(n)/n ## for relatively sparse
q = args.ratio * p
notes = args.notes
nsims = args.sims

## can change node2vec params here also

all_results = []

today = date.today()
curr_date = today.strftime("%d_%m_%Y")
file_name = "node2vec_n{}_K{}_p{}_ratio{}_{}{}".format(n, K, int(p*100), int(args.ratio*100), curr_date, notes)

# Create a simple SBM
# nodes = [n, n]
# p = 0.5
# q = 0.005
# probs = [ [p, q], [q, p]]

# graph = nx.stochastic_block_model(nodes, probs)

for i in range(nsims):
    graph = nx.planted_partition_graph(K, n,  p, q)

    truth = sorted(list(range(K)) * n)

    node2vec = Node2Vec(graph, dimensions=64, walk_length=30, num_walks=200, workers=4)  # Use temp_folder for big graphs

    # Embed nodes
    model = node2vec.fit(window=10, min_count=1, batch_words=4)

    emb_df = (
        pd.DataFrame(
            [model.wv.get_vector(str(n)) for n in graph.nodes()],
            index = graph.nodes
        )
    )

    X = emb_df.values

    ## then do k-means on that

    kmeans = KMeans(n_clusters=K, n_init = 10).fit(X)

    ## then would want to compare these to the truth
    ari_score = adjusted_rand_score(kmeans.labels_, truth)
    nmi_score = adjusted_mutual_info_score(kmeans.labels_, truth)
    print(adjusted_mutual_info_score(kmeans.labels_, truth))

    ## then need to compute proportion correctly recovered here also
    conf_matrix = confusion_matrix(kmeans.labels_, truth)

    best_correct  = 0

    communities = list(range(0,K))

    for perm in multiset_permutations(communities):
        correct = sum(np.diag(conf_matrix[:, perm]))/(n*K)
        # print(correct)
        best_correct = max(correct, best_correct)


    print(best_correct)
    ## then need to figure out how to save these parameters

    results = {'n': [n], 'K': [K], 'p': [p], 'q':[q], 'ARI':[ari_score], 'NMI':[nmi_score], 'prop_recovered':[best_correct], 'sim':[i]}
    all_results.append(results)





df = pd.DataFrame(data = all_results)

curr_dir = os.getcwd()

save_path = curr_dir + "/subfolder/" + curr_date + '/'


filepath = Path(save_path + file_name + '_out.csv')  

filepath.parent.mkdir(parents=True, exist_ok=True)

df.to_csv(filepath)