import argparse


def get_args():
    parser = argparse.ArgumentParser(description="Train SVRG/SGD on MNIST data.")
    # only sgd and svrg are supported
    parser.add_argument('--optimizer', type=str, default="SGD", choices=["SGD", "SVRG"],
                        help="optimizer.")
    parser.add_argument('--nn_model', type=str, default="MNIST_one_layer",
                        help="neural network model.")
    parser.add_argument('--dataset', type=str, default="MNIST",
                        help="neural network model.")
    parser.add_argument('--n_epoch', type=int, default=1000,
                        help="number of training iterations.")
    parser.add_argument('--lr', type=float, default=0.001,
                        help="learning rate.")
    parser.add_argument('--batch_size', type=int, default=64,
                        help="batch size.")
    parser.add_argument('--weight_decay', type=float, default=0.0,
                        help="regularization strength.")
    parser.add_argument('--exp_name', type=str, default="",
                        help="name of the experiment.")
    parser.add_argument('--print_every', type=int, default=1,
                        help="how often to print the loss.")
    parser.add_argument('--ratio', type=float, default=1.0,
                        help="how much of the data to use.")
    parser.add_argument('--temperature', type=float, default=0.5,
                        help="temperature for softmax.")
    parser.add_argument('--loss_type', type=str, default="cross_entropy", 
                        help="loss function.")
    parser.add_argument('--output_dir', type=str, default="outputs",
                        help="output directory.")
    parser.add_argument('--device', type=str, default="cpu")
    parser.add_argument('--log', action='store_true',
                        default=False,
                        help="whether log the results.")
    parser.add_argument('--wandb', action='store_true', 
                        default=False,
                        help="whether to use wandb for logging.")
    parser.add_argument('--wandb_project', type=str, default="svrg_sgd_cifar",
                        help="wandb project name.")
    parser.add_argument('--wandb_entity', type=str, default=None,
                        help="wandb entity (team) name. If None, it will use the default entity.")

    
    
    #TODO: add more arguments
    parser.add_argument('--min_lr', type=float, default=0.0001,
                        help="minimum learning rate.")
    
    
    return parser.parse_args()