import main
import time
import torch_geometric as ptg
import torch as pt
from ray import tune

#from models import GIN_Module, layered_GIN_Module, layered_GNN_Module, GAT_Module, GATv2_Module, GCN_Module, GraphSAGE_Module, sGIN_Module, GNN_Module
from torch.utils.data import random_split
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, ReLU,Sigmoid, LeakyReLU, ELU, Tanh
from torch_geometric.data import DataLoader
from torch_geometric.transforms import BaseTransform
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor

from ray.tune.integration.pytorch_lightning import TuneReportCallback

from ogb.nodeproppred import PygNodePropPredDataset

from GNN_models import GNN_Module, AppendEVs



import sys
import shutil

if __name__ == "__main__":
    print(sys.argv)
    number_of_workers = 20
    num_evs = 4
    dataset_name = sys.argv[1]
    if sys.argv[2] in ["GCN", "GIN", "GAT"]:
        model_class = GNN_Module
    else:
        raise ValueError("Model name unknown.")
    
    if sys.argv[3] in ['graph2', 'graph', 'batch_wo_mean', 'batch', 'none', 'pair']:
        norm = sys.argv[3]
        activation_func = 'leakyrelu'
    elif sys.argv[3] == 'lin_graph2':
        norm = 'graph2'
        activation_func = 'id'
    else:
        raise ValueError("Normalization unknown.")

    
    print(dataset_name, model_class, norm)
    if dataset_name == "Cora_ML" or dataset_name == "Citeseer" or dataset_name == "ogbn-arxiv":
        if dataset_name == "Cora_ML" or dataset_name == "Citeseer":
            dataset_full = ptg.datasets.CitationFull(root='./datasets', name=dataset_name, pre_transform=AppendEVs(num_evs, resize_y=True))
            has_predetermined_split = False
            cv_num_folds = 5
        else:
            dataset_full = PygNodePropPredDataset(name="ogbn-arxiv", root='./datasets', pre_transform=AppendEVs(num_evs, make_undirected=True, resize_y=True))
            dataset_full[0].y = dataset_full[0].y.clone().reshape((-1,))
            dataset_full[0].edge_index = ptg.utils.to_undirected(dataset_full[0].edge_index)
            print(dataset_full[0].y.shape, dataset_full[0].y.clone().reshape((-1,)).shape, ptg.utils.is_undirected(dataset_full[0].edge_index))
            has_predetermined_split = True
            number_of_workers = 5
            cv_num_folds = 1
        is_test = -1

        experiment_config = {'experiment_name': sys.argv[2]+"/"+('lin_' if activation_func == 'id' else '')+norm+'_'+time.strftime("%Y%m%d-%H%M%S"),
                             'dataset_name':dataset_name, 
                             'dataset': dataset_full,
                             'test': is_test,
                             'is_cv': False,
                             'init':'uniform', 
                             'model_class':model_class, 
                             'num_classes':dataset_full.num_classes,
                             'dataset_num_classes':dataset_full.num_classes,
                             'num_features':dataset_full.num_node_features,
                             'training_procedure': main.tmh,
                             'training_procedure_cv': main.train_model,
                             'grace_period': 50,
                             'num_epochs': 200,
                             'cv_num_epochs': 200,
                             'cv_num_folds' : cv_num_folds,
                             'cv_iterations': 10,
                             'cv_save_top_k': 1,
                             'cv_save_last': True,
                             'tune_num_samples': 4,
                             'batch_size': 32,
                             'val_size': 0.20,
                             'has_predetermined_split': has_predetermined_split,
                             'callbacks': [EarlyStopping("val_score", patience=50)],
                             'cv_callbacks': [],
                             'skip_hyperparameter_search': False,
                             'hyperparameter_search_only_first_fold': False,
                             'hyperparameters': 'ray_results/tune_GIN/graph2_20240116-135655/hyperparameters.json',}

        tune_model_config = {'embedding_size': tune.grid_search([128]),
                             'learning_rate': tune.grid_search([0, 0.1, 0.01, 0.001, 0.0001]),
                             'num_layers': tune.grid_search([15]), 
                             'a_iteration': tune.grid_search(list(range(experiment_config['tune_num_samples']))),
                             'dropout': tune.grid_search([0.0,]),
                             'regularization': tune.grid_search([0,0.01,0.0001]),
                             'num_classes':experiment_config['num_classes'], 
                             'dataset_num_classes':experiment_config['dataset_num_classes'],
                             'num_features':experiment_config['num_features'],
                             'initialization': tune.choice(['o']), 
                             'jk':None, 
                             'num_evs':num_evs,
                             'activation_func': tune.grid_search(['tanh']), 
                             'graph_level':False,
                             'normalization':norm,
                             'gnn_conv':sys.argv[2]}
    else: 
        dataset_full = ptg.datasets.TUDataset(root='./datasets', name=dataset_name, use_node_attr=True, pre_transform=AppendEVs(num_evs))
        is_node_level = False
        has_predetermined_split = False
        is_test = -1
        experiment_config = {'experiment_name': sys.argv[2]+"/"+('lin_' if activation_func == 'id' else '')+norm+'_'+time.strftime("%Y%m%d-%H%M%S"),
                             'dataset_name':dataset_name, 
                             'dataset': dataset_full,
                             'test': is_test,
                             'is_cv': False,
                             'init':'uniform', 
                             'model_class':model_class, 
                             'num_classes':dataset_full.num_classes,
                             'dataset_num_classes':dataset_full.num_classes,
                             'num_features':dataset_full.num_node_features,
                             'training_procedure': main.tmh,
                             'training_procedure_cv': main.train_model,
                             'grace_period': 50,
                             'num_epochs': 200,
                             'cv_num_epochs': 200,
                             'cv_num_folds' : 5,
                             'cv_iterations': 10,
                             'cv_save_top_k': 1,
                             'cv_save_last': True,
                             'tune_num_samples': 4,
                             'batch_size': 32,
                             'val_size': 0.20,
                             'has_predetermined_split': has_predetermined_split,
                             'callbacks': [EarlyStopping("val_score", patience=50)],
                             'cv_callbacks': [],
                             'skip_hyperparameter_search': False,
                             'hyperparameter_search_only_first_fold': False,
                             'hyperparameters': 'ray_results/tune_GIN/graph2_20240116-135655/hyperparameters.json',}

        tune_model_config = {'embedding_size': tune.grid_search([64]),
                             'learning_rate': tune.grid_search([0, 0.1, 0.01, 0.001, 0.0001]),
                             'num_layers': tune.grid_search([3,5]), 
                             'a_iteration': tune.grid_search(list(range(experiment_config['tune_num_samples']))),
                             'dropout': tune.grid_search([0.0,]),
                             'regularization': tune.grid_search([0,0.01,0.0001]),
                             'num_classes':experiment_config['num_classes'], 
                             'dataset_num_classes':experiment_config['dataset_num_classes'],
                             'num_features':experiment_config['num_features'],
                             'initialization': tune.choice(['o']), 
                             'jk':None, 
                             'num_evs':num_evs,
                             'activation_func': tune.grid_search(['leakyrelu']), 
                             'graph_level':True,
                             'normalization':norm,
                             'gnn_conv':sys.argv[2]}

    start = time.time()
    main.crossvalidation_with_hyperparameter_search(tune_model_config, experiment_config, number_of_workers)
   
    
    print("Hyperparameter search took: ", time.time()-start, " seconds")