### playing around with node2vec

import argparse
import networkx as nx
from node2vec import Node2Vec
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, confusion_matrix
from sympy.utilities.iterables import multiset_permutations
import numpy.random as nprand
import numpy as np
import pandas as pd
from pathlib import Path
import os
from datetime import date
from math import pi



### Parse arguments in from command line

parser = argparse.ArgumentParser()
parser.add_argument('--nodes', dest='nodes', type=int, default=50)
parser.add_argument('--K', dest='K', type=int, default=2)
parser.add_argument('--ratio', dest='ratio', type=float, default=0.2)
parser.add_argument('--sims', dest='sims', type=int, default=1)
parser.add_argument('--alpha', dest='alpha', type=float, default=0.75)
# parser.add_argument('--within', dest='within', type=float, default=0.5)
# parser.add_argument('--between', dest='between', type=float, default=0.01)
parser.add_argument('--notes', dest='notes', type=str, default='')


args = parser.parse_args()


### initially specify p = 0.05 just by default
### to get it working

n = args.nodes
K = args.K 
# p = 0.05
p = np.log(n)/n
# p = 5/n
q = args.ratio * p
notes = args.notes
nsims = args.sims
alpha = args.alpha

## can change node2vec params here also
truth = sorted(list(range(K)) * n)
N = K * n

## to match harrison zhou paper, aos
z = np.random.normal(loc = 0, scale = 0.25, size = N)
weights = np.absolute(z) + 1 - pow(2*pi, -0.5)


Prob_matrix = np.full((K, K), q)
np.fill_diagonal(Prob_matrix, p)


def sim_dc_sbm(N, K, truth, Prob_matrix, weights):
    edgelist = []

    for i in range(N):
        for j in range(N):
            prob = Prob_matrix[truth[i], truth[j]]
            dc_prob = prob * weights[i] * weights[j]
            edge = np.random.binomial(1, dc_prob)
            if edge == 1 and i > j:
                ## add to the edge list
                edgelist.append((i, j))
                edgelist.append((j, i))
    G = nx.from_edgelist(edgelist)

    missing = list(set(range(N)) - set(G.nodes()))
    missing
    G.add_nodes_from(missing)
    return G


all_results = []

today = date.today()
curr_date = today.strftime("%d_%m_%Y")
file_name = "node2vec_n{}_K{}_p{}_ratio{}_alpha{}_{}{}".format(n, K, int(p*100), int(args.ratio*100), int(args.alpha*100), curr_date, notes)

for i in range(nsims):
    graph = sim_dc_sbm(N, K, truth, Prob_matrix, weights)
    node2vec = Node2Vec(graph, dimensions=64, walk_length=30, num_walks=200, workers=4)  # Use temp_folder for big graphs

    # Embed nodes
    model = node2vec.fit(window=10, min_count=1, batch_words=4, ns_exponent = alpha)

    emb_df = (
        pd.DataFrame(
            [model.wv.get_vector(str(n)) for n in graph.nodes()],
            index = graph.nodes
        )
    )

    X = emb_df.values

    ## then do k-means on that

    kmeans = KMeans(n_clusters=K, n_init = 10).fit(X)

    ## then would want to compare these to the truth
    ari_score = adjusted_rand_score(kmeans.labels_, truth)
    nmi_score = adjusted_mutual_info_score(kmeans.labels_, truth)
    print(adjusted_mutual_info_score(kmeans.labels_, truth))

    ## then need to compute proportion correctly recovered here also
    conf_matrix = confusion_matrix(kmeans.labels_, truth)

    best_correct  = 0

    communities = list(range(0,K))

    for perm in multiset_permutations(communities):
        correct = sum(np.diag(conf_matrix[:, perm]))/(n*K)
        # print(correct)
        best_correct = max(correct, best_correct)


    print(best_correct)
    ## then need to figure out how to save these parameters

    results = {'n': [n], 'K': [K], 'p': [p], 'q':[q], 'ARI':[ari_score], 'NMI':[nmi_score], 'alpha':[alpha], 'prop_recovered':[best_correct], 'sim':[i]}
    all_results.append(results)


df = pd.DataFrame(data = all_results)

curr_dir = os.getcwd()

save_path = curr_dir + "/subfolder/" + curr_date + '/'


filepath = Path(save_path + file_name + '_out.csv')  

filepath.parent.mkdir(parents=True, exist_ok=True)

df.to_csv(filepath)
