import numpy as np
import torch as th
from data_handler import build_synthetic_dataset
import os
from sys import argv
from TFGW_raw import GWmachine
import pickle
from scipy.sparse.csgraph import shortest_path
import warnings
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"]=""
# %% parser for GW layer without GNN preprocessing

# python run_synthetic_datasets.py '4cycles' 'ADJ' [10] 1000 [0,1,2,3,5,6] 1

"""
graph_mode: input representation of the graphs ['ADJ','SP']..
init_mode_graph: way to initialize the graph templates used in the classifier
"""
"""
dataset_name = str(argv[1])
assert dataset_name in ['4cycles', 'skipcircles']
graph_mode =str(argv[2]) # ['ADJ','SP'...]
if '[' in argv[3]: # number of graph templates
    list_Katoms = [int(x) for x in argv[3][1:-1].split(',')] 
else:
    list_Katoms = [int(argv[3])]


epochs = int(argv[4])    
if '[' in argv[5]: # random seed for initializating the atoms.
    list_seeds = [int(x) for x in argv[5][1:-1].split(',')] 
else:
    list_seeds = [int(argv[5])]
fixed_templates = bool(int(argv[6]))
"""
# %%

dataset_name ='4cycles'
graph_mode = 'ADJ'
list_Katoms = [2, 4, 6, 8, 10, 12, 14, 16]
init_mode_atoms = 'sampling_supervised_median' # default for our method / For these toy datasets, all graphs have the same size
lr = 0.01
alpha = 0
epochs = 1000
fixed_templates = True
str_fixed_templates = ''
if fixed_templates:
    str_fixed_templates = 'fixedtemplates'
list_seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
if fixed_templates:
    learn_hbar = False
else:
    learn_hbar = True

if graph_mode=='SP':
    atoms_projection = 'nonnegative'
else:
    atoms_projection = 'clipped'
str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path','LAP':'laplacian',
                 'fullADJ':'augmented_adjacency','normADJ':'normalized_adjacency'}
abspath = os.path.abspath('../')
res_repo = abspath+'/results/%s/'%dataset_name
data_path = abspath+'/real_datasets/'
supervised_sampler = False
if supervised_sampler:
    str_batch=''
else:
    str_batch='random'

val_timestamp = 5
track_templates = True
dtype = th.float64
verbose= False
n_jobs = 3
dropout = 0.
lr = 0.01

device = 'cpu'
num_hidden_clf = 2
hidden_dim_clf = 128
if dataset_name == '4cycles':
    batch_size = 50
elif dataset_name == 'skipcircles':
    batch_size = 10

for seed in list_seeds:
    for Katoms in list_Katoms:
        th.manual_seed(seed)
        np.random.seed(seed)
        init_graph_str = {'random_median':'randmed', 'random_small':'randsmall', 'sampling_supervised_median':'samplabelmed'}
        proj_str  = {'clipped':'', 'nonnegative':'_nonnegative_'}
        if learn_hbar:
            learn_hbar_str = 'learnhbar'
        else:
            learn_hbar_str = ''
        clf_net_dict = {'hidden_dim':hidden_dim_clf,
                        'num_hidden':num_hidden_clf,
                        'dropout':dropout}
        clf_str = 'clfMLPh%sL%sdropout%s'%(hidden_dim_clf, num_hidden_clf, dropout)
        
        experiment_name = '/V5%s%s%s_%sK%s%s_%s_lr%s_%sbatch%s_ep%s_seed%s/'%(
                graph_mode, learn_hbar_str, str_fixed_templates, init_graph_str[init_mode_atoms], 
                Katoms, proj_str[atoms_projection], clf_str, lr, str_batch, batch_size, epochs, seed)
        
        experiment_repo = res_repo + experiment_name
        model_name = 'model_seed%s'%seed
        print('experiment_repo:', experiment_repo)
        print('model_name:', model_name)
        train_model = True
        test_model = True
        if os.path.exists(experiment_repo):
            print('already existing repository')
            try:
                training_log = pickle.load(open(experiment_repo+'/%s_training_log.pkl'%model_name,'rb'))
                print('training log found / saved epochs:', len(training_log['train_cumulated_batch_loss']))
                if (len(training_log['train_cumulated_batch_loss']) == epochs):
                    print('experiment already completed / done epochs:', len(training_log['train_cumulated_batch_loss']) * val_timestamp  )
                    train_model = False
                    if os.path.exists(experiment_repo+'/%s_test_result.pkl'%model_name):
                        test_model = False
                    else:
                        test_model = True
                    
                else:
                    train_model = True
                    test_model = True
            except:
                train_model = True
                test_model = True
        print('train model? ', train_model)

        if train_model or test_model:
            # Load graphs and corresponding labels for clustering benchmark
            if dataset_name == '4cycles':
                train_dataset = build_synthetic_dataset.FourCycles()
                test_dataset = build_synthetic_dataset.FourCycles()
                n_labels = 2
            elif dataset_name == 'skipcircles':
                train_dataset = build_synthetic_dataset.SkipCircles()
                test_dataset = build_synthetic_dataset.SkipCircles()
                n_labels = 10
            n = train_dataset.num_nodes
            print(f'Number of nodes: {n}')
            train_graphs = train_dataset.makedata()                                        
            test_graphs = test_dataset.makedata()                                        
            
            
            X_train = [build_synthetic_dataset.get_adjacency_matrix(g.edge_index, g.num_nodes, dtype).to(device=device, dtype=dtype) for g in train_graphs]
            X_test = [build_synthetic_dataset.get_adjacency_matrix(g.edge_index, g.num_nodes, dtype).to(device=device, dtype=dtype) for g in test_graphs]
            h_train = [th.ones(C.shape[0], device=device, dtype=dtype)/C.shape[0] for C in X_train]
            h_test = [th.ones(C.shape[0], device=device, dtype=dtype)/C.shape[0] for C in X_test]
            shapes = [C.shape[0] for C in X_train]
            dataset_size = len(X_train)
            
            if graph_mode == 'SP':
                X_train = [th.tensor(shortest_path(C.numpy()), device=device, dtype=dtype) for C in X_train]
                X_test = [th.tensor(shortest_path(C.numpy()), device=device, dtype=dtype) for C in X_test]
            
            y_train = th.Tensor([g.y for g in train_graphs]).to(device=device, dtype=th.long)
            y_test = th.Tensor([g.y for g in test_graphs]).to(device=device, dtype=th.long)
            # median shape
            Satoms = round(np.median(shapes))
            list_Satoms = [Satoms] * Katoms
            print('median atom shapes:', list_Satoms)
            
            
            # GW machines
            model = GWmachine(
                Katoms=Katoms, 
                n_labels= n_labels, 
                experiment_repo=experiment_repo, 
                learn_hbar = learn_hbar, 
                clf_net_dict=clf_net_dict,
                dtype=dtype,
                device=device)
            if train_model:
                model.init_parameters(list_Satoms=list_Satoms, init_mode_atoms='sampling_supervised', graphs=X_train, labels=y_train, atoms_projection=atoms_projection, verbose=verbose)
                
                if not fixed_templates:
                    model.fit(model_name=model_name, 
                              X_train=X_train, y_train=y_train, 
                              X_val=None, y_val=None, X_test=None, y_test=None,
                              atoms_projection=atoms_projection, lr=lr, batch_size=batch_size, 
                              supervised_sampler=supervised_sampler, epochs=epochs, val_timestamp=val_timestamp, 
                              algo_seed=0, track_templates=track_templates, verbose=verbose, n_jobs=n_jobs)
                else:
                    model.fit_fixedtemplates(model_name=model_name, 
                              X_train=X_train, y_train=y_train, 
                              X_val=None, y_val=None, X_test=None, y_test=None,
                              atoms_projection=atoms_projection, lr=lr, batch_size=batch_size, 
                              supervised_sampler=supervised_sampler, epochs=epochs, val_timestamp=val_timestamp, 
                              algo_seed=0, track_templates=track_templates, verbose=verbose, n_jobs=n_jobs)
                
            if test_model:
                str_file = model.experiment_repo+'/%s_best_train_accuracy.pkl'%model_name
                
                res_test_file = experiment_repo+'/%s_test_result.pkl'%model_name
                model.load(str_file)
                dist_features_test, pred_test, y_pred_test, loss_test, res_test = model.evaluate(X_test, h_test, y_test, n_jobs)
                print('res_test:', res_test)
                pickle.dump(res_test, open(res_test_file, 'wb'))
                
                        
                                    