import os
import gc
import pickle
import logging
import argparse
import numpy as np
from utils.misc import *
from utils.dataset import get_graph_dataset
from utils.pipeline import get_pipeline
from distutils.util import strtobool

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='OCGL')
    parser.add_argument('--dataset', type=str, default='CoraFull', help='AmazonComputer, Reddit, Arxiv, CoraFull, RomanEmpire, Elliptic')
    parser.add_argument('--gpu', type=int, default=0, help='which GPU to use.')
    parser.add_argument('--seed', type=int, default=1, help='seed for exp')
    parser.add_argument('--epochs', type=lambda x: [int(i) for i in x.replace(' ', '').split(',')], default=[1, 5], help='number of training passes')
    parser.add_argument('--lr', type=lambda x: [float(i) for i in x.replace(' ', '').split(',')], default=[0.01, 0.001, 0.0001, 0.00001], help='learning rate')
    parser.add_argument('--gain', type=lambda x: [float(i) for i in x.replace(' ', '').split(',')], default=[0.1, 1, 10], help='gain for weight initialization')
    parser.add_argument('--weight-decay', type=float, default=0, help='weight decay')
    parser.add_argument('--backbone', type=str, default='GCN', help='backbone model/feature extractor, [GCN, SGC, UGCN, GRNF, UMIXED]')
    parser.add_argument('--method', type=str, choices=['bare', 'lwf', 'agem', 'ewc', 'mas', 'twp', 'er', 'joint', 'slda', 'ssmer', 'ssmagem', 'pdgnn'], default='bare',
                        help='baseline continual learning method')
    parser.add_argument('--ratio_valid_test', type=lambda x: [float(i) for i in x.replace(' ', '').split(',')], default=[0.2, 0.2], help='ratio of nodes used for valid and test')
    parser.add_argument('--n_cls_per_task', type=int, default=2, help='how many classes does each task  contain')
    parser.add_argument('--IL_stream', type=str, default='classIL', help='the type of incremental stream, [classIL, timeIL]') 
    parser.add_argument('--n_time_tasks', type=int, default=10, help='how many tasks to create for the time-incremental stream')
    parser.add_argument('--n_nodes_per_batch', type=int, default=10, help='how many nodes does each streaming batch contain')
    parser.add_argument('--n_validation_tasks', type=int, default=4, help='how many tasks on which to validate the hyperparameters')
    parser.add_argument('--anytime_eval', type=strtobool, default=True, help='whether to perform anytime evaluation (at the end of each batch)')
    parser.add_argument('--anytime_eval_freq', type=int, default=1, help='how many batches to wait before performing anytime evaluation')
    parser.add_argument('--GCN_args', default={'h_dims': [256], 'dropout': 0.0, 'batch_norm': False})
    parser.add_argument('--SGC_args', default={'h_dims': [256], 'k': 2})
    parser.add_argument('--UGCN_args', type=str2dict, default={'h_dims': [1024, 1024]})
    parser.add_argument('--GRNF_args', type=str2dict, default={'h_dims': [1024, 1024], 'order_2_prc': 0.8})
    parser.add_argument('--memory_budget', type=int, default=1000)
    parser.add_argument('--agem_args', type=str2dict, default={'memory_proportion': [1, 2, 3]})
    parser.add_argument('--er_args', type=str2dict, default={'memory_proportion': [1, 2, 3]})
    parser.add_argument('--pdgnn_args', type=str2dict, default={'memory_proportion': [1, 2, 3]})
    parser.add_argument('--ssmer_args', type=str2dict, default={'nei_budget': [(5, 5), (10, 10)], 'memory_proportion': [1, 2, 3]})
    parser.add_argument('--ssmagem_args', type=str2dict, default={'nei_budget': [(5, 5), (10, 10)], 'memory_proportion': [1, 2, 3]})
    parser.add_argument('--lwf_args', type=str2dict, default={'lambda_dist': [1.0, 10.0], 'T': [0.2, 2.0], 'save_every': [10, 100]})
    parser.add_argument('--twp_args', type=str2dict, default={'lambda_l': [100., 10000., 1000000.], 'lambda_t': [100., 10000., 1000000.], 'beta': [0.001, 0.01, 0.1]})
    parser.add_argument('--ewc_args', type=str2dict, default={'memory_strength': [1., 100., 10000., 1000000., 100000000., 10000000000.]})
    parser.add_argument('--mas_args', type=str2dict, default={'memory_strength': [1., 100., 10000., 1000000., 100000000., 10000000000.]})
    parser.add_argument('--slda_args', type=str2dict, default={'Na': [None]})
    parser.add_argument('--bare_args', type=str2dict, default={'Na': [None]})
    parser.add_argument('--joint_args', type=str2dict, default={'Na': [None]})
    parser.add_argument('--center_features', type=strtobool, default=True, help='whether to center online the features extracted by the feature extractor')
    parser.add_argument('--repeats', type=int, default=5, help='how many times to repeat the experiments for the mean and std')
    parser.add_argument('--batch_size', type=int, default=50000, help='batch size for testing')
    parser.add_argument('--sample_nbs', type=strtobool, default=True, help='whether to sample neighbors instead of using all')
    parser.add_argument('--n_nbs_sample', type=lambda x: [int(i) for i in x.replace(' ', '').split(',')], default=[10, 10], 
                        help='number of neighbors to sample per hop, use comma to separate the numbers when using the command line, e.g. 10,25 or 10, 25')
    parser.add_argument('--data_path', type=str, default='./data', help='the path to data')
    parser.add_argument('--result_path', type=str, default='./results', help='the path for saving results')
    parser.add_argument('--save_models', type=strtobool, default=False, help='whether to save models')

    args = parser.parse_args()

    set_seed(args)
    os.makedirs(f'{args.data_path}', exist_ok=True)
    os.makedirs(f'{args.result_path}', exist_ok=True)
    os.makedirs(f'{args.result_path}/logs', exist_ok=True)
    os.makedirs(f'{args.result_path}/models', exist_ok=True)
    os.makedirs(f'{args.result_path}/hyp_params', exist_ok=True)
    os.makedirs(f'{args.result_path}/batch_perf', exist_ok=True)
    os.makedirs(f'{args.result_path}/perf_matrices', exist_ok=True)

    if args.dataset in ['Reddit', 'Arxiv']:
        args.GRNF_args['order_2_prc'] = 0.4

    method_args = {'er': args.er_args, 'lwf': args.lwf_args, 'twp': args.twp_args, 'ewc': args.ewc_args, 'pdgnn': args.pdgnn_args, 'slda': args.slda_args,
                   'bare': args.bare_args, 'agem': args.agem_args, 'mas': args.mas_args, 'joint': args.joint_args, 'ssmer': args.ssmer_args, 'ssmagem': args.ssmagem_args}
    backbone_args = {'GCN': args.GCN_args, 'SGC': args.SGC_args, 'UGCN': args.UGCN_args, 'GRNF': args.GRNF_args, 'UMIXED': args.GRNF_args}
    args.backbone_args = backbone_args[args.backbone]

    experiment_name = f'{format_args(args)}_{list(args.backbone_args.values())}'
    logging.basicConfig(
        filename=f'{args.result_path}/logs/{experiment_name}.log',
        level=logging.INFO,
        format='%(asctime)s - %(message)s'
    )

    graph_dataset = get_graph_dataset(args.dataset, ratio_valid_test=args.ratio_valid_test, args=args)

    # DEFINE THE HYPERPARAMETER SEARCH SPACE
    if args.method == 'joint':
        args.epochs = [25, 50, 100, 250, 500, 1000]
    if args.method == 'slda':
        args.epochs = [1]
        args.lr = [0]
    if args.backbone in ['UGCN', 'GRNF', 'UMIXED']:
        hyp_param_list = compose_hyper_params(method_args[args.method], args.lr, args.epochs, args.gain)
    else:
        hyp_param_list = compose_hyper_params(method_args[args.method], args.lr, args.epochs)

    # HYPERPARAMETER SEARCH
    logging.info(f'Validating {len(hyp_param_list)} hyperparameter combinations')
    AP_best = -1
    for i, hyp_params in enumerate(hyp_param_list):
        assign_hyp_param(args, hyp_params)
        name = f'val_{format_args(args)}_{list(args.backbone_args.values())}_{format_hyp_params(hyp_params)}'
        logging.info(f'hp {i+1}/{len(hyp_param_list)} - {name}')
        AP_hp = []
        for ite in range(args.repeats):
            pipeline = get_pipeline(args)(args, graph_dataset, valid=True)
            AP, _, _ = pipeline.run()
            AP_hp.append(AP)
            del pipeline
            gc.collect()
            torch.cuda.empty_cache()
        AP_hp = np.round(np.mean(AP_hp),2)
        logging.info(f'AP: {AP_hp}\n')
        if AP_hp > AP_best:
            AP_best = AP_hp
            best_hyp_params = hyp_params
        logging.info(f'best params are {best_hyp_params}, best AP is {AP_best}')
    with open(f'{args.result_path}/hyp_params/{experiment_name}.pkl', 'wb') as f:
        pickle.dump(best_hyp_params, f)

    # TRAINING ON FULL STREAM WITH BEST HYPERPARAMETERS
    logging.info('------ Now observing full stream ------')
    assign_hyp_param(args, best_hyp_params)
    name = f'te_{format_args(args)}_{list(args.backbone_args.values())}_{format_hyp_params(best_hyp_params)}'
    perf_matrices = []
    for ite in range(args.repeats):
        logging.info(f'{name} {ite}')
        args.current_model_save_path = f'{args.result_path}/models/{name}_{ite}'
        pipeline = get_pipeline(args)(args, graph_dataset, valid=False)
        AP_test, AF_test, perf_matrix_test = pipeline.run()
        perf_matrices.append(perf_matrix_test)
        del pipeline
        gc.collect()
        torch.cuda.empty_cache()
    with open(f'{args.result_path}/perf_matrices/{name}.pkl'.replace('val', 'te'), 'wb') as f:
        pickle.dump(perf_matrices, f)
