import pdb
import os
import torch
import torch.nn as nn
import logging
import argparse
import csv
from dl.arguments import build_arguments
from dl.src.trainers.betavae import BetaVAE_Trainer
from dl.src.trainers.betatcvae import BetaTCVAE_Trainer
from dl.src.trainers.factorvae import FactorVAE_Trainer
from dl.src.trainers.clgvae import CLGVAE_Trainer
from dl.src.trainers.cmcs.cmcs_gt import CMCS_GT_Trainger
from dl.src.trainers.cmcs.cmcs_super import CMCS_Super_Trainger
from dl.src.trainers.cmcs.cmcs_semisuper import CMCS_SemiSuper_Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def main():

    # set arguments
    args = build_arguments()

    # trainer list
    TRAINER = {
        'betavae': BetaVAE_Trainer,
        'betatcvae': BetaTCVAE_Trainer,
        'factorvae': FactorVAE_Trainer,
        'clgvae': CLGVAE_Trainer,
        'cmcs_gt': CMCS_GT_Trainger,
        'cmcs_super': CMCS_Super_Trainger,
        'cmcs_semisuper': CMCS_SemiSuper_Trainer,
    }

    # set and run trainer
    trainer = TRAINER[args.model_type](args)
    trainer.run()

    return

if __name__ == "__main__":
    main()