import argparse
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

from omegaconf import OmegaConf
import wandb

from trainer import T2V_DMD_BASELINE_Trainer,  \
    T2V_DMD_GRPO_Trainer, T2V_DMD_DPO_Trainer, \
    T2V_FLOW_DPO_Trainer, T2V_DANCE_GRPO_Trainer

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--no_save", action="store_true")
    parser.add_argument("--no_visualize", action="store_true")
    parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs")
    parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs")
    parser.add_argument("--disable-wandb", action="store_true")

    config = OmegaConf.load(args.config_path)
    config.no_save = args.no_save
    config.no_visualize = args.no_visualize

    config_name = os.path.basename(args.config_path).split(".")[0]
    config.config_name = config_name
    config.logdir = args.logdir
    config.wandb_save_dir = args.logdir + "/wandb"
    config.disable_wandb = args.disable_wandb

    if config.trainer == "t2v_dmd_baseline":
        trainer = T2V_DMD_BASELINE_Trainer(config)
    elif config.trainer == "t2v_dmd_grpo":
        trainer = T2V_DMD_GRPO_Trainer(config)
    elif config.trainer == "t2v_dmd_dpo":
        trainer = T2V_DMD_DPO_Trainer(config)
    elif config.trainer == "t2v_flow_dpo":
        trainer = T2V_FLOW_DPO_Trainer(config)
    elif config.trainer == "t2v_dance_grpo":
        trainer = T2V_DANCE_GRPO_Trainer(config)

    trainer.train()

    wandb.finish()


if __name__ == "__main__":
    main()
