""" Parser Basic Arguments for Training """
import warnings
import os
import shutil
import argparse

import kvikio
import kvikio.defaults

def _framework_parser(parser):
    """ Training type parser """
    parser.add_argument('--use_cache', '--use-cache', action='store_true',
                        help='enable host caching')
    parser.set_defaults(use_cache=False)
    parser.add_argument('--layer_wise_cache', '--layer-wise-cache', action='store_true',
                        help='enable layer-wise cache')
    parser.set_defaults(layer_wise_cache=False)
    parser.add_argument('--checkpointing_strategy', '--checkpointing-strategy', choices=['scattered', 'cpu', 'storage'], default='scattered',
                        help='select checkpointing strategy')
    parser.add_argument('--optimize_dataloader', '--optimize-dataloader', action='store_true',
                        help='optimize dataloader')
    parser.set_defaults(optimize_dataloader=False)
    parser.add_argument('--pre_partitioned', '--pre-partitioned', action='store_true')
    parser.set_defaults(pre_partitioned=False)
    parser.add_argument('--cache_portion', '--cache-portion', type=float, default=0.5,
                        help='gpu cache portion')
    parser.add_argument('--host_mem_size', '--host-mem-size', type=float, default=128.0,
                        help='host memory size (GB)')
    parser.add_argument('--storage_offload', '--storage-offload', action='store_true',
                        help='enable storage offloading')
    parser.set_defaults(storage_offload=False)
    
    # compression engine (act/grad)
    # we support ZFP
    parser.add_argument('--compression', action='store_true', help='enable compression')
    parser.set_defaults(compression=False)
    parser.add_argument('--compression-engine', '--compression_engine', type=str, default='zfp',
                        help='compression engine')

    parser.add_argument('--storage_path', '--storage-path', type=str,
                         default='/pci5_nvme/grinnder_storage',
                         help='storage path')
    parser.add_argument('--kvikio_n_threads', '--kvikio-n-threads', type=int, default=32,
                        help='the number of kvikio threads')
    parser.add_argument('--kvikio_compat', '--kvikio-compat', action='store_true',
                        help='enable kvikio compatibility')
    parser.set_defaults(kvikio_compat=True)
    parser.add_argument('--n_partitions', '--n-partitions', type=int, default=4,
                        help='the number of vertex split (it should be change to automatic method)')
    parser.add_argument('--gpus', type=int, default=[0], nargs='+',
                        help='gpus for training (list)')
    parser.add_argument('--backend', type=str, default='nccl',
                        help='distributed implementation backend')
    parser.add_argument('--master_addr', '--master-addr', type=str, default='127.0.0.1',
                        help='master address for torch.distributed')
    parser.add_argument('--master_port', '--master-port', type=str, default='7524',
                        help='master port for torch.distributed')
    # parser.add_argument('--socket_ifname', '--socket-ifname', type=str, default='lo',
    #                     help='gloo socket name... (e.g., ib0, eno1, lo)')
    return parser

def _train_parser(parser):
    """ Model/Train parser """
    parser.add_argument('--model', type=str, default='gcn', help='Model type')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout')
    parser.add_argument('--lr', type=float, default=0.01, help='lr')
    parser.add_argument('--n-epochs', '--n_epochs', type=int, default=1000, help='# epochs')
    parser.add_argument('--n-hidden', '--n_hidden', type=int, default=256, help='hidden size')
    # parser.add_argument('--n-hidden', '--n_hidden', type=int, default=64, help='hidden size')
    parser.add_argument('--n-conv-layers', '--n_conv_layers', type=int, default=3,
                        help='# conv layers')
    parser.add_argument('--n-linear-layers', '--n_linear_layers', type=int, default=1,
                        help='# lin. layers')
    parser.add_argument("--norm", choices=['layer', 'batch'], default='layer',
                        help="normalization method")
    parser.add_argument("--weight-decay", "--weight_decay", type=float, default=5e-4,
                        help="weight for L2 loss")
    return parser

def _eval_test_timer_parser(parser):
    """ Eval/Test or Time-calc parser """
    # eval related
    parser.add_argument('--eval', action='store_true',
                        help='enable evaluation')
    parser.set_defaults(eval=True)
    parser.add_argument('--no-eval', '--no_eval', action='store_false', dest='eval',
                        help='disable evaluation')
    # test related
    parser.add_argument('--extreme', action='store_true',
                        help='use DGL-based mini-batch testloader for test. (hyper-scale dataset)')
    return parser

def _logging_parser(parser):
    """ Logging-related parser """
    # debug related
    parser.add_argument('--debug', action='store_true', help='print debug msgs.')
    # print related
    parser.add_argument('--verbose', action='store_true', help='print verbose msgs.')
    # log related
    parser.add_argument('--log-every', '--log_every', type=int, default=10)
    # json related
    parser.add_argument('--save-json', '--save_json', action='store_true',
                        help='store log json to path')
    parser.add_argument('--json-path', '--json_path', type=str, default='./json_logs')
    parser.add_argument('--loguru-path', '--loguru_path', type=str, default='')
    # db related
    parser.add_argument('--send-db', '--send_db', action='store_true',
                        help='send log to db')
    parser.add_argument('--db-host', '--db_host', type=str,
                        default='my_host.host', help='db url')
    parser.add_argument('--db-name', '--db_name', type=str, default='grinnder', help='db name')
    parser.add_argument('--db-project', '--db_project', required=False, default='test', help='db pj name')
    parser.add_argument('--db-user', '--db_user', help='db username')
    parser.add_argument('--db-passwd', '--db_passwd', help='db password')
    return parser

def _seed_parser(parser):
    """ seedings """
    parser.add_argument("--fix-seed", "--fix_seed", action='store_true',
                        help="fix random seed")
    parser.add_argument("--seed", type=int, default=7524)
    return parser

def _dataset_partition_parser(parser):
    """ Dataset parser"""
    # dataset-related
    parser.add_argument("--dataset", type=str, default='reddit',
                        help="the input dataset")
    parser.add_argument('--dataset-path', '--dataset_path', default='/small_data/grinnder/',
                        type=str, help='dataset path (plz change this path to your own!)')
    parser.add_argument('--ckpt-path', '--ckpt_path', default='/small_data/grinnder_ckpt/',
                        type=str, help='ckpt path (plz change this path to your own!)')
    parser.add_argument('--use_preprocessed', '--use-preprocessed', default=True, action=argparse.BooleanOptionalAction,
                        help='use preprocessed dataset')
    parser.set_defaults(use_preprocessed=False)
    parser.add_argument("--n-feat", "--n_feat", type=int, default=0)
    parser.add_argument("--n-class", "--n_class", type=int, default=0)
    parser.add_argument("--n-train", "--n_train", type=int, default=0)
    parser.add_argument("--inductive", action='store_true', help="inductive learning setting")
    # partition-related
    parser.add_argument('--partition-path', '--partition_path', type=str,
                        default='./grinnder_partitions', help='partition path')
    parser.add_argument('--partition-obj', '--partition_obj', choices=['vol', 'cut'], default='vol',
                        help="partition objective function ('vol' or 'cut')")
    parser.add_argument('--partition-method', '--partition_method',
                        choices=['metis', 'random', 'spinner', 'grinnder'], default='metis',
                        help="the method for graph partition ('metis', 'spinner', 'grinnder', 'random')")
    parser.add_argument('--partition-refine', '--partition_refine', action='store_true',
                        help='refine partition using grinnder')
    return parser

# Reference: Megatron-LM github
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/arguments.py
def _stringify_args(title, args):
    """Print arguments."""
    argstr = f'\n------------------------ {title} ------------------------\n'
    str_list = []
    for arg in vars(args):
        dots = '.' * (48 - len(arg))
        str_list.append('  {} {} {}\n'.format(arg, dots, getattr(args, arg)))
    for arg in sorted(str_list, key=lambda x: x.lower()):
        argstr += arg
    argstr += f'-------------------- end of {title} ---------------------\n'
    return argstr

def _args_checker(args):
    """ Check args and fix if needed """

    # set kvikio-related settings
    kvikio.defaults.num_threads_reset(args.kvikio_n_threads)
    assert args.kvikio_compat, 'compatibility mode is required'
    kvikio.defaults.compat_mode_reset(args.kvikio_compat) # enable compatibility mode

    # if f"{args.dataset_path}/adj_t" folder not exists, create it
    if not os.path.exists(f"{args.dataset_path}/adj_t"):
        os.makedirs(f"{args.dataset_path}/adj_t")
        

    if args.storage_offload or args.optimize_dataloader:
        # make storage path
        if os.path.exists(args.storage_path):
            shutil.rmtree(args.storage_path)
        os.makedirs(args.storage_path)

    if args.layer_wise_cache:
        assert args.storage_offload, 'layer-wise cache requires storage offloading'

    if args.checkpointing_strategy == 'storage':
        assert args.storage_offload, 'storage checkpointing strategy requires storage offloading'

    if args.use_cache:
        # we need to calculate the cache size
        # args.cache_size: #elements
        n_conv = args.n_conv_layers * args.n_hidden
        n_init = 1 * args.n_hidden

        args.cache_size = int((args.host_mem_size * 1024**3) * args.cache_portion / 4) # 4 bytes for float, 1 for input-only (layer-wise)
        args.cache_size = int(args.cache_size/2) # for both acts / grads
    else:
        args.cache_size = 0

    if args.inductive:
        args.graph_name = f'{args.dataset}-{args.partition_method}-{args.partition_obj}-induc'
    else:
        args.graph_name = f'{args.dataset}-{args.partition_method}-{args.partition_obj}-trans'
    return args

def create_parser():
    """ GriNNder argument parser """
    parser = argparse.ArgumentParser(description='GriNNder')

    # parser = _dist_parser(parser)
    parser = _seed_parser(parser)
    parser = _framework_parser(parser)
    parser = _dataset_partition_parser(parser)
    parser = _train_parser(parser)
    parser = _eval_test_timer_parser(parser)
    parser = _logging_parser(parser)

    args = parser.parse_args()

    args = _args_checker(args)

    argstr = _stringify_args('GriNNder Arguments', args)

    return args, argstr