# To run the experiment, simply run this python script from the command line:
#
# python approximateEP_experiment1.py
#
# This replicates the experiment as reported in the paper. The results are saved in the folder approximateEP_experiment/experiment1/.

import main
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
import time
from karateclub import Role2Vec, Node2Vec, GraphWave
from sklearn.cluster import KMeans
from multiprocessing import Pool, set_start_method
from datetime import datetime
import time
import os
from main import evaluate_algorithm, role2vec,node2vec, save_experiment
from GNN_models import get_GNN_embedding


def experiment_iteration(it, experiment_dict):
    k = experiment_dict['num_groups']
    n = experiment_dict['size_groups']
    num_algo = experiment_dict['num_algo']
    s = experiment_dict['num_samples_A']
    graphs = np.zeros((s, n*k**2, n*k**2))
    evals, Hs, times = np.zeros((s,num_algo, 2)), np.zeros((s,num_algo, n*k**2, k)), np.zeros((s,num_algo))
    np.random.seed()
    omega_1 = np.random.rand(k,k)
    A_s = [main.sample_planted_role_model(k, [n]*k, omega_1=omega_1) for i in range(s)]# omega_1=omega_1
    for i in range(1,experiment_dict['num_samples_A']):
        A = np.mean(A_s[:i],axis=0)
        graphs[i] = A.copy()
        print(it, i,0)
        evals[i, 0, 0], times[i, 0], Hs[i, 0] = evaluate_algorithm(A, k, main.appr_EP_by_dom_EV, eval_function=main.frac_ep_cost, size_H=k)
        evals[i, 0, 1] = main.deep_ep_cost_function(A, Hs[i,0])
        print(it,i,1)
        evals[i, 1, 0], times[i, 1], Hs[i, 1] = evaluate_algorithm(A, k, lambda x,y : main.frac_WL(x,y, fit_function="average_linkage"), eval_function=main.frac_ep_cost, size_H=k)
        evals[i, 1, 1] = main.deep_ep_cost_function(A, Hs[i,1])
        print(it, i,2)
        evals[i, 2, 0], times[i, 2], Hs[i, 2] = evaluate_algorithm(A, k, lambda x,y : main.frac_WL(x,y, fit_function="soft_kmeans"), eval_function=main.frac_ep_cost, size_H=k)
        evals[i, 2, 1] = main.deep_ep_cost_function(A, Hs[i,2])
        print(it, i,3)
        evals[i, 3, 0], times[i, 3], Hs[i, 3] = evaluate_algorithm(A, k, role2vec, eval_function=main.frac_ep_cost, size_H=k)
        evals[i, 3, 1] = main.deep_ep_cost_function(A, Hs[i,3])
        print(it, i,4)
        evals[i, 4, 0], times[i, 4], Hs[i, 4] = evaluate_algorithm(A, k, node2vec, eval_function=main.frac_ep_cost, size_H=k)
        evals[i, 4, 1] = main.deep_ep_cost_function(A, Hs[i,4])
        print(it, i,5)
        evals[i, 5, 0], times[i, 5], Hs[i, 5] = evaluate_algorithm(A, k, get_GNN_embedding, eval_function=main.frac_ep_cost, size_H=k)
        evals[i, 5, 1] = main.deep_ep_cost_function(A, Hs[i,5])
        
    print(it, 'done')
    return it, (evals, times, Hs, graphs)



if __name__ == '__main__':
    
    set_start_method('spawn')
    experiment_dict = {}
    experiment_dict['num_samples'] = 100
    experiment_dict['num_groups'] = 5
    experiment_dict['size_groups'] = 10
    experiment_dict['num_algo'] = 6
    experiment_dict['num_samples_A'] = 10 + 1

    str_now = datetime.now().strftime("%Y-%m-%d_%H-%M")
    
    k = experiment_dict['num_groups']
    n = experiment_dict['size_groups']
    num_algo = experiment_dict['num_algo']
    num_trials = experiment_dict['num_samples']
    s = experiment_dict['num_samples_A']

    num_workers = 10
    chunks = int(np.ceil(experiment_dict['num_samples'] / num_workers))

    graphs = np.zeros((num_trials, s, n*k**2, n*k**2))
    evals, colors, times = np.zeros((num_trials,s, num_algo, 2)), np.zeros((num_trials,s, num_algo, n*k**2, k)), np.zeros((num_trials,s, num_algo))

    for i in range(chunks):
        with Pool(num_workers) as p:
            map_res = p.starmap(experiment_iteration, [(it, experiment_dict) for it in list(range(experiment_dict['num_samples']))[i * num_workers : (i+1) * num_workers]])
            
        for it, res_it in map_res:
            
            evals[it], times[it], colors[it], graphs[it] = res_it
            save_experiment("approximateEP_experiment/experiment1/" + str(num_trials) + '_' + str(n) + '_' + str(k) + '_' + str_now + '/' + str(i), evals, times, colors, graphs)
    
