import pdb
import os
import torch
import torch.nn as nn
import logging
import argparse
import csv
from cg.arguments import build_arguments
from cg.src.trainers.betavae import BetaVAE_Trainer
from cg.src.trainers.betatcvae import BetaTCVAE_Trainer
from cg.src.trainers.clgvae import CLGVAE_Trainer
from cg.src.trainers.cmcs.cmcs_gt import CMCS_GT_Trainger
from cg.src.trainers.cmcs.cmcs_super import CMCS_Super_Trainger
from cg.src.trainers.betavae_maganet import BetaMAGAVAE_Trainer
# from cg.src.trainers.cmcs.cmcs_unsuper import CMCS_UnSuper_Trainger

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def main():

    # set arguments
    args = build_arguments()

    # trainer list
    TRAINER = {
        'betavae': BetaVAE_Trainer,
        'betatcvae': BetaTCVAE_Trainer,
        'clgvae': CLGVAE_Trainer,
        'betavae_maganet': BetaMAGAVAE_Trainer,
        'cmcs_gt': CMCS_GT_Trainger,
        'cmcs_super': CMCS_Super_Trainger,
        # 'cmcs_unsuper': CMCS_UnSuper_Trainger,
    }

    # set and run trainer
    trainer = TRAINER[args.model_type](args)
    trainer.run()

    return

if __name__ == "__main__":
    main()