import torch
from torch.utils.data import default_collate
import time
import argparse
import inspect
import os
from glob import glob
from copy import deepcopy
from tqdm import tqdm
import json
import wandb
from accelerate import Accelerator, DataLoaderConfiguration, DistributedDataParallelKwargs, DeepSpeedPlugin, utils

from models.sound_engine import GNSE_models
from utils.torch_utils import set_seeds
from utils.logging import create_logger
from data.extracted_gamegenx import ExtractedGameGenX_load_hdf5
from data.webdataset_ogamedata import build_dataloader

from modules.stable_audio_tools.models import create_model_from_config
from modules.stable_audio_tools.models.autoencoders import AudioAutoencoder
from modules.stable_audio_tools.models.utils import load_ckpt_state_dict
from modules.stable_audio_tools.utils.torch_common import copy_state_dict


from trainer.lr_scheduler import CosineLRScheduler, WarmupConstantLRScheduler
from utils.ema import update_ema
import torchaudio

from utils.evaluate_metrics import evaluate_generated

from utils.optim_utils import create_optimizer


torch.backends.cuda.enable_flash_sdp = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True 
torchaudio.set_audio_backend("sox_io")


def custom_collate(batch):
    batch = [sample for sample in batch if sample is not None]
    if len(batch) == 0:
        raise ValueError("No valid samples in batch")
    return default_collate(batch)


def build_ema_from_model(model, *, new_seq_len: int, device):
    config_copy = deepcopy(model.config)
    config_copy.seq_len = new_seq_len
    ema = model.__class__(config_copy).to(device)
    ema.load_state_dict(model.state_dict(), strict=False)
    ema.requires_grad_(False)
    return ema

def main(args):
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."
    # breakpoint()
    # Setup DDP:
    amp_type = args.mixed_precision
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
    # )
    accelerator = Accelerator(
        mixed_precision=amp_type, 
        # deepspeed_plugin=deepspeed_plugin, 
        kwargs_handlers=[ddp_kwargs], 
        log_with="wandb",
    )
    dtype = torch.float16 if amp_type == "fp16" else (torch.bfloat16 if amp_type == "bf16" else torch.float32)
    
    device = accelerator.device
    # local_rank = accelerator.local_process_index
    global_rank = accelerator.process_index
    world_size = accelerator.num_processes
    seed = args.global_seed * world_size + global_rank
    set_seeds(seed)
    print(f'Device (rank {global_rank} / {world_size}): {device}')
    
    # Setup an experiment folder:
    experiment_dir = None
    checkpoint_dir = None
    if accelerator.is_main_process:
        os.makedirs(args.results_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
        
        experiment_index = len(glob(f"{args.results_dir}/*"))
        model_string_name = args.gnse_model.replace("/", "-")
        
        experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}"
        checkpoint_dir = f"{experiment_dir}/checkpoints"
        os.makedirs(checkpoint_dir, exist_ok=True)

    accelerator.wait_for_everyone()

    experiment_info = [experiment_dir, checkpoint_dir]
    utils.broadcast_object_list(experiment_info)
    experiment_dir, checkpoint_dir = experiment_info 
    print("->->-> DDP Initialized.")
    print(f"->->-> World size (Number of GPUs): {world_size}")


    logger = create_logger(experiment_dir if accelerator.is_main_process else None, accelerator)
    if accelerator.is_main_process:
        logger.info(f"Experiment directory created at {experiment_dir}")
    
    
    # training args
    logger.info(f"{args}")
    
    # wandb
    model_string_name = args.gnse_model.replace("/", "-")
    timestamp = int(time.strftime("%Y%m%d%H%M%S", time.localtime()))
    accelerate_init_kwargs = {
        "wandb": {
            "name": f"{timestamp}-{model_string_name}",
            "dir":  experiment_dir if experiment_dir else None
        }
    }

    accelerator.init_trackers(
        project_name=args.wandb_project,
        config=vars(args),
        init_kwargs=accelerate_init_kwargs
    )
    # training env
    logger.info(f"Starting rank={global_rank}, seed={seed}, world_size={world_size}.")
    # Setup data:
    batch_size = args.batch_per_gpu
    
    train_dataloader, grid_feature_length, n_samples = build_dataloader(
        csv_path = args.train_csv_path, 
        shard_dir = args.shard_dir, 
        shard_pattern = args.shard_pattern,
        batch_size = batch_size, 
        num_workers = args.num_workers, 
        seed = seed,
        audio_pretraining = args.audio_pretraining,
        shuffle_size = args.shuffle_size,
        drop_last = True,
        pin_memory = True, 
        persistent_workers = True,
        prefetch_factor = 1,
        start_time = args.start_time,
        duration = args.duration,
        video_fps = args.video_fps,
        n_epochs= args.epochs,
        data = args.data,
        video_embed_dim = args.video_embed_dim,
    )
    # n_samples = len(train_dataloader) * batch_size
    logger.info(f"Training dataset contains ≈ {n_samples} samples")

    grid_feature_length = grid_feature_length // args.spatial_ds_rate**2

    
    # Setup model
    model = GNSE_models[args.gnse_model](
        seq_len = int(args.duration * args.video_fps),
        num_kv_heads = args.num_kv_heads,
        multiple_of = args.multiple_of,
        per_frame_len = args.per_frame_len,
        ffn_dim_multiplier = args.ffn_dim_multiplier,
        max_period = args.max_period,
        num_types = args.num_types,
        input_type = args.input_type,        
        type_drop_p = args.type_dropout_p,
        token_dropout_p = args.token_dropout_p,
        attn_dropout_p = args.attn_dropout_p,
        resid_dropout_p = args.resid_dropout_p,
        ffn_dropout_p = args.ffn_dropout_p, 
        drop_path_rate = args.drop_path_rate,
        audio_embed_dim = args.audio_embed_dim,
        video_embed_dim = args.video_embed_dim,
        diffusion_batch_mul = args.diffusion_batch_mul,
        P_mean = args.P_mean,
        P_std = args.P_std,
        sigma_data = args.sigma_data,
        label_drop_prob = args.label_drop_prob,
        label_balance = args.label_balance,
        trans_cfg_dropout_prob = args.trans_cfg_dropout_prob,
        audio_pretraining = args.audio_pretraining,
        vision_aggregation = args.vision_aggregation, 
        num_aggregated_tokens = args.num_aggregated_tokens,
        grid_feature_length = grid_feature_length,
        d_aggregate = args.d_aggregate,
        num_heads_aggregate = args.num_heads_aggregate,
        num_layers_aggregate = args.num_layers_aggregate,
        aggregate_trans_architecture = args.aggregate_trans_architecture,
        aggregation_method = args.aggregation_method,
        norm_type=args.norm_type,
        norm_type_agg=args.norm_type_agg,
        naive_mar_mlp = args.naive_mar_mlp,
        condition_merge = args.condition_merge,
        head_dropout_p = args.head_dropout_p,
        diffusion_type = args.diffusion_type,
        no_subtract = args.no_subtract,
        spatial_ds_rate = args.spatial_ds_rate,
        disperse_loss = args.disperse_loss,
        disperse_lambda = args.disperse_lambda
    ).to(device)
    
    logger.info(f"Prior model parameters: {sum(p.numel() for p in model.parameters()):,}")
    # logger.info(f"video_proj_preaggregate model parameters: {sum(p.numel() for p in model.video_proj_preaggregate.parameters()):,}")
    logger.info(f"Aggregator model parameters: {sum(p.numel() for p in model.aggregate_transformer.parameters()):,}")
    logger.info(f"diffusion head model parameters: {sum(p.numel() for p in model.diffloss.diffusion_head.parameters()):,}")
    # breakpoint()
    ema = None
    if args.ema:
        ema = build_ema_from_model(model, new_seq_len=int(args.eval_duration * args.video_fps), device=device)
        # ema = deepcopy(model).to(device).requires_grad_(False)  # Create an EMA of the model for use after training
        logger.info(f"EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}")
    
    
    # create and load model
    with open(args.stage1_model_config) as f:
        stage1_model_config = json.load(f)
    stage1_model: AudioAutoencoder = create_model_from_config(stage1_model_config)
    copy_state_dict(stage1_model, load_ckpt_state_dict(args.stage1_ckpt_path))
    stage1_model.to(device).eval().requires_grad_(False)
    
    # Setup optimizer
    optimizer = create_optimizer(model, args.weight_decay, args.lr, (args.beta1, args.beta2), logger, args.optimizer_type)
        
    if args.max_train_steps is not None and args.max_train_steps > 0:
        total_steps = args.max_train_steps
        logger.info(f"Training for a total of {total_steps} steps.")
    else:
        train_updates_per_epoch = n_samples
        total_epochs = int(args.epochs)
        total_steps = train_updates_per_epoch * total_epochs
        logger.info(f"Training for {args.epochs} epochs, estimated {total_steps} steps.")
        
    if args.optimizer_type == "adamw":
        if args.lr_scheduler == "cosine":
            lr_scheduler = CosineLRScheduler(
                optimizer=optimizer, 
                total_steps=total_steps, 
                warmup_steps=args.warmup_steps, 
                lr_min_ratio=args.lr_min_ratio, 
                cycle_length=args.cycle_length,
            )
        elif args.lr_scheduler == "constant":
            lr_scheduler = WarmupConstantLRScheduler(
                optimizer=optimizer, 
                total_steps=total_steps, 
                warmup_steps=args.warmup_steps,
                # lr_min_ratio=args.lr_min_ratio, 
            )
        
    
    # Prepare models for training:
    if args.ema:
        update_ema(ema, model, decay=0)  # Ensure EMA is initialized with synced weights
    
    if args.z_stats_path is not None:
        stats = torch.load(args.z_stats_path, map_location="cpu")
        # logger.info(f"Loading z-stats from checkpoint at {args.z_stats_path}")
        z_mean = stats["z_mean"]
        z_std = stats["z_std"]
    else:
        AssertionError("z_stats_path is required")

    model.to(device)
    if not args.no_compile:
        logger.info("compiling the model... (may take several minutes)")
        model = torch.compile(model) # requires PyTorch 2.0
    model.train()
    if args.ema:
        ema.eval()  # EMA model should always be in eval mode

    train_steps = 0
    start_epoch = 0
    
    (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare(
        model, train_dataloader, optimizer, lr_scheduler)

    z_mean = z_mean.to(device).to(dtype).view(1, 1, -1) 
    z_std = z_std.to(device).to(dtype).view(1, 1, -1)
    
    
    
    if args.resume_from_checkpoint:
        accelerator.load_state(args.resume_from_checkpoint)
        if args.ema:
            ema_checkpoint = torch.load(args.ema_ckpt_path, map_location="cpu")
            ema.load_state_dict(ema_checkpoint["ema_model"])
        resume_step = ema_checkpoint["steps"]
        start_epoch = int(resume_step / int(n_samples / int(batch_size * world_size)))
        resume_step -= start_epoch * int(n_samples / int(batch_size * world_size))
        train_steps = resume_step
        # train_steps = int(start_epoch * int(len(train_dataset) / int(batch_size * world_size)))
        del ema_checkpoint
        logger.info(f"Resume training from checkpoint: {args.resume_from_checkpoint}")
        logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}")
    
    
    
    
    # Grad Scaler is used to automatically in accelerate when fp16
    # scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16'))
    # Variables for monitoring/logging purposes:
    log_steps = 0
    # running_loss = 0
    running_logvar = 0
    running_total_loss = 0
    running_disperse_loss = 0
    running_weighted_dsm_loss = 0
    running_dsm_loss = 0
    start_time = time.time()

    optimizer.train()
    logger.info(f"Training for {args.epochs} epochs...")
    
    
    for epoch in range(start_epoch, args.epochs):
        logger.info(f"Beginning epoch {epoch}...")
        if args.resume_from_checkpoint and epoch == start_epoch and resume_step is not None:
            # We need to skip steps until we reach the resumed step
            active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
        else:
            # After the first iteration though, we need to go back to the original dataloader
            active_dataloader = train_dataloader
        # if train_steps >= total_steps:
        #         break

        for batch in active_dataloader:
            # if train_steps >= total_steps:
            #     break
            # logger.info(f"Loading batch...")
            # prof.step()
            z = batch["audio_feature"].to(device)  # (bs, seq_len, audio_embed_dim)
            z = (0.5 * ((z - z_mean) / z_std)).to(dtype)
            if not args.audio_pretraining:
                video = batch["video_feature"].to(device).to(dtype)
            else:
                video = None # audio only pretraining model
            # logger.info(f"before forward...")
            with accelerator.autocast():
                total_loss, weighted_dsm_loss, dsm_loss, logvar, disperse_loss = model(video=video, audio=z, target=z)
            
            # logger.info(f"before backprop") 
            
            accelerator.backward(total_loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            if accelerator.is_main_process and args.ema:
                # unwrapped_model = accelerator.unwrap_model(model)
                # update_ema(ema, model._orig_mod, decay=args.ema_rate)
                if world_size > 1:
                    update_ema(ema, model.module._orig_mod if not args.no_compile else model.module, decay=args.ema_rate)
                else:
                    update_ema(ema, model._orig_mod if not args.no_compile else model.module, decay=args.ema_rate)
        

            # Log loss values:
            running_logvar += logvar.item()
            running_weighted_dsm_loss += weighted_dsm_loss.item()
            running_dsm_loss += dsm_loss.item()
            running_total_loss += total_loss.item()
            running_disperse_loss += disperse_loss.item()

            log_steps += 1
            train_steps += 1
            
            # if accelerator.sync_gradients:
            if train_steps % args.log_every == 0:
                # Measure training speed:
                # accelerator.wait_for_everyone()
                end_time = time.time()
                steps_per_sec = log_steps / (end_time - start_time)

                stats_gpu = torch.tensor(
                    [
                        running_logvar / log_steps,
                        running_weighted_dsm_loss / log_steps,
                        running_dsm_loss / log_steps,
                        running_total_loss / log_steps,
                        running_disperse_loss / log_steps,
                    ],
                    device=accelerator.device
                )
                stats_all = accelerator.reduce(stats_gpu, reduction="mean")
                
                if accelerator.is_main_process:
                    avg_logvar, avg_weighted, avg_dsm, avg_total, avg_disperse = stats_all.cpu().tolist()
                    logger.info(
                        f"(step={train_steps:07d}) "
                        f"Train Loss: {avg_total:.4f}, "
                        f"Train Steps/Sec: {steps_per_sec:.2f}"
                    )
                    wandb.log({
                        "train_weighted_dsm_loss": avg_weighted,
                        "train_dsm_loss":        avg_dsm,
                        "train_logvar":          avg_logvar,
                        "train_total_loss":      avg_total,
                        "train_disperse_loss":   avg_disperse,
                        "train_steps_per_sec":   steps_per_sec,
                    }, step=train_steps)

                running_logvar = 0
                running_weighted_dsm_loss = 0
                running_dsm_loss = 0
                running_total_loss = 0
                running_disperse_loss = 0
                log_steps = 0
                start_time = time.time()

            # Save checkpoint:
            if train_steps % args.ckpt_every == 0 and train_steps > 0:
                optimizer.eval()
                accelerator.wait_for_everyone()
                checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
                accelerator.save_state(checkpoint_path)
                
                if accelerator.is_main_process:
                    logger.info(f"Saved checkpoint to {checkpoint_path}")
                    if args.ema:
                        ema_model_weight = ema.state_dict()
                        checkpoint = {
                        "ema_model": ema_model_weight,
                        "steps": train_steps,
                        "args": args
                    }
                    # if not args.no_local_save:
                    ema_checkpoint_path = f"{checkpoint_dir}/ema_{train_steps:07d}.pt"
                    accelerator.save(checkpoint, ema_checkpoint_path)
                    logger.info(f"Saved ema checkpoint to {ema_checkpoint_path}")
                optimizer.train()
                accelerator.wait_for_everyone()
            
        logger.info("Done!")
        accelerator.end_training()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--train-csv-path", type=str, required=True)
    # parser.add_argument("--val-csv-path", type=str, required=True)
    parser.add_argument("--shard_dir", type=str, required=True)
    parser.add_argument("--shard_pattern", type=str, required=True)
    parser.add_argument("--shuffle_size", type=int, default=1_000)
    parser.add_argument("--premade-feature-dir-audio", type=str, required=True)
    parser.add_argument("--z-stats-path", type=str, required=True, help="z_stats path for training")
    parser.add_argument("--max-train-steps", type=int, default=100000)
    parser.add_argument("--results-dir", type=str, default="results")
    parser.add_argument("--wandb-project", type=str, default="project")
    parser.add_argument("--error-log-path", type=str, default='error_deepspeed.txt', help="error log path in dataloder")
    parser.add_argument("--val-error-log-path", type=str, default='error_deepspeed.txt', help="error log path in dataloder")
    parser.add_argument("--reference-folder", type=str)
    parser.add_argument("--eval-folder", type=str)

    parser.add_argument("--resume-from-checkpoint", type=str, default=None, help="ckpt path for resume training")
    parser.add_argument("--ema-ckpt-path", type=str, default=None, help="EMA ckpt path for resume training")
    parser.add_argument("--log-every", type=int, default=1)
    parser.add_argument("--ckpt-every", type=int, default=5000)
    parser.add_argument("--mixed-precision", type=str, default='no', choices=["no", 'fp8', "fp16", "bf16"]) 
    parser.add_argument("--stage1-model-config", type=str, required=True, help="model config of stage1 model")
    parser.add_argument("--stage1-ckpt-path", type=str, required=True, help="ckpt path for stage1 model")
    parser.add_argument("--clap-model-path", type=str, required=True, help="ckpt path for laion-clap")

    parser.add_argument("--audio-embed-dim", type=int, default=64, help="dimension for audio compression model")
    parser.add_argument("--video-embed-dim", type=int, default=384, help="dimension for video feature")
    parser.add_argument("--video-fps", type=int, default=30, help="frame rate of video")
    parser.add_argument("--channels", type=int, default=2, help="audio channels")
    parser.add_argument("--sample-rate", type=int, default=48000, help="sample rate of audio")
    parser.add_argument("--start-time", type=float, default=0, help="start time of dataset")
    parser.add_argument("--duration", type=float, default=8, help="duration of dataset")
    
    parser.add_argument("--batch-per-gpu", type=int, default=12)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--num-workers", type=int, default=8)
    parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument("--optimizer_type", default="adamw", type=str, choices=["adamw"], help="Optimizer type.")
    parser.add_argument("--epochs", type=int, default=400)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--weight-decay", type=float, default=5e-2, help="Weight decay to use.")
    parser.add_argument("--beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--beta2", type=float, default=0.99, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--warmup-steps", default=4000, type=int, help="warmup steps on lr scheduler.")
    parser.add_argument("--lr-min-ratio", default=1e-6, type=float, help="minimum learning rate on lr scheduler.")
    parser.add_argument("--cycle-length", default=1.0, type=float, help="cycle length on lr scheduler.")
    parser.add_argument("--lr-scheduler", default="cosine", type=str, choices=["cosine", "constant"], help="lr scheduler type.")
    parser.add_argument("--no-compile", action='store_true')

    parser.add_argument("--gnse-model", type=str, choices=list(GNSE_models.keys()), default="GNSE-B")
    parser.add_argument("--per-frame-len", type=int, default=1)
    parser.add_argument("--num-types", type=int, default=2)
    parser.add_argument("--input-type", type=str, choices=['interleave', 'concat'], default="interleave")
    parser.add_argument("--max-period", type=float, default=10_000)
    parser.add_argument("--num-kv-heads", type=int, default=None)
    parser.add_argument("--drop-path-rate", type=float, default=0.1, help="drop_path_rate of attention and ffn")
    parser.add_argument("--token-dropout-p", type=float, default=0., help="dropout_p of token_dropout_p")
    parser.add_argument("--type-dropout-p", type=float, default=0., help="dropout_p of type_dropout_p")
    parser.add_argument("--resid-dropout-p", type=float, default=0.1, help="dropout_p of resid_dropout_p")
    parser.add_argument("--ffn-dropout-p", type=float, default=0.1, help="dropout_p of ffn_dropout_p")
    parser.add_argument("--attn-dropout-p", type=float, default=0.1, help="dropout_p of attn_dropout_p")
    parser.add_argument("--ffn-dim-multiplier", type=float, default=None, help="ffn_dim_multiplier")
    parser.add_argument("--multiple-of", type=int, default=256, help="multiple_of")
    parser.add_argument("--trans-cfg-dropout-prob", type=float, default=0.1, help="dropout probability on transformer for cfg sampling")
    parser.add_argument("--head-dropout-p", type=float, default=0.1)
    
    parser.add_argument("--diffusion-batch-mul", type=int, default=4)
    parser.add_argument("--label-balance", type=float, default=0.5)
    parser.add_argument("--P-mean", type=float, default=-0.4)
    parser.add_argument("--P-std", type=float, default=1.0)
    parser.add_argument("--sigma-data", type=float, default=0.5)
    parser.add_argument("--label-drop-prob", type=float, default=0.1, help="dropout probability on diffusion head for cfg sampling")
    parser.add_argument("--norm-type", type=str, default='rms', choices=['rms','dyt'], help="choose normalization type")
    parser.add_argument("--diffusion-type", type=str, default='diffusion', choices=['diffusion','trigflow', 'diffusion_v2'], help="choose diffusion model type")

    parser.add_argument("--ema", action='store_true')
    parser.add_argument("--no_subtract", action='store_true')
    parser.add_argument("--ema-rate", type=float, default=0.9999)
    parser.add_argument("--noise-augmentation", action='store_true')
    parser.add_argument("--audio-pretraining", action='store_true', help="flag for audio only pretraining")
    parser.add_argument("--vision-aggregation", action='store_true', help="flag for vision aggregation along with spatial axis.")
    parser.add_argument("--num-aggregated-tokens", type=int, default=1, help="number of aggregated tokens for vision aggregation")
    parser.add_argument("--d-aggregate", type=int, default=512, choices=[64, 128, 256, 384, 512, 768, 1024], help="number of feature dimension of attention blocks for vision aggregation")
    parser.add_argument("--num-heads-aggregate", type=int, default=8, help="number of heads of attention blocks for vision aggregation")
    parser.add_argument("--num-layers-aggregate", type=int, default=1, help="number of attention blocks for vision aggregation")
    parser.add_argument("--spatial-ds-rate", type=int, default=2, help="Downsampling rate of conv2d in vision aggregation")
    parser.add_argument("--aggregate-trans-architecture", type=str, default='sdpa', choices=['linear','sdpa'],help="Use linear transformer or sdpa for vision aggregation")
    parser.add_argument("--aggregation-method", type=str, default='self', choices=['self','cross'],help="select self or cross attention for vision aggregation")
    parser.add_argument("--norm-type-agg", type=str, default='rms', choices=['rms','dyt'], help="choose normalization type for aggregation transformer")
    parser.add_argument("--naive-mar-mlp", action='store_true', help="Use MAR's diffusion head architecture")
    parser.add_argument("--condition-merge", action='store_true', help="Maerge different conditon into diffusion head.")

    parser.add_argument("--eval-duration", type=float, default=10, help="duration of dataset")
    # parser.add_argument("--video-h5-val-dir", type=str, required=True)
    parser.add_argument('--val_csv_path', type=str, default=None)
    parser.add_argument("--eval-every-steps", type=int, default=5000)
    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--trans-cfg-scale", type=float, default=1.0, help="CFG for transformer")
    parser.add_argument("--cfg-interval", type=int, default=-1, help="CFG interval for transformer")
    parser.add_argument("--cfg-schedule", type=str, default="constant", choices=["constant", "linear"], help="CFG scheduler for diffusion head with time axis")
    parser.add_argument("--diff-cfg-scale", type=float, default=1.0, help="CFG for diffusion head")
    parser.add_argument("--num-diffusion-steps", type=int, default=30, help="number of sampling steps for diffusion")
    parser.add_argument("--sigma-min", type=float, default=0.002, help="sigma_min")
    parser.add_argument("--sigma-max", type=float, default=80.0, help="sigma_max")
    parser.add_argument("--rho", type=float, default=7.0, help="rho")
    parser.add_argument("--S-churn", type=int, default=0, help="S_churn")
    parser.add_argument("--S-min", type=int, default=0, help="S_min")
    parser.add_argument("--S-max", type=float, default=float('inf'), help="S_max")
    parser.add_argument("--S-noise", type=int, default=1, help="S_noise")
    parser.add_argument("--mid-t", type=float, default=None)
    parser.add_argument("--disperse_loss", action='store_true')
    parser.add_argument("--disperse_lambda", type=float, default=0.25)
    parser.add_argument("--data", type=str, default='ogamedata')
    
    args = parser.parse_args()
    main(args)