import argparse

def args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--train_class', type=int, default=0, help='class of training dataset')
    parser.add_argument('--eval_class', type=int, default=1, help='class of OOD evaluation dataset')

    # learning rate warm-up and decay
    parser.add_argument('--lr', type=float, default=0.0, help='learning rate')
    parser.add_argument('--warmup_epochs', type=int, default=0, help='number of warm-up epochs')
    parser.add_argument('--decay_epochs', type=int, default=0, help='number of learning rate decay epochs at the end')
    parser.add_argument('--cold_lr_warmup', type=float, default=0.0, help='initial lr before warm-up')
    parser.add_argument('--cold_lr_decay', type=float, default=0.0, help='final lr after learning rate decay')
    
    # Model and dataset specification
    parser.add_argument('--model', type=str, default='vit', help='model name')
    parser.add_argument('--fine_tune_ver', type=int, default=2, help='Fine-tune 1 (classifier), 2 (NA), 3 (entire))')
    
    # Other arguments
    parser.add_argument('--device', type=str, default='cpu', help='device to train on (overwritten in main.py)')
    parser.add_argument('--batch_size', type=int, default=32, help="batch size")
    parser.add_argument('--epochs', type=int, default=1, help="total epochs of training")
    parser.add_argument('--base_project_name', type=str, default='', help="Base Project Name for Wandb")
    parser.add_argument('--run_name', type=str, default='', help="Base Run Name for Wandb")
    parser.add_argument('--run_name_for_wandb_API', type=str, default='', help="to be used to avoid stability issues in API (when loading model)")
    parser.add_argument('--eval_every', type=int, default=3, help="Evaluate metrics every k epochs")
    parser.add_argument('--dataset', type=str, default='CIFAR10', help="name of dataset")
    parser.add_argument('--eval_dataset', type=str, default='CIFAR10_Lp', help="name of evaluation (e.g. noised) dataset")
    parser.add_argument('--PySeed', type=int, default=0, help="choice of Pytorch random seed")
    parser.add_argument('--DataSeed', type=int, default=0, help="choice of Data generator random seed")
    parser.add_argument('--pretrain', type=int, default=0, help="load pretrained model if 1, not if 0")

    # Optimizer specification
    parser.add_argument('--opt', type=int, default=0, help="choice of optimizer")
    #parser.add_argument('--optimizer', type=str, default='sgd', help='choice of optimizer')
    parser.add_argument('--eps', type=float, default=0.001, help='epsilon smoothing term for optimizer')
    parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam/AdaGrad/sgd, first moment EMA')
    parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam, second moment EMA')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')

    # Some options which are not used here; they are used for different file. But as datasets imports this file, 
    # things can break if these options are not added here. 
    parser.add_argument('--lr_list', nargs='+', type=float, default=[1e-4, 1e-3, 1e-2], help='learning rate array')
    parser.add_argument('--eps_list', nargs='+', type=float, default=[1e-9, 1e-7, 1e-5], help='epsilon array')
    parser.add_argument('--beta_1_list', nargs='+', type=float, default=[0.8, 0.9], help='beta1 array for optimizer')
    parser.add_argument('--beta_2_list', nargs='+', type=float, default=[0.9, 0.999], help='beta2 array for optimizer')
    parser.add_argument('--warmup_epochs_list', nargs='+', type=int, default=[3], help='warm-up epochs array')
    parser.add_argument('--optimzer_choice_list', nargs='+', type=int, default=[0], help='optimizer choice array')
    parser.add_argument('--weight_decay_list', nargs='+', type=float, default=[0.0, 1e-1], help='weight decay array')
    parser.add_argument('--p_list', nargs='+', type=float, default=[0, 0.5, 1], help='p array for LR schedule')
    parser.add_argument('--num_ingredients_list', nargs='+', type=int, default=[2], help='num_ingredients array')
    parser.add_argument('--epochs_list', nargs='+', type=int, default=[1, 2], help='epochs array')
    parser.add_argument('--z_list', nargs='+', type=float, default=[1.2, 2.0, 2.5], help='z array for custom loss')
    parser.add_argument('--shuffletrue_list', nargs='+', type=int, default=[0, 1], help='shuffle true array')
    parser.add_argument('--batch_size_list', nargs='+', type=int, default=[2, 4], help='batch size array')
    parser.add_argument('--max_array', type=int, default=600, help='maximum array index')
    parser.add_argument('--slurm_id', type=int, default=0, help='Slurm array job ID')
    parser.add_argument('--reload', type=int, default=0, help='reload pickled wandb data for project')

    args = parser.parse_args()
    return args
