

import numpy as np
from algo_graphdataset import GWh_datasets_graph
import algo_relaxedGW as algo
#import pylab as pl
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_mutual_info_score,rand_score,accuracy_score, precision_score, recall_score, roc_auc_score
from sys import argv
import os
import pandas as pd
#import pickle
#from GW_kmeans import Kmeans_Kernelmatrix #GW_Kmeans_extendedDL,
from tqdm import tqdm


#%%

#python run_learning_graphdataset_full.py "imdb-b" [10,20,30,40,50] 0 0 "random" [0.01] [32] 100 0 "ADJ" 0 1 0

dataset_name = str(argv[1])
assert dataset_name in ['imdb-b','imdb-m']
if '[' in argv[2]:
    list_Ntarget = [int(x) for x in argv[2][1:-1].split(',')] 
else:
    list_Ntarget = [int(argv[2])]
if '[' in argv[3]:
    list_lambda_reg = [float(x) if x!=0 else int(x) for x in argv[3][1:-1].split(',') ]
else:
    list_lambda_reg = [float(argv[3])]

if '[' in argv[4]:
    list_gamma_entropy = [float(x) if x!=0 else int(x) for x in argv[4][1:-1].split(',') ]
else:
    list_gamma_entropy = [float(argv[4])]

init_mode_graph = str(argv[5])

if '[' in argv[6]:
    list_lr = [float(x) for x in argv[6][1:-1].split(',') ]
else:
    list_lr = [float(argv[6])]
if '[' in argv[7]:
    list_batch= [int(x) for x in argv[7][1:-1].split(',') ]
else:
    list_batch= [int(argv[7])]
    
epochs = int(argv[8])
if '[' in argv[9]:
    list_seeds = [int(x) for x in argv[9][1:-1].split(',')] 
else:
    list_seeds = [int(argv[9])]
mode = argv[10]
try:
    degrees = bool(int(argv[11]))
except:
    degrees=False
try:
    use_optimizer= bool(int(argv[12]))
except:
    use_optimizer=True


#%%

dataset_name ='imdb-m'
use_checkpoint = True
list_Ntarget=[10]
#list_Ntarget=[10,20,30,40,50]
list_lambda_reg=[0.]
list_seeds=[0]
list_gamma_entropy = [0]
init_mode_graph = 'random'
list_lr =[0.01]
#list_lr = [0.001,0.01]
#list_batch= [16,32]#[16,32]#,32]
list_batch= [32]
epochs =20
use_optimizer = True
mode = 'ADJ'

str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path','LAP':'laplacian',
                         'fullADJ':'augmented_adjacency','normADJ':'normalized_adjacency',
                         'SIF':'sif_distance', 'SLAP':'signed_laplacian'}

degrees=False
use_optimizer=True
str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path','LAP':'laplacian',
                         'fullADJ':'augmented_adjacency','normADJ':'normalized_adjacency',
                         'SIF':'sif_distance', 'SLAP':'signed_laplacian'}

use_checkpoint = True
abspath = os.path.abspath('../')
experiment_repo ='/results/%s/'%dataset_name
if mode=='ADJ':
    str_mode =''
else:
    str_mode=mode+'_'

init_GW = 'product'
data_path = './real_datasets/'
max_iter_inner = 1000
eps_inner = 10**(-6)
use_warmstart_MM= True
for gamma_entropy in list_gamma_entropy:
    for lambda_reg in list_lambda_reg:
        for batch_size in list_batch:
            for lr in list_lr:
                for Ntarget in list_Ntarget:
                    for seed in list_seeds:
                        init_str= {'random':'rand', 'fixed_cluster':'cls', 'randomV2':'randV2'}
                        optim_str = {True: 'Adam', False:'SGD'}
                        if gamma_entropy ==0:
                            entropic_reg_str =''
                        else:
                            entropic_reg_str = 'ENTreg%s_'%gamma_entropy
                        if lambda_reg ==0:
                            reg_str = 'reg0.0'
                            max_iter_MM =0
                            eps_inner_MM = 0
                            rerun_clustering=False
                        elif lambda_reg>0:
                            reg_str = 'MMreg%s'%lambda_reg
                            max_iter_MM = 50
                            eps_inner_MM = 10**(-6)
                            rerun_clustering =True
                        else:
                            raise 'negative lambda_reg is not supported - promoting density of the OT is the goal of this regularization'
                    
                        if not degrees:
                            experiment_name= '/%sV2initproduct_Ntarget%s_%s%s_%s_init%s_lr%s_batch%s_epochs%s_seed%s/'%(str_mode,Ntarget,entropic_reg_str,reg_str,optim_str[use_optimizer],init_str[init_mode_graph],lr,batch_size, epochs,seed)
                        else:
                            experiment_name= '/%sV2initproduct_Ntarget%s_%s%s_%s_init%s_lr%s_batch%s_epochss_degrees_seed%s/'%(str_mode,Ntarget,entropic_reg_str,reg_str,optim_str[use_optimizer],init_str[init_mode_graph],lr,batch_size, epochs,seed)
                    
                        method=GWh_datasets_graph(graphs=None, masses=None,
                                                          dataset_name = dataset_name,
                                                          mode=mode, Ntarget=Ntarget,
                                                          experiment_repo=experiment_repo,
                                                          experiment_name=experiment_name,degrees=degrees,data_path= data_path)
                                   
                        full_path = abspath+experiment_repo+experiment_name
                        print('full_path:', full_path)
                        dictionary_to_learn = False
                        if os.path.exists(full_path):
                            method.load_elements(use_checkpoint=use_checkpoint)
                            log_loss= np.load(full_path+'/loss.npy')
                        else:
                            print('- start learning -')
                            if mode in ['ADJ','SP','fullADJ', 'SLAP', 'normADJ','SIF']:
                                local_proj = 'nsym'
                            elif mode in ['LAP','normLAP']:
                                local_proj = 'sym'
                            else:
                                raise 'unknown projection for mode : %s'%mode
                            method.Learn_Ctarget(lambda_reg=lambda_reg, epochs=epochs ,max_iter_inner=1000, 
                                                 eps_inner=10**(-6),lr=lr, batch_size=batch_size,checkpoint_freq=5,max_iter_MM=max_iter_MM, eps_inner_MM=eps_inner_MM,gamma_entropy=gamma_entropy,
                                                 algo_seed=seed, beta_1= 0.9, beta_2=0.99, init_mode_graph=init_mode_graph, use_warmstart_MM = use_warmstart_MM,
                                                 use_optimizer=use_optimizer,use_checkpoint= use_checkpoint, proj=local_proj,init_GW =init_GW)  
                        
                        n_clusters = len(np.unique(method.y)) 
                   
                        if os.path.exists(full_path+'/res_clustering.csv') :
                            continue
                        else:
                            km_embeddings_res = {'checkpoint':[],'RI':[], 'ami':[],'loss_mean':[],'loss_std':[]}
                            print('computing unmixing - 1 shot')
                            
                            list_best_T,list_best_losses=method.compute_unmixing(lambda_reg=lambda_reg,gamma_entropy=gamma_entropy,use_checkpoint=use_checkpoint,
                                                                                 init_GW = init_GW, eps_inner = eps_inner, max_iter_inner = max_iter_inner,
                                                                                 eps_inner_MM = eps_inner_MM,max_iter_MM=max_iter_MM,use_warmstart_MM=False, algo_seed=0)
                            list_unmixings = np.array([[np.sum(T,axis=0) for T in list_OT] for list_OT in list_best_T])
                            np.save(full_path+'unmixings.npy', np.array(list_unmixings))
                            np.save(full_path+'losses_unmixings.npy', np.array(list_best_losses))
                            
                            for checkpoint in range(len(method.checkpoint_Ctarget)):
                                unmixings = list_unmixings[checkpoint]
                                km = KMeans(n_clusters =n_clusters, n_init=100,random_state = 0).fit(unmixings)
                                pred = km.labels_
                                ami = adjusted_mutual_info_score(method.y, pred, average_method='max')
                                rand_index = rand_score(method.y,pred)
                                km_embeddings_res['checkpoint'].append(checkpoint)
                                km_embeddings_res['RI'].append(rand_index); km_embeddings_res['ami'].append(ami)
                                km_embeddings_res['loss_mean'].append(np.mean(list_best_losses[checkpoint]))
                                km_embeddings_res['loss_std'].append(np.std(list_best_losses[checkpoint]))
                            pd.DataFrame(km_embeddings_res).to_csv(full_path+'res_clustering.csv')
                     