# To run the experiment, simply run this python script from the command line:
#
# python approximateEP_experiment2.py
#
# This replicates the experiment as reported in the paper. The results are saved in the folder approximateEP_experiment/experiment2/.


import main
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
import time
from karateclub import Role2Vec, Node2Vec, GraphWave
from sklearn.cluster import KMeans
from multiprocessing import Pool, set_start_method
from datetime import datetime
import time
import os
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
from main import evaluate_algorithm, node2vec, role2vec

from GNN_models import get_GNN_embedding

def experiment_iteration(it, experiment_dict):
    A = experiment_dict['dataset'] 
    c = experiment_dict['num_dif_k']
    l = experiment_dict['num_groups']
    n = A.shape[1]
    num_algo = experiment_dict['num_algo']
    graphs = A.copy()
    evals, Hs, times = np.zeros((c, num_algo, 2)), np.zeros((c, num_algo, n, l)), np.zeros((c, num_algo))
    np.random.seed()
    for i, k in enumerate(np.linspace(2, l, c, dtype=int).astype(int)):
        print(it, k, 0)
        evals[i, 0, 0], times[i, 0], Hs[i, 0] = evaluate_algorithm(A, k, lambda x,y : main.appr_EP_by_dom_EV(x,y, precomputed_embedding=experiment_dict['ev_emb']), iterations=1, eval_function=main.frac_ep_cost, size_H=l)
        evals[i, 0, 1] = main.deep_ep_cost_function(A, Hs[i,0])
        
        print(it, k, 1)
        evals[i, 1, 0], times[i, 1], Hs[i, 1] = evaluate_algorithm(A, k, lambda x,y : main.frac_WL(x,y, fit_function="average_linkage"), iterations=1, eval_function=main.frac_ep_cost, size_H=l)
        evals[i, 1, 1] = main.deep_ep_cost_function(A, Hs[i,1])
        
        print(it,k,2)
        evals[i, 2, 0], times[i, 2], Hs[i, 2] = evaluate_algorithm(A, k, lambda x,y : main.frac_WL(x,y, fit_function="soft_kmeans"), iterations=1, eval_function=main.frac_ep_cost, size_H=l)
        evals[i, 2, 1] = main.deep_ep_cost_function(A, Hs[i,2])
        
        print(it,k,3)
        evals[i, 3, 0], times[i, 3], Hs[i, 3] = evaluate_algorithm(A, k, lambda x,y : role2vec(x,y, precomputed_embedding=experiment_dict['r2v_emb']), iterations=1, eval_function=main.frac_ep_cost, size_H=l)
        evals[i, 3, 1] = main.deep_ep_cost_function(A, Hs[i,3])
        
        print(it,k,4)
        evals[i, 4, 0], times[i, 4], Hs[i, 4] = evaluate_algorithm(A, k,  lambda x,y : node2vec(x,y, precomputed_embedding=experiment_dict['n2v_emb']), iterations=1, eval_function=main.frac_ep_cost, size_H=l)
        evals[i, 4, 1] = main.deep_ep_cost_function(A, Hs[i,4])
        
        print(it,k,5)
        evals[i, 5, 0], times[i, 5], Hs[i, 5] = evaluate_algorithm(A, k, get_GNN_embedding, iterations=1, eval_function=main.frac_ep_cost, size_H=l)
        evals[i, 5, 1] = main.deep_ep_cost_function(A, Hs[i,5])
    print(it, 'done')
    return it, (evals, times, Hs, graphs)



if __name__ == '__main__':
    
    set_start_method('spawn')
    experiment_dict = {}
    experiment_dict['dataset_name'] = "protein"
    G = nx.readwrite.read_edgelist("networks/protein.edgelist.txt", nodetype=int)
    A = nx.to_numpy_array(G)
    experiment_dict['dataset'] = A
    
    experiment_dict['num_samples'] = 10
    experiment_dict['num_groups'] = 20
    experiment_dict['num_dif_k'] = 10
    experiment_dict['size_groups'] = experiment_dict['dataset'].shape[1]
    experiment_dict['num_algo'] = 6
    

    str_now = datetime.now().strftime("%Y-%m-%d_%H-%M")
    
    k = experiment_dict['num_dif_k']
    max_k = experiment_dict['num_groups']
    n = experiment_dict['size_groups']
    num_algo = experiment_dict['num_algo']
    num_trials = experiment_dict['num_samples']

    num_workers = 10
    chunks = int(np.ceil(experiment_dict['num_samples'] / num_workers))

    graphs = np.zeros((k, *experiment_dict['dataset'].shape))
    evals, colors, times = np.zeros((num_trials,k, num_algo, 2)), np.zeros((num_trials, k, num_algo, n, max_k)), np.zeros((num_trials,k, num_algo))
    
    print('precomputing embeddings')
    # precompute embeddings
    experiment_dict['n2v_emb'] = node2vec(A, 2, return_embedding=True)
    experiment_dict['r2v_emb'] = role2vec(A, 2, return_embedding=True)
    experiment_dict['ev_emb'] = main.appr_EP_by_dom_EV(A, 2, return_X=True)[1]

    print('starting experiments')
    for i in range(chunks):
        with Pool(num_workers) as p:
            map_res = p.starmap(experiment_iteration, [(it, experiment_dict) for it in list(range(experiment_dict['num_samples']))[i * num_workers : (i+1) * num_workers]])
            
        for it, res_it in map_res:
            evals[it], times[it], colors[it], graphs[it] = res_it
            main.save_experiment("approximateEP_experiment/experiment2/" + experiment_dict['dataset_name'] + str(num_trials) + '_' + str(n) + '_' + str(k) + '_' + str(max_k) + '_' + str_now + '/' + str(i), evals, times, colors, graphs)
