import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import logging
import argparse

from disent.trainer.jcgel_betavae import CGEL_BetaVAE_Trainer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def build_arguments():
    parser = argparse.ArgumentParser()
    # SET DEVICE
    parser.add_argument(
        "--device_idx",
        type=str,
        default="cuda:0",
        required=True,
        help="set GPU index, i.e. cuda:0,1,2 ...",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Avoid using CUDA when available"
    )
    parser.add_argument(
        "--n_gpu",
        type=int,
        default=0,
        required=False,
        help="number of available gpu",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank",
    )


    # DATASETS
    parser.add_argument(
        "--dataset",
        type=str,
        choices=[
            "shapes3d",
            "mpi3d_toy",
        ],
        required=True,
        help="Choose Dataset",
    )

    # SET MODEL
    parser.add_argument(
        "--model_type",
        type=str,
        choices=[
            'cgebetavae'
        ],
        default='cgebetavae',
        required=False,
        help="choose vae type",
    )

    parser.add_argument(
        "--dense_dim",
        nargs="*",
        default=[256, 256],
        type=int,
        required=False,
        help="set CNN hidden FC layers",
    )

    parser.add_argument(
        "--latent_dim",
        type=int,
        default=10,
        required=False,
        help="set prior dimension z",
    )

    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of training mini-batch size",
    )
    parser.add_argument(
        "--per_gpu_train_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of training mini-batch size for multi GPU training",
    )
    parser.add_argument(
        "--test_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of evaluation mini-batch size",
    )
    parser.add_argument(
        "--num_epoch",
        type=int,
        default=60,
        required=False,
        help="Set number of epoch size",
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=0,
        required=False,
        help="Set number of epoch size",
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=100000000000,
        required=False,
        help="Save model checkpoint iteration interval",
    )
    parser.add_argument(
        "--logging_steps",
        type=int,
        default=1000,
        required=False,
        help="Update tb_writer iteration interval",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        required=False,
        help="set seed",
    )


    parser.add_argument(
        "--lr_rate", default=1e-4, type=float, required=False, help="Set learning rate"
    )

    parser.add_argument(
        "--weight_decay",
        default=0.0,
        type=float,
        required=False,
        help="Set weight decay",
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps."
    )
    parser.add_argument(
        "--num_sampling",
        type=int,
        default=1,
        required=False,
        help="Set samples for reparameterization trick",
    )
    parser.add_argument(
        "--beta",
        type=float,
        default=1.0,
        required=False,
        help="Set hyper-parameter beta",
    )


    # MODEL TRAINING AND EVALUATION
    parser.add_argument("--do_train", action="store_true", help="Do training")
    parser.add_argument("--do_eval", action="store_true", help="Do evaluation")


    parser.add_argument("--c_rot",
                        type=int,
                        default=3,
                        required=False,)
    parser.add_argument("--g_rot",
                        type=int,
                        default=4,
                        required=False,)
    parser.add_argument("--n_flip",
                        type=int,
                        default=0,
                        required=False,)
    parser.add_argument("--temperature",
                        type=float,
                        default=0.01,
                        required=False,)
    parser.add_argument("--normalization", action="store_true")
    parser.add_argument("--soft", action="store_true")


    # DISENTANGLEMENT QUALITATIVE ANALYSIS
    parser.add_argument(
        "--num_disen_train",
        type=int,
        default=10,
        required=False,
        help="set number of disentanglement evaluation task",
    )
    parser.add_argument(
        "--num_disen_test",
        type=int,
        default=10,
        required=False,
        help="set number of disentanglement evaluation task",
    )
    parser.add_argument(
        "--batch_disen",
        type=int,
        default=100,
        required=False,
        help="set batch for Factor VAE disentanglement learning",
    )

    # qualitative analysis
    parser.add_argument(
        "--interval",
        type=int,
        default=10,
        required=False,
        help="Choose the interval for latent vector values",
    )
    parser.add_argument(
        "--quali_sampling",
        type=int,
        default=10,
        required=False,
        help="Set hyper-parameter for samplings on TC-Beta-VAE",
    )


    # SET WANDB
    parser.add_argument(
        "--project_name",
        type=str,
        required=True,
        help="set project name for wiehgt and bias writer",
    )
    parser.add_argument(
        "--entity",
        type=str,
        required=True,
        help="set project name for wiehgt and bias writer",
    )

    args = parser.parse_args()

    return args


def main():

    # set arguments
    args = build_arguments()

    trainer = CGEL_BetaVAE_Trainer(args)
    trainer.run()

    return

if __name__ == "__main__":
    main()


