import os
from argparse import ArgumentParser
import yaml


def parse_args():
    """Command-line argument parser for train."""

    parser = ArgumentParser(
        description='Official PyTorch implementation of GradNCP'
    )

    parser.add_argument("--seed", type=int,
                        default=0, help='random seed')
    parser.add_argument("--rank", type=int,
                        default=0, help='Local rank for distributed learning')
    parser.add_argument('--distributed', help='automatically change to True for GPUs > 1',
                        default=False, type=bool)
    parser.add_argument('--resume_path', help='Path to the resume checkpoint',
                        default=None, type=str)
    parser.add_argument('--load_path', help='Path to the loading checkpoint',
                        default=None, type=str)
    parser.add_argument('--configs', help='Path to the loading configs',
                        default='./logs', type=str)
    parser.add_argument("--no_strict", help='Do not strictly load state_dicts',
                        action='store_true')
    parser.add_argument('--suffix', help='Suffix for the log dir',
                        default=None, type=str)
    parser.add_argument('--eval_step', help='Epoch steps to compute accuracy/error',
                        default=1000, type=int)
    parser.add_argument('--save_step', help='Epoch steps to save checkpoint',
                        default=50000, type=int)
    parser.add_argument('--print_step', help='Epoch steps to print/track training stat',
                        default=200, type=int)
    parser.add_argument('--tto', help='Test time optimization',
                        default=1, type=int)
    parser.add_argument("--no_date", help='do not save the date',
                        action='store_true', default=True)
    parser.add_argument("--fname", help='filename for the log',
                        default=None, type=str)

    """ Training Configurations """
    parser.add_argument('--incremental', help='Incremental training',
                        action='store_true')
    parser.add_argument('--ewc', help='EWC',
                        action='store_true')
    parser.add_argument('--replay', help='Replay',
                        action='store_true')
    parser.add_argument('--oml', help='OML',
                        action='store_true')
    parser.add_argument('--mask', help='Mask',
                        action='store_true')
    parser.add_argument('--prog', help='Progressive Expansion',
                        action='store_true')
    parser.add_argument('--expansion', help='Expansion',
                        action='store_true')
    parser.add_argument('--expansion_type', help='Expansion type',
                        default='diff', type=str)
    parser.add_argument('--expansion_layer', help='Expansion layer',
                        default=0, type=int)
    parser.add_argument('--order', help='Order of the model',
                        default='colwise', type=str)
    parser.add_argument('--inner_step', help='meta-learning inner-step',
                        default=4, type=int)
    parser.add_argument('--inner_iter', help='inner-step iteration',
                        default=1, type=int)
    parser.add_argument('--inner_lr', help='inner-step learning rate',
                        default=1e-2, type=float)
    parser.add_argument('--lr', help='meta-learning learning rate',
                        default=1e-5, type=float)
    parser.add_argument('--resolution', help='Resolution of the image',
                        default=178, type=int)
    parser.add_argument('--outer_steps', help='meta-learning outer-step',
                        default=500000, type=int)
    parser.add_argument('--max_test_task', help='Max number of task for inference',
                        default=100, type=int)
    parser.add_argument('--lam', type=float, default=1.)

    """ Decoder Configurations """
    parser.add_argument('--w0', type=float, default=30.)

    # sample_type
    parser.add_argument("--data_ratio", help='sampling ratio',
                        default=0.25, type=float)
    parser.add_argument("--sample_type", help='sampling method',
                        default='none', type=str)

    args = parser.parse_args()
    if args.configs is not None and os.path.exists(args.configs):
        load_cfg(args)

    return args


def load_cfg(args):
    with open(args.configs, "rb") as f:
        cfg = yaml.safe_load(f)

    for key, value in cfg.items():
        if key not in args.__dict__.keys():
            args.__dict__[key] = value

    return args
