import os
import pickle
import json
import traceback
import argparse
from training.utils import set_seed, mkdir_if_missing, remove_illegal_characters,str2dict,compose_hyper_params,assign_hyp_param
from distutils.util import strtobool
from metrics import *
from pipeline import *
import sys
dir_home = os.getcwd()
sys.path.append(os.path.join(dir_home,'.local/lib/python3.7/site-packages')) # for hpc usage
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CGLB')
    parser.add_argument("--dataset", type=str, default='CoraFull-CL', help='AmazonComputer-CL, Reddit-CL, Arxiv-CL, CoraFull-CL')
    parser.add_argument("--gpu", type=int, default=0, help="which GPU to use.")
    parser.add_argument("--seed", type=int, default=1, help="seed for exp")
    parser.add_argument("--epochs", default=[1,5], help="number of training epochs (passes on each batch)")
    parser.add_argument("--lr", default=[0.01, 0.001, 0.0001, 0.00001], help="learning rate")
    parser.add_argument('--weight-decay', type=float, default=0, help="weight decay")
    parser.add_argument('--backbone', type=str, default='GCN', help="backbone GNN, [GAT, GCN, GIN]")
    parser.add_argument('--method', type=str,
                        choices=["bare", 'lwf', 'agem', 'ewc', 'mas', 'twp', 'jointtrain', 'er', 'joint','Joint'], default="bare",
                        help="baseline continual learning method")
    parser.add_argument('--inter-task-edges', type=strtobool, default=True,
                        help='whether to keep the edges connecting nodes from different tasks')
    parser.add_argument('--d_dtat', default=None, help='will be assigned during running')
    parser.add_argument('--n_cls', default=None, help='will be assigned during running')
    parser.add_argument('--ratio_valid_test', nargs='+', default=[0.2, 0.2], help='ratio of nodes used for valid and test')
    parser.add_argument('--transductive', type=strtobool, default=True, help='using transductive or inductive')
    parser.add_argument('--default_split', type=strtobool, default=False, help='whether to  use the data split provided by the dataset')
    parser.add_argument('--task_seq', default=[])
    parser.add_argument('--n-task', default=0, help='will be assigned during running')
    parser.add_argument('--n_cls_per_task', type=int, default=2, help='how many classes does each task  contain')   
    parser.add_argument('--n_nodes_per_batch', type=int, default=10, help='how many nodes does each streaming batch contain')
    parser.add_argument('--n_validation_tasks', type=int, default=4, help='how many tasks on which to validate the hyperparameters')
    parser.add_argument('--end_batch_test', type=strtobool, default=True, help='whether to perform anytime evaluation')
    parser.add_argument('--GAT_args',
                        default={'num_layers': 1, 'num_hidden': 32, 'heads': 8, 'out_heads': 1, 'feat_drop': .6,
                                 'attn_drop': .6, 'negative_slope': 0.2, 'residual': False})
    parser.add_argument('--GCN_args', default={'h_dims': [256], 'dropout': 0.0, 'batch_norm': False})
    parser.add_argument('--GIN_args', default={'h_dims': [256], 'dropout': 0.0})
    parser.add_argument('--er_args', type=str2dict, default={'budget': [100,1000], 'd': [0.5], 'sampler': ['random'], 'memory_proportion': [1., 2., 3.]},
                        help='sampler options: CM, CM_plus, MF, MF_plus, random')
    parser.add_argument('--lwf_args', type=str2dict, default={'lambda_dist': [0.1, 1.0, 10.0], 'T': [0.2, 2.0, 20.0], 'save_every': [1, 10, 100]})
    parser.add_argument('--twp_args', type=str2dict, default={'lambda_l': [100., 10000., 1000000.], 'lambda_t': [100., 10000., 1000000.], 'beta': [0.001, 0.01, 0.1]})
    parser.add_argument('--ewc_args', type=str2dict, default={'memory_strength': [1., 100., 10000., 1000000., 100000000., 10000000000.]})
    parser.add_argument('--mas_args', type=str2dict, default={'memory_strength': [1., 100., 10000., 1000000., 100000000., 10000000000.]})
    parser.add_argument('--agem_args', type=str2dict, default={'n_memories': [100,1000], 'memory_proportion': [1., 2., 3.]})
    parser.add_argument('--bare_args', type=str2dict, default={'Na': None})
    parser.add_argument('--joint_args', type=str2dict, default={'Na': None})
    parser.add_argument('--cls-balance', type=strtobool, default=False, help='whether to balance the cls when training and testing')
    parser.add_argument('--repeats', type=int, default=1, help='how many times to repeat the experiments for the mean and std')
    parser.add_argument('--batch_size', type=int, default=50000, help='batch size for testing only, for training use n_nodes_per_batch')
    parser.add_argument('--sample_nbs', type=strtobool, default=False, help='whether to sample neighbors instead of using all')
    parser.add_argument('--n_nbs_sample', type=lambda x: [int(i) for i in x.replace(' ', '').split(',')], default=[5, 5], help='number of neighbors to sample per hop, use comma to separate the numbers when using the command line, e.g. 10,25 or 10, 25')
    parser.add_argument('--replace_illegal_char', type=strtobool, default=False)
    parser.add_argument('--ori_data_path', type=str, default='./store/data', help='the root path to raw data')
    parser.add_argument('--data_path', type=str, default='./data', help='the path to processed data (splitted into tasks)')
    parser.add_argument('--result_path', type=str, default='./results', help='the path for saving results')
    parser.add_argument('--load_check', type=strtobool, default=False, help='whether to check the existence of processed data by loading')
    parser.add_argument('--perform_testing', type=strtobool, default=True, help='whether to to test the models with validated hyperparameters')
    args = parser.parse_args()
    args.ratio_valid_test = [float(i) for i in args.ratio_valid_test]
    set_seed(args)
    mkdir_if_missing(f'{args.data_path}')
    if type(args.lr) is str:
        args.lr = json.loads(args.lr)
    if type(args.epochs) is str:
        args.epochs = json.loads(args.epochs)

    method_args = {'er': args.er_args, 'lwf': args.lwf_args, 'twp': args.twp_args, 'ewc': args.ewc_args,
                   'bare': args.bare_args, 'agem': args.agem_args, 'mas': args.mas_args, 'joint': args.joint_args}
    backbone_args = {'GCN': args.GCN_args, 'GAT': args.GAT_args, 'GIN': args.GIN_args}
    hyp_param_list = compose_hyper_params(method_args[args.method], args.lr, args.epochs)
    AP_best, name_best = 0, None
    AP_dict = {str(hyp_params).replace("'",'').replace(' ','').replace(',','_').replace(':','_'):[] for hyp_params in hyp_param_list}
    AF_dict = {str(hyp_params).replace("'",'').replace(' ','').replace(',','_').replace(':','_'):[] for hyp_params in hyp_param_list}
    PM_dict = {str(hyp_params).replace("'",'').replace(' ','').replace(',','_').replace(':','_'):[] for hyp_params in hyp_param_list}
    for hyp_params in hyp_param_list:
        # iterate over each candidate hyper-parameter combination, and find the best one over the validation set
        hyp_params_str = str(hyp_params).replace("'",'').replace(' ','').replace(',','_').replace(':','_')
        print(hyp_params_str)
        assign_hyp_param(args,hyp_params)
        main = get_pipeline(args)
        train_ratio = round(1 - args.ratio_valid_test[0] - args.ratio_valid_test[1], 2)
        subfolder = f'onl_IL/train_ratio_{train_ratio}/'

        name = f'{subfolder}val_{args.dataset}_{args.n_cls_per_task}_{args.method}_{args.lr}_{args.n_nodes_per_batch}_{list(hyp_params.values())}_{args.backbone}_{backbone_args[args.backbone]}_{args.cls_balance}_{args.epochs}_{args.repeats}'
        mkdir_if_missing(f'{args.result_path}/' + subfolder)
        if args.replace_illegal_char:
            name = remove_illegal_characters(name)

        acc_matrices = []
        print('method args are', hyp_params)
        for ite in range(args.repeats):
            print(name, ite)
            args.current_model_save_path = [name,ite]
            try:
                AP, AF, acc_matrix = main(args,valid=True)
                AP_dict[hyp_params_str].append(AP)
                # choose the best configuration according to the validation results
                acc_matrices.append(acc_matrix)
                torch.cuda.empty_cache()
                if ite == 0:
                    with open(
                            f'{args.result_path}/log.txt',
                            'a') as f:
                        f.write(name)
                        f.write('\nAP:{},AF:{}\n'.format(AP, AF))
            except Exception as e:
                mkdir_if_missing(f'{args.result_path}/errors/' + subfolder)
                if ite > 0:
                    name_ = f'{subfolder}val_{args.dataset}_{args.n_cls_per_task}_{args.method}_{list(hyp_params.values())}_{args.backbone}_{backbone_args[args.backbone]}_{args.cls_balance}_{args.epochs}_{ite}'
                    with open(f'{args.result_path}/{name_}.pkl', 'wb') as f:
                        pickle.dump(acc_matrices, f)
                print('error', e)
                name = 'errors/{}'.format(name)
                acc_matrices = traceback.format_exc()
                print(acc_matrices)
                print('error happens on \n', name)
                torch.cuda.empty_cache()
                break
        if np.mean(AP_dict[hyp_params_str]) > AP_best:
            AP_best = np.mean(AP_dict[hyp_params_str])
            hyp_best_str = hyp_params_str
            name_best = name
        print(f'best params is {hyp_best_str}, best AP is {AP_best}')
        with open(f'{args.result_path}/{name}.pkl', 'wb') as f:
            pickle.dump(acc_matrices, f)

    if args.perform_testing:
        print('----------Now in testing--------')
        acc_matrices = []
        for ite in range(args.repeats):
            args.current_model_save_path = [name_best, ite]
            AP_test, AF_test, acc_matrix_test = main(args, valid=False)
            acc_matrices.append(acc_matrix_test)
        with open(f'{args.result_path}/{name_best}.pkl'.replace('val', 'te'), 'wb') as f:
            pickle.dump(acc_matrices, f)
