import os
from argparse import ArgumentParser
import yaml


def parse_args():
    """Command-line argument parser for train."""

    parser = ArgumentParser(
        description='PyTorch implementation'
    )

    parser.add_argument("--subsampling", action="store_true")
    parser.add_argument("--subset", action="store_true")
    parser.add_argument("--use_sampler", action="store_true")
    
    parser.add_argument("--full_batch", type=int, default=200)
    parser.add_argument("--first_order", action="store_true")
    parser.add_argument("--selfsup", action="store_true")
    parser.add_argument('--pr_max', action="store_true")

    parser.add_argument('--param_attack', action="store_true")
    parser.add_argument('--ablation', action="store_true")
    parser.add_argument('--class_attack', help='class_attack', action='store_true')
    parser.add_argument('--img_aug_only', help='img_aug_only', action='store_true')
    parser.add_argument('--no_aug', help='inner_no_aug', action='store_true')
    parser.add_argument('--inner_ablation', help='inner_adapt', action='store_true')
    
    parser.add_argument('--observe', type=str, default=None, help='observe type')
        
    parser.add_argument('--no_wandb', action="store_true")
    parser.add_argument('--entity', type=str, default='wandb', help='your wandb entity')
    parser.add_argument('--wandb_project_name', type=str, default='marvl')
    parser.add_argument('--folder_name', type=str, default='marvl')
    parser.add_argument('--pr_type', type=str, default='mse')

    parser.add_argument('--dataset', help='Dataset',
                        type=str)
    parser.add_argument('--configs', help='config files', type=str)
    parser.add_argument('--mode', help='Training mode',
                        default='maml', type=str)
    parser.add_argument("--seed", type=int,
                        default=0, help='random seed')
    parser.add_argument("--rank", type=int,
                        default=0, help='Local rank for distributed learning')
    parser.add_argument('--distributed', help='automatically change to True for GPUs > 1',
                        default=False, type=bool)
    parser.add_argument('--resume_path', help='Path to the resume checkpoint',
                        default=None, type=str)
    parser.add_argument('--load_path', help='Path to the loading checkpoint',
                        default=None, type=str)
    parser.add_argument("--no_strict", help='Do not strictly load state_dicts',
                        action='store_true')
    parser.add_argument('--suffix', help='Suffix for the log dir',
                        default=None, type=str)
    parser.add_argument('--limit_train_batches', type=int, default=200)
    parser.add_argument('--limit_val_batches', type=int, default=100)
    parser.add_argument('--eval_step', help='Epoch steps to compute accuracy/error',
                        default=2500, type=int)
    parser.add_argument('--save_step', help='Epoch steps to save checkpoint',
                        default=2500, type=int)
    parser.add_argument('--print_step', help='Epoch steps to print/track training stat',
                        default=100, type=int)
    parser.add_argument("--regression", help='Use MSE loss (automatically turns to true for regression tasks)',
                        action='store_true')
    parser.add_argument("--baseline", help='do not save the date',
                        action='store_true')
    parser.add_argument("--ema", help='apply ema',
                        action='store_true')
    parser.add_argument('--eta', help='weight momentum',
                                    default=0.995, type=float)
    parser.add_argument('--barlow', help='apply barlow loss', action='store_true')
    parser.add_argument('--adml', help='apply barlow loss', action='store_true')    
    parser.add_argument('--adv', help='robust meta-training', action='store_true')

    """ Training Configurations """
    parser.add_argument('--inner_steps', help='meta-learning outer-step',
                        default=1, type=int)
    parser.add_argument('--inner_steps_test', help='meta-learning outer-step',
                        default=1, type=int)
    parser.add_argument('--outer_steps', help='meta-learning outer-step',
                        default=60000, type=int)
    parser.add_argument('--max_epochs', help='max_epochs for orig setting', default=4000, type=int)
    parser.add_argument('--lr', type=float, default=1.e-3, metavar='LR', 
                        help='learning rate (absolute lr)')#1e-3
    parser.add_argument('--inner_lr', type=float, default=1.e-2, metavar='LR',
                        help='learning rate of inner gradients') #1e-2
    parser.add_argument('--batch_size', help='Batch size',
                        default=32, type=int)
    parser.add_argument('--test_batch_size', help='Batch size for test loader',
                        default=32, type=int)
    parser.add_argument('--max_test_task', help='Max number of task for inference',
                        default=400, type=int)
    #parser.add_argument('--lam', type=float, default=0.5,
    #                    help='regularization parameter')
    #parser.add_argument('--temp', type=float, default=4.0,
    #                    help='temp scale')
    

    """ Meta Learning Configurations """
    parser.add_argument('--num_ways', help='N ways',
                        default=5, type=int)
    parser.add_argument('--num_shots', help='K (support) shot',
                        default=5, type=int)
    parser.add_argument('--num_shots_test', help='query shot',
                        default=15, type=int)
    parser.add_argument('--num_shots_global', help='global (or distill) shot',
                        default=0, type=int)
    parser.add_argument('--train_num_ways', help='N ways)',
                        default=20, type=int)

    """ Classifier Configurations """
    parser.add_argument('--model', help='model type',
                        type=str, default='conv4')
    
    parser.add_argument('--dynamic_attack', help='attack with sprt params', action='store_true')
    parser.add_argument('--r2d2', help='attack_like_aqr2d2', action='store_true')
    parser.add_argument('--trades_only', help='img_aug_only', action='store_true')
    
    parser.add_argument('--qry_attack', help='query_image_attack', action='store_true')
    parser.add_argument('--loss_type', help='outer_loss adv loss type', default='cec1cec2kl1kl2_cosadvadv', type=str)
    parser.add_argument('--sprt_attack', help='support_image_attack',action='store_true')
    parser.add_argument('--inner_update_type', help='inner_update_type', type=str, default='encoder_only')
    """ Selfsup Configurations """
    parser.add_argument("--selfsup_w", help='weight on selfsup loss', type=float, default=1.0)
    
    parser.add_argument("--aug_type", type=str, default='selfsup')
    # color jitter
    parser.add_argument("--brightness", type=float, default=0.8)
    parser.add_argument("--contrast", type=float, default=0.8)
    parser.add_argument("--saturation", type=float, default=0.8)
    parser.add_argument("--hue", type=float, default=0.2)
    parser.add_argument("--color_jitter_prob", type=float, default=0.8)

    # other augmentation probabilities
    parser.add_argument("--gray_scale_prob", type=float, default=0.4)
    parser.add_argument("--horizontal_flip_prob", type=float, default=0.5)
    parser.add_argument("--gaussian_prob", type=float, default=0.0)
    parser.add_argument("--solarization_prob", type=float, default=0.0)

    # cropping
    parser.add_argument("--crop_size", type=int, default=32)
    parser.add_argument("--min_scale", type=float, default=0.08)
    parser.add_argument("--max_scale", type=float, default=1.0)

    """ Attack Configuration """
    parser.add_argument("--attack_img_num", type=int, default=2)
    
    parser.add_argument("--auto_attack", action="store_true")
    parser.add_argument("--min_val", type=float, default=0.0, help="min for cliping image")
    parser.add_argument("--max_val", type=float, default=1.0, help="max for cliping image")
    parser.add_argument("--attack_type", type=str, default="linf", help="adversarial l_p")
    parser.add_argument("--epsilon", type=float, default=8.0/255.0, help="maximum perturbation of adversaries (8/255 for cifar-10)")
    parser.add_argument("--alpha", type=float, default=2.0/255.0, help="movement multiplier per iteration when generating adversarial examples (2/255=0.00784)")
    parser.add_argument("--advw", type=float, default=6.0, help="weight on TRADES loss")
    parser.add_argument("--max_iters", type=int, default=7, help="maximum iteration when generating adversarial examples")
    parser.add_argument("--random_start", type=bool, default=True, help="True for PGD")
    parser.add_argument("--attack_loss_type", type=str, default="sim", help="loss type for Rep attack")
    
    args = parser.parse_args()
    if args.configs is not None and os.path.exists(args.configs):
        load_cfg(args)

    return args


def load_cfg(args):
    with open(args.configs, "rb") as f:
        cfg = yaml.safe_load(f)

    for key, value in cfg.items():
        args.__dict__[key] = value

    return args
