"""
Train a Latent consistency trajectory model on audios.
"""
import time
import copy
import argparse
import json
import logging
import math
import os
import wandb
import numpy as np
import pandas as pd
import torch.nn as nn
from accelerate import Accelerator, DistributedDataParallelKwargs, DeepSpeedPlugin
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
# from tqdm.auto import tqdm
from datasets import load_dataset
import datasets
import diffusers
from diffusers.utils.import_utils import is_xformers_available
from packaging import version
import transformers
from transformers import SchedulerType
import random

# from cm import dist_util
# from cm.image_datasets import load_data
from cm.script_util_v3 import (
    create_model_and_diffusion,
    ctm_train_defaults,
    ctm_eval_defaults,
    add_dict_to_argparser,
    create_ema_and_scales_fn,
    others_defaults,
    gan_defaults,
)
from cm.train_util_light_2 import CMTrainLoop
import torch as th
from tango_edm.models_edm import build_pretrained_models
from dac_dev.model.discriminator import Discriminator as DAC_GAN_Discriminator
from dac_dev.model.san_discriminator import SAN_Discriminator as DAC_SAN_Discriminator
from dac_dev.model.discriminator import ConditionalDiscriminator as DAC_GAN_CondDiscriminator
from latent_discriminator.vqgan_discriminator import NLayerDiscriminator as VQGAN_Discriminator
from latent_discriminator.vqgan_discriminator import ConditionalNLayerDiscriminator as CVQGAN_Discriminator

from latent_discriminator.mel_vqgan_discriminator import NLayerDiscriminator as MelVQGAN_Discriminator
from latent_discriminator.mel_vqgan_discriminator import ConditionalNLayerDiscriminator as MelCVQGAN_Discriminator
from latent_discriminator.mel_mb_discriminator import MBDiscriminator
from latent_discriminator.mel_mb_discriminator import ConditionalMBDiscriminator

logger = get_logger(__name__)
def rand_fix(seed):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    th.backends.cudnn.benchmark = False
    th.backends.cudnn.deterministic = True
    

def create_argparser():
    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        "--seed", type=int, default=5031, #43
        help="A seed for reproducible training."
    )
    parser.add_argument(
        "--tango", action="store_true", default=False,
        help="A seed for reproducible training."
    )
    parser.add_argument(
        "--train_file", type=str, default="data/train_audiocaps.json",
        help="A csv or a json file containing the training data."
    )
    parser.add_argument(
        "--validation_file", type=str, default="data/train_audiocaps.json",
        help="A csv or a json file containing the training data."
    )
    parser.add_argument(
        "--num_examples", type=int, default=-1,
        help="How many examples to use for training and validation.",
    )
    parser.add_argument(
        "--text_encoder_name", type=str, default="google/flan-t5-large",
        help="Text encoder identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--unet_model_config", type=str, default=None,
        help="UNet model config json path.",
    )
    parser.add_argument(
        "--ctm_unet_model_config", type=str, default=None,
        help="UNet model config json path.",
    )
    parser.add_argument(
        "--freeze_text_encoder", action="store_true", default=False,
        help="Freeze the text encoder model.",
    )
    parser.add_argument(
        "--text_column", type=str, default="captions",
        help="The name of the column in the datasets containing the input texts.",
    )
    parser.add_argument(
        "--audio_column", type=str, default="location",
        help="The name of the column in the datasets containing the audio paths.",
    )
    parser.add_argument(
        "--tango_data_augment", action="store_true", default=False,
        help="Augment training data.",
    )
    parser.add_argument(
        "--augment_num", type=int, default=1,
        help="number of augment training data.",
    )
    parser.add_argument(
        "--uncond_prob", type=float, default=0.1,
        help="Dropout rate of conditon text.",
    )
    parser.add_argument(
        "--prefix", type=str, default=None,
        help="Add prefix in text prompts.",
    )
    parser.add_argument(
        "--per_device_train_batch_size", type=int, default=2,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size", type=int, default=2,
        help="Batch size (per device) for the validation dataloader.",
    )
    parser.add_argument(
        "--num_train_epochs", type=int, default=40,
        help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--gradient_accumulation_steps", type=int, default=4,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--lr_scheduler_type", type=SchedulerType, default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument(
        "--d_lr_scheduler_type", type=SchedulerType, default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument(
        "--num_warmup_steps", type=int, default=0,
        help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--d_num_warmup_steps", type=int, default=0,
        help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--adam_beta1", type=float, default=0.9,
        help="The beta1 parameter for the Adam optimizer."
    )
    parser.add_argument(
        "--adam_beta2", type=float, default=0.999,
        help="The beta2 parameter for the Adam optimizer."
    )
    # parser.add_argument(
    #     "--adam_weight_decay", type=float, default=1e-2,
    #     help="Weight decay to use."
    # )
    parser.add_argument(
        "--adam_epsilon", type=float, default=1e-08,
        help="Epsilon value for the Adam optimizer"
    )
    parser.add_argument(
        "--output_dir", type=str, default=None,
        help="Where to store the final model."
    )
    parser.add_argument(
        "--duration", type=float, default=10.0,
        help="input audio duration"
    )
    parser.add_argument(
        "--checkpointing_steps", type=str, default="best",
        help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
    )

    parser.add_argument(
        "--model_grad_clip_value", type=float, default=1000.,
        help="Clipping value for gradient of model"
    )
    parser.add_argument(
        "--disc_grad_clip_value", type=float, default=1000.,
        help="Clipping value for gradient of discriminator"
    )
    parser.add_argument(
        "--sigma_data", type=float, default=0.25,
        help="sigma_data for the model"
    )
    parser.add_argument(
        "--resume_from_checkpoint", type=str, default=None,
        help="If the training should continue from a local checkpoint folder.",
    )

    parser.add_argument(
        "--generated_path", type=str, default=None,
        help="Path to the generated audio for evaluation.",
    )
    parser.add_argument(
        "--valid_data_path", type=str, default=None,
        help="Path to the validation audio for evaluation.",
    )
    
    # ----Mixed Precision----
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default='bf16',
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
    )
    parser.add_argument(
        "--with_tracking", action="store_true",
        help="Whether to enable experiment trackers for logging.",
    )
    parser.add_argument(
        "--report_to", type=str, default="wandb",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
            ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
            "Only applicable when `--with_tracking` is passed."
        ),
    )
    
    defaults = dict(
        teacher_model_path="ckpt/sigma_025/best/pytorch_model_2.bin",
        stage1_path="ckpt/audioldm-s-full.ckpt",
        schedule_sampler="uniform",
        lr=0.00004,
        weight_decay=0.0,
        lr_anneal_steps=0,
        ema_rate="0.999",  # comma-separated list of EMA values "0.999,0.9999" if data_name == 'cifar10' else "0.999,0.9999,0.9999432189950708",
        total_training_steps=600000,
        save_interval=3000, # 4000
        unet_mode = 'half'
        # sample_interval=8001,
        # eval_interval=7001,
        # eval_cfg=3.0,
        # use_fp16=False,
        # fp16_scale_growth=1e-3,
    )
    defaults.update(others_defaults())
    defaults.update(ctm_train_defaults())
    defaults.update(ctm_eval_defaults())
    defaults.update(gan_defaults())
    defaults.update()
    
    add_dict_to_argparser(parser, defaults)
    args = parser.parse_args()

    return args



class Text2AudioDataset(Dataset):
    def __init__(self, dataset, prefix, text_column, audio_column, uncond_prob=0.1 ,num_examples=-1):

        inputs = list(dataset[text_column])
        self.inputs = [prefix + inp for inp in inputs]
        self.audios = list(dataset[audio_column])
        self.indices = list(range(len(self.inputs)))
        self.uncond_prob = uncond_prob

        self.mapper = {}
        for index, audio, text in zip(self.indices, self.audios, inputs):
            self.mapper[index] = [audio, text]

        if num_examples != -1:
            self.inputs, self.audios = self.inputs[:num_examples], self.audios[:num_examples]
            self.indices = self.indices[:num_examples]

    def __len__(self):
        return len(self.inputs)

    def get_num_instances(self):
        return len(self.inputs)

    def __getitem__(self, index):
        text = self.inputs[index]
        text = "" if random.random() < self.uncond_prob else text
        s1, s2, s3 = text, self.audios[index], self.indices[index]
        return s1, s2, s3

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        return [dat[i].tolist() for i in dat]

# class EMADummyModel(nn.Module):
#     def __init__(self, parameters_list):
#         super().__init__()
#         self.param_sets = nn.ParameterList([
#             nn.ParameterList([nn.Parameter(p.detach(), requires_grad=False) for p in params])
#             for params in parameters_list
#         ])

#     def get_param_sets(self):
        
#         return [[p for p in param_set] for param_set in self.param_sets]


def main():
    args = create_argparser() 
    
    accelerator_log_kwargs = {}
    if args.with_tracking:
        accelerator_log_kwargs["log_with"] = args.report_to
        accelerator_log_kwargs["project_dir"] = args.output_dir
        
    # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
    # deepspeed_plugin = DeepSpeedPlugin(zero_stage=2)

    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, 
                              mixed_precision=args.mixed_precision, 
                            #   deepspeed_plugin=deepspeed_plugin,
                            #   kwargs_handlers=[ddp_kwargs],
                              **accelerator_log_kwargs)
    

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()
        datasets.utils.logging.set_verbosity_error()


    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)
        rand_fix(args.seed)

    # Handle output directory creation and wandb tracking
    if accelerator.is_main_process:
        if args.output_dir is None or args.output_dir == "":
            args.output_dir = "saved/" + str(int(time.time()))
            
            if not os.path.exists("saved"):
                os.makedirs("saved")
                
            os.makedirs(args.output_dir, exist_ok=True)
            
        elif args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

        os.makedirs("{}/{}".format(args.output_dir, "outputs"), exist_ok=True)
        with open("{}/summary.jsonl".format(args.output_dir), "a") as f:
            f.write(json.dumps(dict(vars(args))) + "\n\n")

        accelerator.project_configuration.automatic_checkpoint_naming = False

        wandb.init(project="XXXXXXXX")

    accelerator.wait_for_everyone()

   # Get the datasets
    data_files = {}
    if args.train_file is not None:
        data_files["train"] = args.train_file   
        
    # NOTE: start from not using validation dataset.
    if args.validation_file is not None:
        data_files["validation"] = args.validation_file

    extension = args.train_file.split(".")[-1]
    raw_datasets = load_dataset(extension, data_files=data_files)
    text_column, audio_column = args.text_column, args.audio_column

    ema_scale_fn = create_ema_and_scales_fn(
        target_ema_mode=args.target_ema_mode, # fixed
        start_ema=args.start_ema, # 0.999
        scale_mode=args.scale_mode, # fixed
        start_scales=args.start_scales, # 40
        end_scales=args.end_scales, # 40
        total_steps=args.total_training_steps, # 600000
        distill_steps_per_iter=args.distill_steps_per_iter, # 50000
    )


    # Load Discriminator
    # NOTE: We start from DAC's discriminator. 
    discriminator = None
    if args.discriminator_training:
        feature_networks = None
        if args.d_architecture == 'DAC_GAN':
            discriminator = DAC_GAN_Discriminator(
                            args.dac_dis_rates,
                            args.dac_dis_periods,
                            args.dac_dis_fft_sizes,
                            args.dac_dis_sample_rate,
                            args.dac_dis_bands)
        elif args.d_architecture == 'DAC_SAN':
            discriminator = DAC_SAN_Discriminator(
                            args.dac_dis_rates,
                            args.dac_dis_periods,
                            args.dac_dis_fft_sizes,
                            args.dac_dis_sample_rate,
                            args.dac_dis_bands)
        elif args.d_architecture == 'DAC_CGAN':
            discriminator = DAC_GAN_CondDiscriminator(
                            args.dac_dis_rates,
                            args.dac_dis_periods,
                            args.dac_dis_fft_sizes,
                            args.dac_dis_sample_rate,
                            args.dac_dis_bands,
                            args.d_cond_type,
                            args.c_dim,
                            args.cmap_dim
                            )
        
        elif args.d_architecture == 'MEL_VQGAN':
            discriminator = MelVQGAN_Discriminator(
                            input_nc=1, 
                            ndf=args.vqgan_ndf, 
                            n_layers=args.vqgan_n_layers, 
                            use_spectral_norm=args.vqgan_use_spectral_norm, 
                            use_actnorm=False
                            )
        elif args.d_architecture == 'MEL_CVQGAN':
            discriminator = MelCVQGAN_Discriminator(
                            input_nc=1, 
                            ndf=args.vqgan_ndf, 
                            n_layers=args.vqgan_n_layers, 
                            use_spectral_norm=args.vqgan_use_spectral_norm, 
                            use_actnorm=False,
                            d_cond_type=args.d_cond_type,
                            c_dim=args.c_dim,
                            cmap_dim=args.cmap_dim,
                            device=accelerator.device
                            )
            
        elif args.d_architecture == 'L_VQGAN':
            discriminator = VQGAN_Discriminator(
                            input_nc=args.latent_channels, 
                            ndf=args.vqgan_ndf, 
                            n_layers=args.vqgan_n_layers, 
                            use_spectral_norm=args.vqgan_use_spectral_norm, 
                            use_actnorm=False
                            )
        elif args.d_architecture == 'L_CVQGAN':
            discriminator = CVQGAN_Discriminator(
                            input_nc=args.latent_channels, 
                            ndf=args.vqgan_ndf, 
                            n_layers=args.vqgan_n_layers, 
                            use_spectral_norm=args.vqgan_use_spectral_norm, 
                            use_actnorm=False,
                            d_cond_type=args.d_cond_type,
                            c_dim=args.c_dim,
                            cmap_dim=args.cmap_dim,
                            device=accelerator.device
                            )
        elif args.d_architecture == 'MBDisc':
            discriminator = MBDiscriminator(
                            ndf=args.mbdisc_ndf,
                            n_bins=args.n_bins, 
                            bands=args.dac_dis_bands,
                            increase_ch=args.increase_ch,
                            )
        elif args.d_architecture == 'CMBDisc':
            discriminator = ConditionalMBDiscriminator(
                            ndf=args.mbdisc_ndf,
                            n_bins=args.n_bins, 
                            bands=args.dac_dis_bands,
                            increase_ch=args.increase_ch,
                            d_cond_type=args.d_cond_type,
                            c_dim=args.c_dim,
                            cmap_dim=args.cmap_dim,
                            device=accelerator.device
                            )
       # elif args.d_architecture == 'DAC_CSAN':
        #     discriminator = DAC_SAN_CondDiscriminator(
        #                     args.dac_dis_rates,
        #                     args.dac_dis_periods,
        #                     args.dac_dis_fft_sizes,
        #                     args.dac_dis_sample_rate,
        #                     args.dac_dis_bands,
        #                     args.d_cond_type,
        #                     args.c_dim,
        #                     args.cmap_dim
        #                     )
        if accelerator.is_main_process:
            logger.info("Discriminator params: {}".format(sum(p.numel() for p in discriminator.parameters())))
            
        discriminator.train()


    # Load stage1 models, vocoder, and mel preprocess function
    pretrained_model_name = "audioldm-s-full"
    # vae, stft = build_pretrained_models(pretrained_model_name, stage1_path=args.stage1_path)
    vae, stft = build_pretrained_models(pretrained_model_name)
    vae.requires_grad_(False)
    stft.requires_grad_(False)
    vae.eval()
    stft.eval()

    # Load Model
    logger.info("creating the student model")
    model, diffusion = create_model_and_diffusion(args, feature_networks)
    model.train()

    
    if args.gradient_checkpointing:
    # Keep unet in train mode if we are using gradient checkpointing to save memory.
    # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
        model.ctm_unet.enable_gradient_checkpointing()
        
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                logger.warn(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            model.ctm_unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")
        

    # Load teacher model
    if len(args.teacher_model_path) > 0:  # path to the teacher score model.
        logger.info(f"loading the teacher model from {args.teacher_model_path}")
        teacher_model, _ = create_model_and_diffusion(args, teacher=True)
        
        if os.path.exists(args.teacher_model_path):
            model_ckpt = th.load(args.teacher_model_path, map_location=accelerator.device)
            teacher_model.load_state_dict(model_ckpt, strict=False)
        teacher_model.eval()
        
        # Initialize model parameters with teacher model
        for dst_name, dst in model.ctm_unet.named_parameters():
            for src_name, src in teacher_model.unet.named_parameters():
                if dst_name in ['.'.join(src_name.split('.')[1:]), src_name]:
                    dst.data.copy_(src.data)
                    break

        for dst_name, dst in model.text_encoder.named_parameters():
            for src_name, src in teacher_model.text_encoder.named_parameters():
                if dst_name in ['.'.join(src_name.split('.')[1:]), src_name]:
                    dst.data.copy_(src.data)
                    break
        
        teacher_model.requires_grad_(False)
        teacher_model.eval()
        logger.info(f"Initialized parameters of student (online) model synced with the teacher model from {args.teacher_model_path}")
        
    else:
        teacher_model = None
    # Load the target model for distillation, if path specified.
    logger.info("creating the target model")
    target_model, _ = create_model_and_diffusion(args)
    logger.info(f"Copy parameters of student model with the target_model model")
    for dst, src in zip(target_model.parameters(), model.parameters()):
        dst.data.copy_(src.data)
        
    target_model.requires_grad_(False)
    target_model.train()
    
    # Handle mixed precision and device placement
    # For mixed precision training we cast all non-trainable weigths to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    # weight_dtype = th.float32
    # if accelerator.mixed_precision == "fp16":
    #     weight_dtype = th.float16
    # elif accelerator.mixed_precision == "bf16":
    #     weight_dtype = th.bfloat16

    # Move unet, vae and stft to device and cast to weight_dtype
    # The VAE is in float32 to avoid NaN losses.  
    vae.to(accelerator.device)
    stft.to(accelerator.device)
    # if args.pretrained_vae_model_name_or_path is not None:
    #     vae.to(dtype=weight_dtype)

    # Move teacher_unet to device, optionally cast to weight_dtype
    target_model.to(accelerator.device)
    teacher_model.to(accelerator.device)
    # Also move the diffusion schedules to accelerator.device.
    # diffusion = diffusion.to(accelerator.device)
    
    # Load EMA stuff
    # ema_rate = (
    #     [args.ema_rate]
    #     if isinstance(args.ema_rate, float)
    #     else [float(x) for x in args.ema_rate.split(",")]
    # )
    # ema_params = [
    #             copy.deepcopy(list(model.ctm_unet.parameters()))
    #             for _ in range(len(ema_rate))
    #         ]
    

    # for rate, params in zip(ema_rate, ema_params):
    #     if rate == args.start_ema: # 0.999
    #         logger.info(f"loading target model from {args.start_ema} ema...")
    #         state_dict = model.ctm_unet.state_dict()
    #         for i, (name, _value) in enumerate(model.ctm_unet.named_parameters()):
    #             assert name in state_dict
    #             state_dict[name] = params[i]
    #         target_model.load_state_dict(state_dict, strict=False)
    # target_model.requires_grad_(False)
    # target_model.train()
    
    # ema_model = EMADummyModel(ema_params)

    # Define dataloader
    logger.info("creating data loader...")
    if args.prefix:
            prefix = args.prefix
    else:
        prefix = ""

    with accelerator.main_process_first():
        train_dataset = Text2AudioDataset(raw_datasets["train"], prefix, text_column, audio_column, args.uncond_prob, args.num_examples)
        eval_dataset = Text2AudioDataset(raw_datasets["validation"], prefix, text_column, audio_column, 0.0, args.num_examples)
        # test_dataset = Text2AudioDataset(raw_datasets["test"], prefix, text_column, audio_column, args.num_examples)
        accelerator.print("Num instances in train: {}, validation: {}".format(train_dataset.get_num_instances(), eval_dataset.get_num_instances()))

        # accelerator.print("Num instances in train: {}, validation: {}, test: {}".format(train_dataset.get_num_instances(), eval_dataset.get_num_instances(), test_dataset.get_num_instances()))

    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.per_device_train_batch_size, collate_fn=train_dataset.collate_fn)
    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, collate_fn=eval_dataset.collate_fn)
    # test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.per_device_eval_batch_size, collate_fn=test_dataset.collate_fn)

    # Optimizer 
    if args.freeze_text_encoder:
        for param in model.text_encoder.parameters():
            param.requires_grad = False
            model.text_encoder.eval()
        
        if args.ctm_unet_model_config:
            optimizer_parameters = model.ctm_unet.parameters()
            accelerator.print("Optimizing CTM UNet parameters.")
        else:
            NotImplementedError
            # optimizer_parameters = list(model.unet.parameters()) + list(model.group_in.parameters()) + list(model.group_out.parameters())
            # accelerator.print("Optimizing UNet and channel transformer parameters.")
    else:
        NotImplementedError
        # optimizer_parameters = model.parameters()
        # accelerator.print("Optimizing Text Encoder and UNet parameters.")

    num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    accelerator.print("Num CTM UNet trainable parameters: {}".format(num_trainable_parameters))
    
    optimizer = th.optim.RAdam(
        optimizer_parameters, lr=args.lr,
        weight_decay=args.weight_decay,
    )
    

    d_optimizer_parameters = discriminator.parameters()
    accelerator.print("Optimizing discriminator parameters.")
    d_num_trainable_parameters = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
    accelerator.print("Num discriminator's trainable parameters: {}".format(d_num_trainable_parameters))

    d_optimizer = th.optim.RAdam(
        d_optimizer_parameters, lr=args.d_lr,
        weight_decay=args.weight_decay,
        betas=(0.5, 0.9)
    )
    overrode_total_training_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader))
    if args.total_training_steps is None:
        args.total_training_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_total_training_steps = True

    # lr_scheduler = get_scheduler(
    #     name=args.lr_scheduler_type,
    #     optimizer=optimizer,
    #     num_warmup_steps=args.num_warmup_steps,
    #     num_training_steps=args.total_training_steps,
    # )
    
    
    # d_lr_scheduler = get_scheduler(
    #     name=args.d_lr_scheduler_type,
    #     optimizer=d_optimizer,
    #     num_warmup_steps=args.d_num_warmup_steps,
    #     num_training_steps=args.total_training_steps,
    # )
        
    # Prepare everything with our `accelerator`.

    model, discriminator, optimizer, d_optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, discriminator, optimizer, d_optimizer, train_dataloader, eval_dataloader
    )

    
    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader))
    if overrode_total_training_steps:
        args.total_training_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.total_training_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if args.with_tracking:
        experiment_config = vars(args)
        accelerator.init_trackers("text_to_audio_diffusion", experiment_config)


    # Train!
    total_batch_size = (args.per_device_train_batch_size + args.augment_num) * args.gradient_accumulation_steps * accelerator.num_processes
    th.cuda.empty_cache()
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device + augment_num = {args.per_device_train_batch_size} + {args.augment_num}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation & data augmentation) = {total_batch_size}")
    logger.info(f"  Total optimization steps = {args.total_training_steps}")
    # progress_bar = tqdm(range(args.total_training_steps), disable=not accelerator.is_local_main_process)
    resume_epoch = 0
    resume_step = 0
    resume_global_step = 0
    
    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
            progress_state = th.load(os.path.join(args.resume_from_checkpoint, "progress_state.pth"), map_location=accelerator.device)
            resume_step = progress_state['completed_steps']
            resume_global_step = progress_state['completed_global_steps']
            resume_epoch = progress_state['completed_epochs']
            # accelerator.load_state(os.path.join(args.resume_from_checkpoint, f"{resume_global_step:06d}"))
            accelerator.load_state(args.resume_from_checkpoint)
            
            state_dict = th.load(os.path.join(args.resume_from_checkpoint, f"ema_{args.ema_rate}_{resume_step:06d}.pt"), map_location=accelerator.device)
            target_model.load_state_dict(state_dict, strict=False)
            target_model.requires_grad_(False)
            target_model.train()
            target_model.to(accelerator.device)
            # try:
            #     ema_params = [state_dict[name] for name, _ in model.ctm_unet.named_parameters()]
            # except:
            #     ema_params = [state_dict[name] for name, _ in model.module.ctm_unet.named_parameters()] 
            # path = os.path.basename(args.resume_from_checkpoint)
            accelerator.print(f"Resumed from local checkpoint: {args.resume_from_checkpoint}")
        else:
            # Get the most recent checkpoint
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs.sort(key=os.path.getctime)
            # path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
    # breakpoint() 23396MiB
    # accelerator.wait_for_everyone()
    th.cuda.empty_cache()
    CMTrainLoop(
        model=model,
        target_model=target_model,
        teacher_model=teacher_model,
        # ema_model=ema_model,
        latent_decoder=vae,
        stft=stft,
        discriminator=discriminator,
        ema_scale_fn=ema_scale_fn,
        diffusion=diffusion,
        data=train_dataloader,
        eval_dataloader=eval_dataloader,
        args=args,
        accelerator=accelerator, 
        opt=optimizer, 
        d_opt=d_optimizer,
        resume_step=resume_step,
        resume_global_step=resume_global_step,
        resume_epoch=resume_epoch,
        # lr_scheduler = lr_scheduler,
        # d_lr_scheduler = d_lr_scheduler,
    ).run_loop()


if __name__ == "__main__":
    main()