
import argparse
from  dl.src.constants import PRIOR_LIST

def build_arguments():
    parser = argparse.ArgumentParser()

    # SET DEVICE
    parser.add_argument(
        "--device_idx",
        type=str,
        default="cuda:0",
        required=True,
        help="set GPU index, i.e. cuda:0,1,2 ...",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Avoid using CUDA when available"
    )
    parser.add_argument(
        "--n_gpu",
        type=int,
        default=0,
        required=False,
        help="number of available gpu",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank",
    )

    # DATASETS
    parser.add_argument(
        "--dataset",
        type=str,
        choices=[
            "dsprites",
            "shapes3d",
            "car",
            "smallnorb",
            "celeba",
            "cdsprites",
            "mpi3d_toy",
            "mpi3d_real",
            "mpi3d_complex",
            "mmnist",
        ],
        required=True,
        help="Choose Dataset",
    )


    # SET MODEL
    parser.add_argument(
        "--model_type",
        type=str,
        choices=[
            "betavae",
            "factorvae",
            "betatcvae",
            "clgvae",
            "cmcs_gt",
            "cmcs_super",
            "cmcs_semisuper",
            "cmcs_unsuper",
        ],
        required=True,
        help="choose vae type",
    )

    parser.add_argument(
        "--dense_dim",
        nargs="*",
        default=[256, 128],
        type=int,
        required=False,
        help="set CNN hidden FC layers",
    )

    parser.add_argument(
        "--latent_dim",
        type=int,
        default=10,
        required=False,
        help="set prior dimension z",
    )

    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of training mini-batch size",
    )
    parser.add_argument(
        "--per_gpu_train_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of training mini-batch size for multi GPU training",
    )
    parser.add_argument(
        "--test_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of evaluation mini-batch size",
    )
    parser.add_argument(
        "--num_epoch",
        type=int,
        default=60,
        required=False,
        help="Set number of epoch size",
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=0,
        required=False,
        help="Set number of epoch size",
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=100000000000,
        required=False,
        help="Save model checkpoint iteration interval",
    )
    parser.add_argument(
        "--logging_steps",
        type=int,
        default=1000,
        required=False,
        help="Update tb_writer iteration interval",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        required=False,
        help="set seed",
    )

    parser.add_argument(
        "--optimizer",
        choices=["sgd", "adam"],
        default="adam",
        type=str,
        help="Choose optimizer",
        required=False,
    )
    parser.add_argument(
        "--scheduler",
        choices=["const", "linear"],
        default="const",
        type=str,
        help="Whether using scheduler during training or not",
        required=False,
    )

    parser.add_argument(
        "--lr_rate", default=1e-4, type=float, required=False, help="Set learning rate"
    )

    parser.add_argument(
        "--weight_decay",
        default=0.0,
        type=float,
        required=False,
        help="Set weight decay",
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps."
    )
    parser.add_argument(
        "--num_sampling",
        type=int,
        default=1,
        required=False,
        help="Set samples for reparameterization trick",
    )
    # MODEL HYPER-PARAMETERS
    parser.add_argument(
        "--alpha",
        type=float,
        required=False,
        help="Set hyper-parameter alpha",
    )
    parser.add_argument(
        "--beta",
        type=float,
        required=False,
        help="Set hyper-parameter beta",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        required=False,
        help="Set hyper-parameter gamma",
    )
    parser.add_argument(
        "--lamb",
        type=float,
        required=False,
        help="Set hyper-parameter lambda",
    )
    parser.add_argument(
        "--epsilon",
        type=float,
        required=False,
        help="Set hyper-parameter epsilon",
    )

    # Factor VAE
    parser.add_argument(
        "--lr_rate_disc",
        type=float,
        # default=1e-4,
        required=False,
        help="Set discriminator learning rate"
    )

    # CLG VAE
    parser.add_argument(
        "--hy_hes",
        type=float,
        # default=40.0,
        required=False,
        help="Set hyper-parameter for commutative-VAE",
    )
    parser.add_argument(
        "--hy_rec",
        type=float,
        # default=0.1,
        required=False,
        help="Set hyper-parameter for commutative-VAE",
    )
    parser.add_argument(
        "--hy_commute",
        type=float,
        # default=20.0,
        required=False,
        help="Set hyper-parameter for commutative-VAE",
    )
    parser.add_argument(
        "--forward_eq_prob",
        type=float,
        # default=0.2,
        required=False,
        help="Set hyper-parameter for commutative-VAE",
    )
    parser.add_argument(
        "--subgroup_sizes_ls",
        nargs="*",
        # default=[100],
        type=int,
        required=False,
        help="Set hyper-parameter for commutative-VAE",
    )
    parser.add_argument(
        "--subspace_sizes_ls",
        nargs="*",
        # default=[10],
        type=int,
        required=False,
        help="Set hyper-parmaeter for commutative-VAE",
    )
    parser.add_argument(
        "--no_exp",
        action="store_true",
    )


    # CMCS
    parser.add_argument(
        "--nth_root",
        type=int,
        # default="64",
        required=False,
    )
    parser.add_argument(
        "--prior_list",
        nargs="*",
        # default=[3, 6, 40, 32, 32, 10, 10, 10, 10, 10],
        type=int,
        required=False,
        help="set CNN hidden FC layers",
    )

    # MODEL TRAINING AND EVALUATION
    parser.add_argument("--do_train", action="store_true", help="Do training")
    parser.add_argument("--do_eval", action="store_true", help="Do evaluation")
    parser.add_argument("--do_analysis", action="store_true", help="Do analysis")
    parser.add_argument("--evaluate_during_training", action="store_true")

    # DISENTANGLEMENT QUALITATIVE ANALYSIS
    parser.add_argument(
        "--num_disen_train",
        type=int,
        default=10,
        required=False,
        help="set number of disentanglement evaluation task",
    )
    parser.add_argument(
        "--num_disen_test",
        type=int,
        default=10,
        required=False,
        help="set number of disentanglement evaluation task",
    )
    parser.add_argument(
        "--batch_disen",
        type=int,
        default=100,
        required=False,
        help="set batch for Factor VAE disentanglement learning",
    )

    # qualitative analysis
    parser.add_argument(
        "--interval",
        type=int,
        default=10,
        required=False,
        help="Choose the interval for latent vector values",
    )
    parser.add_argument(
        "--quali_sampling",
        type=int,
        default=10,
        required=False,
        help="Set hyper-parameter for samplings on TC-Beta-VAE",
    )

    # SET WANDB
    parser.add_argument(
        "--project_name",
        type=str,
        required=True,
        help="set project name for wiehgt and bias writer",
    )
    parser.add_argument(
        "--entity",
        type=str,
        required=True,
        help="set project name for wiehgt and bias writer",
    )

    args = parser.parse_args()

    if args.model_type == "cmcs_unsuper":
        args.prior_list = PRIOR_LIST[args.dataset]

    return args