import argparse

def args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--project_name', type=str, default='A_GLD23K_GLD23K_Lp_vit_25Epk', help='project name on wandb')
    parser.add_argument('--sort_metric', type=str, default='val_accuracy', help='metric to sort wandb runs (choose top n ingredients)')
    parser.add_argument('--greedy', type=int, default=0, help='0 (does standard) and 1 (greedy) adaptive ensembling')
    parser.add_argument('--epoch_idx', type=int, default=10, help='for the Soup VS Adap ENS file')

    # can change from slurm script
    parser.add_argument('--reload', type=int, default=0, help='reload pickled wandb data for project')

    # standard hyperparameters
    parser.add_argument('--p', type=float, default=1.0, help='controls LR schedule via (1+epoch)**p')
    parser.add_argument('--z', type=float, default=2.0, help='controls custom loss via abs(w - x) ** z')
    parser.add_argument('--lr', type=float, default=1.0, help='learning rate')
    parser.add_argument('--num_ingredients', type=int, default=5, help='number of ensemble ingredients')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help="name of dataset")
    parser.add_argument('--eval_dataset', type=str, default='CIFAR10_Lp', help="name of evaluation (e.g. noised) dataset")
    # parser.add_argument('--device', type=str, default='cpu', help='device to train on (overwritten during file run)')
    parser.add_argument('--batch_size', type=int, default=1, help="ensemble batch size")
    parser.add_argument('--epochs', type=int, default=10, help="total epochs of ensembling")
    parser.add_argument('--base_project_name', type=str, default='', help='project name on wandb')
    parser.add_argument('--run_name', type=str, default='', help="Base Run Name for Wandb")
    # Add the shuffletrue argument with 0 or 1
    parser.add_argument('--shuffletrue', type=int, choices=[0, 1], default=0, help='Enable shuffling of the dataset (1 for True, 0 for False)')

    # Consider using at later date
    parser.add_argument('--warmup_epochs', type=int, default=0, help='number of warm-up epochs')
    parser.add_argument('--decay_epochs', type=int, default=0, help='number of learning rate decay epochs at the end')
    parser.add_argument('--cold_lr_warmup', type=float, default=0.0, help='initial lr before warm-up')
    parser.add_argument('--cold_lr_decay', type=float, default=0.0, help='final lr after learning rate decay')
    
    # Optimizer specification
    parser.add_argument('--opt', type=int, default=0, help="choice of optimizer")
    parser.add_argument('--eps', type=float, default=0.001, help='epsilon smoothing term for optimizer')
    parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam/AdaGrad/sgd, first moment EMA')
    parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam, second moment EMA')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')

    # Evaluation specification
    parser.add_argument('--eval_every', type=int, default=3, help="Evaluate metrics every k epochs. Overwritten in code")

    # New hyperparameter list arguments
    parser.add_argument('--lr_list', nargs='+', type=float, default=[1e-4, 1e-3, 1e-2], help='learning rate array')
    parser.add_argument('--eps_list', nargs='+', type=float, default=[1e-9, 1e-7, 1e-5], help='epsilon array')
    parser.add_argument('--beta_1_list', nargs='+', type=float, default=[0.9], help='beta1 array for optimizer')
    parser.add_argument('--beta_2_list', nargs='+', type=float, default=[0.999], help='beta2 array for optimizer')
    parser.add_argument('--warmup_epochs_list', nargs='+', type=int, default=[3], help='warm-up epochs array')
    parser.add_argument('--optimzer_choice_list', nargs='+', type=int, default=[0], help='optimizer choice array')
    parser.add_argument('--weight_decay_list', nargs='+', type=float, default=[0.0, 1e-1], help='weight decay array')
    parser.add_argument('--p_list', nargs='+', type=float, default=[0.0, 0.5, 1.0], help='p array for LR schedule')
    parser.add_argument('--num_ingredients_list', nargs='+', type=int, default=[8, 16], help='num_ingredients array')
    parser.add_argument('--epochs_list', nargs='+', type=int, default=[1, 2], help='epochs array')
    parser.add_argument('--z_list', nargs='+', type=float, default=[1.2, 2.0, 2.5], help='z array for custom loss')
    parser.add_argument('--shuffletrue_list', nargs='+', type=int, default=[0, 1], help='shuffle true array')
    parser.add_argument('--batch_size_list', nargs='+', type=int, default=[1, 2], help='batch size array')
    parser.add_argument('--max_array', type=int, default=300, help='maximum array index')
    parser.add_argument('--slurm_id', type=int, default=0, help='Slurm array job ID')
    
    args = parser.parse_args()

    return args