import numpy as np
import os
import itertools
from tqdm import tqdm
import pandas as pd

#%%


def fullgrid_benchmarked_kernel_results(
        method:str,
        dataset_name:str, 
        method_hyperparameters:dict,
        metrics:list,
        aggregators:list=None,
        version:str=''):
    abspath = os.path.abspath('../')
    metrics = ['acc']
    sets_ = ['val_mean', 'test']
    stacked_results = {}
    for metric in metrics:
        for s in sets_:
            stacked_results['mean_%s_%s'%(s,metric)]=[]
            stacked_results['std_%s_%s'%(s,metric)]=[]
    stacked_results['0_test']=[]
    full_dict = {}
    full_dict.update(method_hyperparameters)
    print('full_dict:', full_dict)
    for params in full_dict.keys():
        stacked_results[params]=[]
        full_keys, full_values = zip(*full_dict.items())
    full_experiments  = [dict(zip(full_keys,v)) for v in itertools.product(*full_values)]
    for n_exp, exp in tqdm(enumerate(full_experiments),desc='stacked experiments:'):
        res_repo = abspath + '/kernel_results/%s/'%dataset_name
        
        if method =='FGW':
            
            if not dataset_name in ['imdb-b', 'imdb-m', 'colab']:
                experiment_name = '/FGWkernel_%s_alpha%s_dist%s'%(
                    exp['graph_mode'], exp['alpha'], exp['dist_features'])
            else:
                assert exp['features_mode'] in ['degree', 'ones', 'onehot']
                experiment_name = '/FGWkernel_%s%s_alpha%s_dist%s'%(
                    exp['graph_mode'], exp['features_mode'], exp['alpha'], exp['dist_features'])
                
            print('exp:', experiment_name)
        elif 'WL' in method:
            if method == 'WL':
                str_method = 'WLkernel'
            elif method == 'WWL':
                str_method = 'WassWLkernel'
            str_features_mode = ''
            if dataset_name in ['imdb-b', 'imdb-m', 'collab']:
                str_features_mode = '_degree'                    
            experiment_name = '/%s%s_wl%s'%(str_method, str_features_mode, exp['wl'])
                
        experiment_repo = res_repo + experiment_name
        print('experiment_repo')
        try:
            res = pd.read_csv(experiment_repo+'/res_SVC.csv')
            print('res:', res.keys())
            stacked_results['mean_val_mean_acc'].append(np.mean(res['val_mean_acc'].values))
            stacked_results['mean_test_acc'].append(np.mean(res['test_acc'].values))
            stacked_results['std_val_mean_acc'].append(np.std(res['val_mean_acc'].values))
            stacked_results['std_test_acc'].append(np.std(res['test_acc'].values))
            stacked_results['0_test'].append(res['test_acc'].values[0])
            # Add corresponding method hyperparameters
            for param in exp.keys():
                stacked_results[param].append(exp[param])
        except:
            continue
    ## Then aggregate based on grid_aggreg_params
    for key in stacked_results.keys():
        print('key: %s / len res: %s'%(key,len(stacked_results[key])))
    stacked_df =pd.DataFrame(stacked_results)
    return stacked_df

#%%

method = 'FGW'
dataset_name = 'mutag'
graph_mode = 'ADJ'
dist_features = 'euclidean'
hyperparameters = {
    'alpha':[0.9997,0.995,0.9,0.75,0.5,0.25,0.1,0.005,0.0003],
    'dist_features':[dist_features],
    'graph_mode':[graph_mode]}
if dataset_name in ['imdb-b']:
    hyperparameters['features_mode'] = ['degree', 'onehot']    


df_res = fullgrid_benchmarked_kernel_results(
        method,
        dataset_name, 
        hyperparameters,
        metrics=['acc'])
#%% WL kernels


method = 'WL'
dataset_name = 'imdb-b'
hyperparameters = {
    'wl':[1,2,3,4,5,6,7,8,9,10]}


df_res = fullgrid_benchmarked_kernel_results(
        method,
        dataset_name, 
        hyperparameters,
        metrics=['acc'])