import argparse
import torch


def parse_args():
    parser = argparse.ArgumentParser(description='--')

    # dataloader related
    parser.add_argument("--data_dir", type=str, default="../../HDD/dataset/")   # revise it!
    parser.add_argument("--save_dir", type=str, default="../../HDD2/raqvae/")   # revise it!
    parser.add_argument("--dataset", type=str, default="CelebA",
                        choices=['cifar10', 'CelebA', 'CelebA_128', 'ImageNet'])
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--batch_size_test", type=int, default=64)
    parser.add_argument("--num_workers", type=int, default=8)

    # model size
    parser.add_argument("--raq_type", type=str, default="mb", choices=['mb', 'dd'], help="mb: model-based RAQ, dd: RAQ")
    parser.add_argument("--model_type", type=str, default="vqvae", choices=['vqvae', 'vqvae2', 'vqgan', 'sqvae', 'simvq', 'rqvae'])
    parser.add_argument("--num_embeddings", type=int, default=256, help="base vocabulary size; number of possible discrete states")
    parser.add_argument("--embedding_dim", type=int, default=64, help="size of the vector of the embedding of each discrete token")
    parser.add_argument("--n_hid", type=int, default=64, help="number of channels controlling the size of the model")

    # Training options
    parser.add_argument('--n_epochs', type=int, default=300, help='number of training epochs')
    parser.add_argument('--lr', type=int, default=5e-4, help='learning rate')
    parser.add_argument('--seed', type=int, default=0, help='training seed: we use 10, 42, 170, 682')
    parser.add_argument('--cuda_ind', type=int, default=0,  help='index for cuda device')

    # Model-based options
    parser.add_argument('--cluster_target', type=int, default=512, help='Codebook clustering taget')
    parser.add_argument('--max_iter', type=int, default=200, help='number of dkm iterations')
    parser.add_argument('--epsilon', type=int, default=1e-8, help='epsilon for softmax function')
    parser.add_argument('--temp', type=int, default=1e-2, help='Softmax temperature of DKM')

    # Data-driven options
    parser.add_argument("--num_embeddings_min", type=int, default=32,
                        help="minimum vocabulary size; number of possible discrete states")
    parser.add_argument("--num_embeddings_max", type=int, default=2048,
                        help="maximum vocabulary size; number of possible discrete states")
    parser.add_argument("--num_embeddings_test", type=int, default=512,
                        help="Test vocabulary size; number of possible discrete states")

    # directory for FID
    parser.add_argument('--img_dir', type=str, default='imgs/')
    args = parser.parse_args()

    args.device = torch.device("cuda:" + str(args.cuda_ind) if torch.cuda.is_available() else "cpu")

    device = args.device
    print("Runnung on CUDA: ", device, "...")

    return args