import json
import logging
import math
import os
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List

import diffusers
import torch
import torch.backends
import torch.nn as nn
import transformers
import wandb
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import (
    DistributedDataParallelKwargs,
    InitProcessGroupKwargs,
    ProjectConfiguration,
    gather_object,
    set_seed,
)
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import export_to_video, load_image, load_video
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from tqdm import tqdm

from .args import Args, validate_args
from .constants import (
    FINETRAINERS_LOG_LEVEL,
    PRECOMPUTED_CONDITIONS_DIR_NAME,
    PRECOMPUTED_DIR_NAME,
    PRECOMPUTED_LATENTS_DIR_NAME,
)
from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
from .hooks import apply_layerwise_upcasting
from .models import get_config_from_model_name
from .patches import perform_peft_patches
from .state import State
from .utils.checkpointing import (
    get_intermediate_ckpt_path,
    get_latest_ckpt_path_to_resume_from,
)
from .utils.data_utils import should_perform_precomputation
from .utils.diffusion_utils import (
    get_scheduler_alphas,
    get_scheduler_sigmas,
    prepare_loss_weights,
    prepare_sigmas,
    prepare_target,
)
from .utils.file_utils import string_to_filename
from .utils.hub_utils import save_model_card
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
from .utils.model_utils import resolve_vae_cls_from_ckpt_path
from .utils.optimizer_utils import get_optimizer
from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
from .quant.quant import QuantLoRALinear, replace_linear, QuantLinear


logger = get_logger("qvgen")
logger.setLevel(FINETRAINERS_LOG_LEVEL)


class Trainer:
    def __init__(self, args: Args) -> None:
        validate_args(args)

        self.args = args
        self.args.seed = self.args.seed or datetime.now().year
        self.state = State()

        # Tokenizers
        self.tokenizer = None
        self.tokenizer_2 = None
        self.tokenizer_3 = None

        # Text encoders
        self.text_encoder = None
        self.text_encoder_2 = None
        self.text_encoder_3 = None

        # Denoisers
        self.transformer = None
        self.teacher = None  # For distillation
        self.unet = None

        # Autoencoders
        self.vae = None

        # Scheduler
        self.scheduler = None

        self.transformer_config = None
        self.vae_config = None
        self.lr_dict = dict()

        self._init_distributed()
        self._init_logging()
        self._init_directories_and_repositories()
        self._init_config_options()

        # Peform any patches needed for training
        if len(self.args.layerwise_upcasting_modules) > 0:
            perform_peft_patches()
        #  handle text encoders
        # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
        #     perform_text_encoder_patches()

        self.state.model_name = self.args.model_name
        self.model_config = get_config_from_model_name(
            self.args.model_name, self.args.training_type
        )

    def prepare_dataset(self) -> None:
        #  Make a background process for fetching
        logger.info("Initializing dataset and dataloader")

        self.dataset = ImageOrVideoDatasetWithResizing(
            data_root=self.args.data_root,
            caption_column=self.args.caption_column,
            video_column=self.args.video_column,
            resolution_buckets=self.args.video_resolution_buckets,
            dataset_file=self.args.dataset_file,
            id_token=self.args.id_token,
            remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes,
        )
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=1,
            sampler=BucketSampler(
                self.dataset, batch_size=self.args.batch_size, shuffle=True
            ),
            collate_fn=self.model_config.get("collate_fn"),
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.pin_memory,
        )

    def prepare_models(self) -> None:
        logger.info("Initializing models")

        load_components_kwargs = self._get_load_components_kwargs()
        condition_components, latent_components, diffusion_components = {}, {}, {}
        if not self.args.precompute_conditions:
            # To download the model files first on the main process (if not already present)
            # and then load the cached files afterward from the other processes.
            with self.state.accelerator.main_process_first():
                condition_components = self.model_config["load_condition_models"](
                    **load_components_kwargs
                )
                latent_components = self.model_config["load_latent_models"](
                    **load_components_kwargs
                )
                diffusion_components = self.model_config["load_diffusion_models"](
                    **load_components_kwargs
                )

        components = {}
        components.update(condition_components)
        components.update(latent_components)
        components.update(diffusion_components)
        self._set_components(components)

        if self.vae is not None:
            if self.args.enable_slicing:
                self.vae.enable_slicing()
            if self.args.enable_tiling:
                self.vae.enable_tiling()

        if self.args.use_teacher:
            with self.state.accelerator.main_process_first():
                self.teacher = self.model_config["load_diffusion_models"](
                    **load_components_kwargs
                ).get("transformer", self.teacher)

    def prepare_precomputations(self) -> None:
        if not self.args.precompute_conditions:
            return

        logger.info("Initializing precomputations")

        if self.args.batch_size != 1:
            raise ValueError(
                "Precomputation is only supported with batch size 1. This will be supported in future."
            )

        def collate_fn(batch):
            latent_conditions = [x["latent_conditions"] for x in batch]
            text_conditions = [x["text_conditions"] for x in batch]
            batched_latent_conditions = {}
            batched_text_conditions = {}
            for key in list(latent_conditions[0].keys()):
                if torch.is_tensor(latent_conditions[0][key]):
                    batched_latent_conditions[key] = torch.cat(
                        [x[key] for x in latent_conditions], dim=0
                    )
                else:
                    #  implement batch sampler for precomputed latents
                    batched_latent_conditions[key] = [
                        x[key] for x in latent_conditions
                    ][0]
            for key in list(text_conditions[0].keys()):
                if torch.is_tensor(text_conditions[0][key]):
                    batched_text_conditions[key] = torch.cat(
                        [x[key] for x in text_conditions], dim=0
                    )
                else:
                    #  implement batch sampler for precomputed latents
                    batched_text_conditions[key] = [x[key] for x in text_conditions][0]
            return {
                "latent_conditions": batched_latent_conditions,
                "text_conditions": batched_text_conditions,
            }

        cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path)
        precomputation_dir = (
            Path(self.args.data_root)
            / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
        )
        should_precompute = should_perform_precomputation(precomputation_dir)
        if not should_precompute:
            logger.info(
                "Precomputed conditions and latents found. Loading precomputed data."
            )
            self.dataloader = torch.utils.data.DataLoader(
                PrecomputedDataset(
                    data_root=self.args.data_root,
                    model_name=self.args.model_name,
                    cleaned_model_id=cleaned_model_id,
                ),
                batch_size=self.args.batch_size,
                shuffle=True,
                collate_fn=collate_fn,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.pin_memory,
            )
            return

        logger.info(
            "Precomputed conditions and latents not found. Running precomputation."
        )

        # At this point, no models are loaded, so we need to load and precompute conditions and latents
        with self.state.accelerator.main_process_first():
            condition_components = self.model_config["load_condition_models"](
                **self._get_load_components_kwargs()
            )
        self._set_components(condition_components)
        self._move_components_to_device()
        self._disable_grad_for_components(
            [self.text_encoder, self.text_encoder_2, self.text_encoder_3]
        )

        if (
            self.args.caption_dropout_p > 0
            and self.args.caption_dropout_technique == "empty"
        ):
            logger.warning(
                "Caption dropout is not supported with precomputation yet. This will be supported in the future."
            )

        conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
        latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
        conditions_dir.mkdir(parents=True, exist_ok=True)
        latents_dir.mkdir(parents=True, exist_ok=True)

        accelerator = self.state.accelerator

        # Precompute conditions
        progress_bar = tqdm(
            range(
                0,
                (len(self.dataset) + accelerator.num_processes - 1)
                // accelerator.num_processes,
            ),
            desc="Precomputing conditions",
            disable=not accelerator.is_local_main_process,
        )
        index = 0
        for i, data in enumerate(self.dataset):
            if i % accelerator.num_processes != accelerator.process_index:
                continue

            logger.debug(
                f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
            )

            text_conditions = self.model_config["prepare_conditions"](
                tokenizer=self.tokenizer,
                tokenizer_2=self.tokenizer_2,
                tokenizer_3=self.tokenizer_3,
                text_encoder=self.text_encoder,
                text_encoder_2=self.text_encoder_2,
                text_encoder_3=self.text_encoder_3,
                prompt=data["prompt"],
                device=accelerator.device,
                dtype=self.args.transformer_dtype,
            )
            filename = (
                conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
            )
            torch.save(text_conditions, filename.as_posix())
            index += 1
            progress_bar.update(1)
        self._delete_components()

        memory_statistics = get_memory_statistics()
        logger.info(
            f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}"
        )
        torch.cuda.reset_peak_memory_stats(accelerator.device)

        # Precompute latents
        with self.state.accelerator.main_process_first():
            latent_components = self.model_config["load_latent_models"](
                **self._get_load_components_kwargs()
            )
        self._set_components(latent_components)
        self._move_components_to_device()
        self._disable_grad_for_components([self.vae])

        if self.vae is not None:
            if self.args.enable_slicing:
                self.vae.enable_slicing()
            if self.args.enable_tiling:
                self.vae.enable_tiling()

        progress_bar = tqdm(
            range(
                0,
                (len(self.dataset) + accelerator.num_processes - 1)
                // accelerator.num_processes,
            ),
            desc="Precomputing latents",
            disable=not accelerator.is_local_main_process,
        )
        index = 0
        for i, data in enumerate(self.dataset):
            if i % accelerator.num_processes != accelerator.process_index:
                continue

            logger.debug(
                f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
            )

            latent_conditions = self.model_config["prepare_latents"](
                vae=self.vae,
                image_or_video=data["video"].unsqueeze(0),
                device=accelerator.device,
                dtype=self.args.transformer_dtype,
                generator=self.state.generator,
                precompute=True,
            )
            filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt"
            torch.save(latent_conditions, filename.as_posix())
            index += 1
            progress_bar.update(1)
        self._delete_components()

        accelerator.wait_for_everyone()
        logger.info("Precomputation complete")

        memory_statistics = get_memory_statistics()
        logger.info(
            f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}"
        )
        torch.cuda.reset_peak_memory_stats(accelerator.device)

        # Update dataloader to use precomputed conditions and latents
        self.dataloader = torch.utils.data.DataLoader(
            PrecomputedDataset(
                data_root=self.args.data_root,
                model_name=self.args.model_name,
                cleaned_model_id=cleaned_model_id,
            ),
            batch_size=self.args.batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.pin_memory,
        )

    def prepare_quantized_model(self) -> None:
        logger.info("Initializing quantized models")

        # replace Linear -> QuantLinear
        suffix = ""
        if self.args.progress_training:
            suffix = f"_{self.args.progress_rank}_{self.args.progress_iter}"
            self.args.w_quantizer = "progressive_lora"
        if self.args.w_granularity == "per_group":
            suffix = suffix + f"_g{self.args.w_group_size}"

        assert self.transformer is not None
        resume = False
        if self.args.resume_from_checkpoint and (
            self.args.w_cali != "minmax" or suffix != ""
        ):
            q_path = os.path.join(
                self.args.output_dir,
                f"quant_{self.args.w_cali}_w{self.args.w_bit}{suffix}.bin",
            )
            logger.info(q_path)
            if os.path.exists(q_path):
                logger.info(f"Load {q_path}")
                with self.state.accelerator.main_process_first():
                    quantization_parameters = torch.load(q_path)
                    w_cali = "minmax"
                    iters = 1
                    resume = True
            else:
                quantization_parameters = None
                w_cali = self.args.w_cali
                iters = self.args.progress_iter
        else:
            quantization_parameters = None
            w_cali = self.args.w_cali
            iters = self.args.progress_iter
        quantizer_type = {
            "w": self.args.w_quantizer,
            "act": self.args.act_quantizer,
        }
        q_params = {
            "w": {
                "bit": self.args.w_bit,
                "sym": False,
                "granularity": self.args.w_granularity,
                "cali": w_cali,
                "group_size": self.args.w_group_size,
                "round_zero": True,
                "use_grad_scaling": self.args.w_use_grad_scaling,
                "round_func": self.args.round_func,
                "rank": self.args.progress_rank,
                "iters": iters,
                "cur_alpha_strategy": self.args.progress_cur_alpha_strategy,
                "alpha_T": self.args.progress_alpha_T,
                "warm_up": self.args.progress_warm_up,
                "resume": resume,
            },
            "act": {
                "bit": self.args.act_bit,
                "sym": False,
                "granularity": "per_token",
                "cali": "minmax",
                "round_zero": True,
                "clip_ratio": self.args.act_clip_ratio,
                "round_func": self.args.round_func,
            },
        }
        lora_params = {
            "rank": self.args.qa_rank,
            "dropout": self.args.qa_dropout,
            "alpha": self.args.qa_alpha,
        }
        if self.args.act_quantizer == "learnable_clipped_dynamic":
            q_params["act"]["timesteps"] = self.args.clip_group_num
        if self.args.round_func == "LearnableAdaSTE":
            q_params["w"]["timesteps_ste"] = self.args.round_group_num
            q_params["act"]["timesteps_ste"] = self.args.round_group_num
        logger.info(f"Quantization type: {json.dumps(quantizer_type, indent=4)}")
        logger.info(f"Quantization parameters: {json.dumps(q_params, indent=4)}")
        if self.args.qa_lora:
            logger.info(f"Quantization-aware LoRA: {json.dumps(lora_params, indent=4)}")
        random_list = []
        if self.args.wo_quant_random:
            random_list = random.sample(
                [
                    p
                    for p in range(3, 245)
                    if ((p - 3) % 8 != 0 and (p < 8 or (p - 8) % 8 != 0)) or p >= 243
                ],
                self.args.wo_quant_num,
            )
            # 245 is the number of layers in CogVideoX-2b
            # discard norm1, norm2, time_embedding
        replace_linear(
            self.transformer,
            quantizer_type,
            q_params,
            lora_params,
            self.args.qa_lora,
            self.args.wo_quant_adaln,
            self.args.wo_quant_only_time_embedding,
            self.args.wo_quant_only_norm,
            random_list,
        )
        # initialize quantization parameters
        logger.info("Initializing weight quantization parameters")
        for _, module in self.transformer.named_modules():
            if isinstance(module, (QuantLinear, QuantLoRALinear)):
                if w_cali != "minmax":
                    logger.info(f"Init. {_}.")
                module.wquantizer(module.w.detach())
                module.wquantizer.build()
                module.aquantizer.build()
                module.set_quant_state(True, self.args.use_aq)

        if quantization_parameters is not None:
            self.transformer.load_state_dict(quantization_parameters, strict=False)
            import safetensors.torch

            if os.path.exists(
                os.path.join(
                    self.args.output_dir,
                    "transformer_prev/diffusion_pytorch_model.safetensors",
                )
            ):
                logger.info(f"Load transformer_prev")
                self.transformer.load_state_dict(
                    safetensors.torch.load_file(
                        os.path.join(
                            self.args.output_dir,
                            "transformer_prev/diffusion_pytorch_model.safetensors",
                        )
                    ),
                    strict=True,
                )
        elif w_cali != "minmax":
            if self.state.accelerator.is_main_process:
                q_dict = dict()
                for n, p in {
                    **dict(self.transformer.named_parameters()),
                    **dict(self.transformer.named_buffers()),
                }.items():  # for L, R, scare, zero_point
                    if "wquantizer" in n:
                        q_dict[n] = p
                torch.save(
                    q_dict,
                    os.path.join(
                        self.args.output_dir,
                        f"quant_{self.args.w_cali}_w{self.args.w_bit}{suffix}.bin",
                    ),
                )

    def prepare_trainable_parameters(self) -> None:
        logger.info("Initializing trainable parameters")

        # with self.state.accelerator.main_process_first():
        #     diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
        # self._set_components(diffusion_components)

        components = [
            self.text_encoder,
            self.text_encoder_2,
            self.text_encoder_3,
            self.vae,
        ]
        self._disable_grad_for_components(components)

        if self.args.training_type == "full-finetune":
            if self.args.qa_lora:
                logger.info("Finetuning transformer with QA-LoRA")
                self._enable_grad_for_components([self.transformer])
                for name, module in self.transformer.named_modules():
                    assert (
                        not isinstance(module, nn.Linear) or "lora" in name
                    ), "All `nn.Linear`s are required to be replaced with `QuantLoRALinear`."
                    if isinstance(module, QuantLoRALinear):
                        module.w.requires_grad = False
            elif self.args.progress_training:
                logger.info(
                    "Finetuning transformer with Progressive LoRA"
                )  # only enable grad of L, R and, (lower, upper) for activation quantizer
                self._enable_grad_for_components([self.transformer])
                if not self.args.progress_fix:
                    for name, module in self.transformer.named_modules():
                        if isinstance(module, QuantLinear):
                            module.w.requires_grad = False
                            module.wquantizer.scale.requires_grad = False
                            module.wquantizer.zero_point.requires_grad = False
            else:
                logger.info("Finetuning transformer with no additional parameters")
                self._enable_grad_for_components([self.transformer])
                if self.args.freeze_adaln:  # add freeze
                    for n, m in self.transformer.named_modules():
                        if isinstance(m, nn.Linear):
                            for p in m.parameters():
                                p.requires_grad = False
                            logger.info(f"Freeze: {n}")
        else:
            logger.info("Finetuning transformer with PEFT parameters")
            self._disable_grad_for_components([self.transformer])

        if self.args.use_teacher:
            logger.info("Disable teacher parameters")
            self._disable_grad_for_components([self.teacher])

        # Layerwise upcasting must be applied before adding the LoRA adapter.
        # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
        # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
        if (
            self.args.training_type == "lora"
            and "transformer" in self.args.layerwise_upcasting_modules
        ):
            apply_layerwise_upcasting(
                self.transformer,
                storage_dtype=self.args.layerwise_upcasting_storage_dtype,
                compute_dtype=self.args.transformer_dtype,
                skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
                non_blocking=True,
            )

        self._move_components_to_device()

        if self.args.gradient_checkpointing:
            self.transformer.enable_gradient_checkpointing()

        if self.args.training_type == "lora":
            transformer_lora_config = LoraConfig(
                r=self.args.rank,
                lora_alpha=self.args.lora_alpha,
                init_lora_weights=True,
                target_modules=self.args.target_modules,
            )
            self.transformer.add_adapter(transformer_lora_config)
        else:
            transformer_lora_config = None

        #  it might be nice to add some assertions here to make sure that lora parameters are still in fp32
        # even if layerwise upcasting. Would be nice to have a test as well

        self.register_saving_loading_hooks(transformer_lora_config)

    def register_saving_loading_hooks(self, transformer_lora_config):
        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
        def save_model_hook(models, weights, output_dir):
            if self.state.accelerator.is_main_process:
                transformer_lora_layers_to_save = None

                for model in models:
                    if isinstance(
                        unwrap_model(self.state.accelerator, model),
                        type(unwrap_model(self.state.accelerator, self.transformer)),
                    ):
                        model = unwrap_model(self.state.accelerator, model)
                        if self.args.training_type == "lora":
                            transformer_lora_layers_to_save = get_peft_model_state_dict(
                                model
                            )
                    else:
                        raise ValueError(f"Unexpected save model: {model.__class__}")

                    # make sure to pop weight so that corresponding model is not saved again
                    if weights:
                        weights.pop()

                if self.args.training_type == "lora":
                    self.model_config["pipeline_cls"].save_lora_weights(
                        output_dir,
                        transformer_lora_layers=transformer_lora_layers_to_save,
                    )
                else:
                    model.save_pretrained(os.path.join(output_dir, "transformer"))

                    # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
                    # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
                    # ensure the scheduler config is serialized as well.
                    self.scheduler.save_pretrained(
                        os.path.join(output_dir, "scheduler")
                    )

        def load_model_hook(models, input_dir):
            if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
                while len(models) > 0:
                    model = models.pop()
                    if isinstance(
                        unwrap_model(self.state.accelerator, model),
                        type(unwrap_model(self.state.accelerator, self.transformer)),
                    ):
                        transformer_ = unwrap_model(self.state.accelerator, model)
                    else:
                        raise ValueError(
                            f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
                        )
            else:
                transformer_cls_ = unwrap_model(
                    self.state.accelerator, self.transformer
                ).__class__

                if self.args.training_type == "lora":
                    transformer_ = transformer_cls_.from_pretrained(
                        self.args.pretrained_model_name_or_path, subfolder="transformer"
                    )
                    transformer_.add_adapter(transformer_lora_config)
                    lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(
                        input_dir
                    )
                    transformer_state_dict = {
                        f'{k.replace("transformer.", "")}': v
                        for k, v in lora_state_dict.items()
                        if k.startswith("transformer.")
                    }
                    incompatible_keys = set_peft_model_state_dict(
                        transformer_, transformer_state_dict, adapter_name="default"
                    )
                    if incompatible_keys is not None:
                        # check only for unexpected keys
                        unexpected_keys = getattr(
                            incompatible_keys, "unexpected_keys", None
                        )
                        if unexpected_keys:
                            logger.warning(
                                f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                                f" {unexpected_keys}. "
                            )
                else:
                    # transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
                    pass

        self.state.accelerator.register_save_state_pre_hook(save_model_hook)
        self.state.accelerator.register_load_state_pre_hook(load_model_hook)

    def prepare_optimizer(self) -> None:
        logger.info("Initializing optimizer and lr scheduler")

        self.state.train_epochs = self.args.train_epochs
        self.state.train_steps = self.args.train_steps

        # Make sure the trainable params are in float32
        if self.args.training_type == "lora":
            cast_training_params([self.transformer], dtype=torch.float32)

        self.state.learning_rate = self.args.lr
        if self.args.scale_lr:
            self.state.learning_rate = (
                self.state.learning_rate
                * self.args.gradient_accumulation_steps
                * self.args.batch_size
                * self.state.accelerator.num_processes
            )

        # transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
        transformer_trainable_parameters = [
            p
            for n, p in self.transformer.named_parameters()
            if p.requires_grad and "quantizer" not in n
        ]
        if not (self.args.progress_training and self.args.progress_fix):
            w_quantization_trainable_parameters = [
                p
                for n, p in self.transformer.named_parameters()
                if p.requires_grad and "wquantizer" in n and ".ps" not in n
            ]
        else:
            w_quantization_trainable_parameters = [
                p
                for n, p in self.transformer.named_parameters()
                if p.requires_grad
                and "wquantizer" in n
                and ".ps" not in n
                and ("L" in n or "R" in n)
            ]
        act_quantization_trainable_parameters = [
            p
            for n, p in self.transformer.named_parameters()
            if p.requires_grad and "aquantizer" in n and ".ps" not in n
        ]
        round_quantization_trainable_parameters = [
            p
            for n, p in self.transformer.named_parameters()
            if p.requires_grad and "quantizer" in n and ".ps" in n
        ]
        p1, p2 = [], []
        if self.args.progress_training and not self.args.progress_fix:
            for n, p in self.transformer.named_parameters():
                if p.requires_grad == True and "aquantizer" not in n:
                    if "wquantizer" in n:
                        assert (
                            "L" in n or "R" in n
                        ), f"{n} is required to disable gradient."
                    else:
                        assert not n.endswith(
                            ".w"
                        ), f"{n} is required to disable gradient."
        if not self.args.qa_lora:
            params_to_optimize = [
                {
                    "params": transformer_trainable_parameters,
                    "lr": self.state.learning_rate,
                },
                {
                    "params": w_quantization_trainable_parameters,
                    "lr": self.args.w_lr
                    if not self.args.progress_training
                    else self.args.progress_lr,  # for L, R in progressive training
                },
            ]
            self.lr_dict["lr"] = 0
            self.lr_dict[
                "q_lr" if not self.args.progress_training else "progress_lr"
            ] = 0
        else:
            params_to_optimize = [
                {
                    "params": w_quantization_trainable_parameters,
                    "lr": self.args.w_lr
                    if not self.args.progress_training
                    else self.args.progress_lr,  # for L, R in progressive training
                },
            ]
            self.lr_dict[
                "q_lr" if not self.args.progress_training else "progress_lr"
            ] = 0
        if len(act_quantization_trainable_parameters) != 0:
            params_to_optimize.append(
                {
                    "params": act_quantization_trainable_parameters,
                    "lr": self.args.act_lr,
                }
            )
            self.lr_dict["q_act_lr"] = 0
            self.state.num_trainable_act_quantization_parameters = sum(
                p.numel() for p in act_quantization_trainable_parameters
            )
        if self.args.progress_training:
            if not self.args.progress_fix:
                p1 = [
                    p
                    for n, p in self.transformer.named_parameters()
                    if n.endswith(".w")
                ]
            p2 = [
                p
                for n, p in self.transformer.named_parameters()
                if (n.endswith(".scale") or n.endswith(".zero_point"))
                and "wquantizer" in n
            ]
            if not self.args.progress_fix:
                params_to_optimize.append(
                    {
                        "params": p1,
                        "lr": self.state.learning_rate,
                    }
                )
            params_to_optimize.append(
                {
                    "params": p2,
                    "lr": self.args.w_lr,
                }
            )
            if not self.args.progress_fix:
                self.lr_dict["lr_discard"] = 0
            self.lr_dict["q_lr"] = 0
            if not self.args.progress_fix:
                logger.info(f"Discard First: {len(p1)} | Quantizer: {len(p2)}")
            self.temp_p = p1 + p2
        if len(round_quantization_trainable_parameters) != 0:
            params_to_optimize.append(
                {
                    "params": round_quantization_trainable_parameters,
                    "lr": self.args.round_lr,
                }
            )
            self.state.num_trainable_round_quantization_parameters = sum(
                p.numel() for p in round_quantization_trainable_parameters
            )
            self.lr_dict["q_round_lr"] = 0
        if self.args.qa_lora:
            count = 0
            for n, m in self.transformer.named_modules():
                if isinstance(m, QuantLoRALinear):
                    avg_scale = torch.sum(m.wquantizer.scale) / torch.numel(
                        m.wquantizer.scale
                    )
                    if count < 20:
                        logger.info(
                            f"Avg. Scaling Factor: {avg_scale} | Layer: {n} | LR: {self.state.learning_rate}"
                        )
                        count += 1
                    elif count == 20:
                        logger.info(f"...")
                        count += 1
                    params_to_optimize.append(
                        {
                            "params": [m.loraA.weight, m.loraB.weight],
                            "lr": self.state.learning_rate * avg_scale,
                        }
                    )
        self.state.num_trainable_transformer_parameters = sum(
            p.numel() for p in transformer_trainable_parameters
        )
        self.state.num_trainable_w_quantization_parameters = sum(
            p.numel() for p in w_quantization_trainable_parameters
        )
        if self.args.progress_training and self.args.progress_fix:
            self.state.num_trainable_w_quantization_parameters += sum(
                p.numel() for p in p2
            )
        use_deepspeed_opt = (
            self.state.accelerator.state.deepspeed_plugin is not None
            and "optimizer"
            in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
        )
        optimizer = get_optimizer(
            params_to_optimize=params_to_optimize,
            optimizer_name=self.args.optimizer,
            learning_rate=self.state.learning_rate,
            beta1=self.args.beta1,
            beta2=self.args.beta2,
            beta3=self.args.beta3,
            epsilon=self.args.epsilon,
            weight_decay=self.args.weight_decay,
            use_8bit=self.args.use_8bit_bnb,
            use_deepspeed=use_deepspeed_opt,
        )
        # logger.info(f"Optimizer: {optimizer}")

        num_update_steps_per_epoch = math.ceil(
            len(self.dataloader)
            / self.args.gradient_accumulation_steps
            / self.state.accelerator.num_processes
        )  # fix, here num_update_steps_per_epoch is not divided by num_processes
        if self.state.train_steps is None:
            self.state.train_steps = (
                self.state.train_epochs * num_update_steps_per_epoch
            )
            self.state.overwrote_max_train_steps = True

        use_deepspeed_lr_scheduler = (
            self.state.accelerator.state.deepspeed_plugin is not None
            and "scheduler"
            in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
        )
        total_training_steps = (
            self.state.train_steps * self.state.accelerator.num_processes
        )
        num_warmup_steps = (
            self.args.lr_warmup_steps * self.state.accelerator.num_processes
        )

        if use_deepspeed_lr_scheduler:
            from accelerate.utils import DummyScheduler

            lr_scheduler = DummyScheduler(
                name=self.args.lr_scheduler,
                optimizer=optimizer,
                total_num_steps=total_training_steps,
                num_warmup_steps=num_warmup_steps,
            )
        else:
            lr_scheduler = get_scheduler(
                name=self.args.lr_scheduler,
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=total_training_steps,
                num_cycles=self.args.lr_num_cycles,
                power=self.args.lr_power,
            )

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

    def prepare_for_training(self) -> None:
        if (
            self.args.progress_training and not self.args.progress_fix
        ):  # deepspeed will pre-allocate memory for all parameters
            for p in self.temp_p:
                p.requires_grad = True
        (
            self.transformer,
            self.optimizer,
            self.dataloader,
            self.lr_scheduler,
        ) = self.state.accelerator.prepare(
            self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
        )
        if self.args.progress_training and not self.args.progress_fix:
            for p in self.temp_p:
                p.requires_grad = False
        # We need to recalculate our total training steps as the size of the training dataloader may have changed.
        num_update_steps_per_epoch = math.ceil(
            len(self.dataloader) / self.args.gradient_accumulation_steps
        )
        if self.state.overwrote_max_train_steps:
            self.state.train_steps = (
                self.state.train_epochs * num_update_steps_per_epoch
            )
        # Afterwards we recalculate our number of training epochs
        self.state.train_epochs = math.ceil(
            self.state.train_steps / num_update_steps_per_epoch
        )
        self.state.num_update_steps_per_epoch = num_update_steps_per_epoch

    def prepare_trackers(self) -> None:
        logger.info("Initializing trackers")

        tracker_name = self.args.tracker_name or "finetrainers-experiment"
        self.state.accelerator.init_trackers(
            tracker_name,
            config=self._get_training_info(),
            init_kwargs={"wandb": {"name": self.args.name}},
        )

    def train(self) -> None:
        logger.info("Starting training")

        memory_statistics = get_memory_statistics()
        logger.info(
            f"Memory before training start: {json.dumps(memory_statistics, indent=4)}"
        )

        if self.vae_config is None:
            # If we've precomputed conditions and latents already, and are now re-using it, we will never load
            # the VAE so self.vae_config will not be set. So, we need to load it here.
            vae_cls = resolve_vae_cls_from_ckpt_path(
                self.args.pretrained_model_name_or_path,
                revision=self.args.revision,
                cache_dir=self.args.cache_dir,
            )
            vae_config = vae_cls.load_config(
                self.args.pretrained_model_name_or_path,
                subfolder="vae",
                revision=self.args.revision,
                cache_dir=self.args.cache_dir,
            )
            self.vae_config = FrozenDict(**vae_config)

        # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
        # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
        # ensure the scheduler config is serialized as well.
        if self.args.training_type == "full-finetune":
            self.scheduler.save_pretrained(
                os.path.join(self.args.output_dir, "scheduler")
            )

        self.state.train_batch_size = (
            self.args.batch_size
            * self.state.accelerator.num_processes
            * self.args.gradient_accumulation_steps
        )
        info = {
            "trainable transformer parameters": self.state.num_trainable_transformer_parameters,
            "trainable weight quantization parameters": self.state.num_trainable_w_quantization_parameters,
            "trainable activation quantization parameters": self.state.num_trainable_act_quantization_parameters,
            "trainable round quantization parameters (Learnable STE)": self.state.num_trainable_round_quantization_parameters,
            "total samples": len(self.dataset),
            "train epochs": self.state.train_epochs,
            "train steps": self.state.train_steps,
            "batches per device": self.args.batch_size,
            "total batches observed per epoch": len(self.dataloader),
            "train batch size": self.state.train_batch_size,
            "gradient accumulation steps": self.args.gradient_accumulation_steps,
        }
        logger.info(f"Training configuration: {json.dumps(info, indent=4)}")

        global_step = 0
        first_epoch = 0
        initial_global_step = 0

        # Potentially load in the weights and states from a previous save
        (
            resume_from_checkpoint_path,
            initial_global_step,
            global_step,
            first_epoch,
        ) = get_latest_ckpt_path_to_resume_from(
            resume_from_checkpoint=self.args.resume_from_checkpoint,
            num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
            output_dir=self.args.output_dir,
        )
        rk = self.args.progress_rank
        if resume_from_checkpoint_path:
            if self.args.progress_training and global_step > self.args.progress_warm_up:
                temp_step = (global_step - 1) - self.args.progress_warm_up
                reduced_num = temp_step // self.args.progress_alpha_T
                rk = self.args.progress_rank // (2**reduced_num)
                if rk < 1:
                    rk = 0
                transformer = unwrap_model(self.state.accelerator, self.transformer)
                if not self.args.progress_fix:
                    for n, m in transformer.named_modules():
                        if isinstance(m, QuantLinear):
                            # m.wquantizer.L.requires_grad = False
                            # m.wquantizer.R.requires_grad = False
                            m.wquantizer.scale.requires_grad = True
                            m.wquantizer.zero_point.requires_grad = True
                            m.w.requires_grad = True
                # for param_group in self.optimizer.param_groups:
                #     for param in param_group["params"]:
                #         if not param.requires_grad:
                #             del self.optimizer.state[param]
                prev_num_trainable_transformer_parameters = (
                    self.state.num_trainable_transformer_parameters
                )
                prev_trainable_w_quantization_parameters = (
                    self.state.num_trainable_w_quantization_parameters
                )
                self.state.num_trainable_transformer_parameters = sum(
                    p.numel()
                    for n, p in transformer.named_parameters()
                    if p.requires_grad and "quantizer" not in n
                )
                self.state.num_trainable_w_quantization_parameters = sum(
                    p.numel()
                    for n, p in transformer.named_parameters()
                    if p.requires_grad and "wquantizer" in n
                )
                info = {
                    "trainable transformer parameters": self.state.num_trainable_transformer_parameters,
                    "trainable weight quantization parameters": self.state.num_trainable_w_quantization_parameters,
                    "prev. trainable transformer parameters": prev_num_trainable_transformer_parameters,
                    "prev. trainable weight quantization parameters": prev_trainable_w_quantization_parameters,
                }
                logger.info(
                    f"Changed training configuration: {json.dumps(info, indent=4)}"
                )
                for n, m in transformer.named_modules():
                    if isinstance(m, QuantLinear):
                        m.wquantizer.set_alpha_L_R(self.args.progress_warm_up)
                progress_steps = (
                    reduced_num * self.args.progress_alpha_T
                    if rk != 0
                    else self.args.progress_alpha_T
                    * (math.log2(self.args.progress_rank) + 1)
                )
                if rk != self.args.progress_rank:
                    for n, m in transformer.named_modules():
                        if isinstance(m, QuantLinear):
                            m.wquantizer.rank = rk * 2 if rk != 0 else 1
                            m.wquantizer.set_alpha_L_R(
                                self.args.progress_warm_up + progress_steps
                            )
                logger.info(
                    f"Resume for Progressive Training after Warm-Up (last global step: {global_step - 1}, rank: {rk})"
                )
            self.state.accelerator.load_state(resume_from_checkpoint_path)
            free_memory()

        progress_bar = tqdm(
            range(0, self.state.train_steps),
            initial=initial_global_step,
            desc="Training steps",
            disable=not self.state.accelerator.is_local_main_process,
        )

        accelerator = self.state.accelerator
        generator = torch.Generator(device=accelerator.device)
        if self.args.seed is not None:
            generator = generator.manual_seed(self.args.seed)
        self.state.generator = generator

        scheduler_sigmas = get_scheduler_sigmas(self.scheduler)
        scheduler_sigmas = (
            scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32)
            if scheduler_sigmas is not None
            else None
        )
        scheduler_alphas = get_scheduler_alphas(self.scheduler)
        scheduler_alphas = (
            scheduler_alphas.to(device=accelerator.device, dtype=torch.float32)
            if scheduler_alphas is not None
            else None
        )
        for epoch in range(first_epoch, self.state.train_epochs):
            logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")

            self.transformer.train()
            models_to_accumulate = [self.transformer]
            epoch_loss = 0.0
            num_loss_updates = 0

            for step, batch in enumerate(self.dataloader):
                logger.debug(
                    f"Starting step {step + 1} (Global Step {global_step + 1})"
                )
                logs = {}
                if self.args.progress_training:
                    transformer = unwrap_model(accelerator, self.transformer)
                    if global_step == self.args.progress_warm_up:
                        if not self.args.progress_fix:
                            for n, m in transformer.named_modules():
                                if isinstance(m, QuantLinear):
                                    # m.wquantizer.L.requires_grad = False
                                    # m.wquantizer.R.requires_grad = False
                                    m.wquantizer.scale.requires_grad = True
                                    m.wquantizer.zero_point.requires_grad = True
                                    m.w.requires_grad = True
                        # for param_group in self.optimizer.param_groups:
                        #     for param in param_group["params"]:
                        #         if not param.requires_grad:
                        #             del self.optimizer.state[param]
                        prev_num_trainable_transformer_parameters = (
                            self.state.num_trainable_transformer_parameters
                        )
                        prev_trainable_w_quantization_parameters = (
                            self.state.num_trainable_w_quantization_parameters
                        )
                        self.state.num_trainable_transformer_parameters = sum(
                            p.numel()
                            for n, p in transformer.named_parameters()
                            if p.requires_grad and "quantizer" not in n
                        )
                        self.state.num_trainable_w_quantization_parameters = sum(
                            p.numel()
                            for n, p in transformer.named_parameters()
                            if p.requires_grad and "wquantizer" in n
                        )
                        info = {
                            "trainable transformer parameters": self.state.num_trainable_transformer_parameters,
                            "trainable weight quantization parameters": self.state.num_trainable_w_quantization_parameters,
                            "prev. trainable transformer parameters": prev_num_trainable_transformer_parameters,
                            "prev. trainable weight quantization parameters": prev_trainable_w_quantization_parameters,
                        }
                        logger.info(
                            f"Changed training configuration: {json.dumps(info, indent=4)}"
                        )
                    for n, m in transformer.named_modules():
                        if isinstance(m, QuantLinear):
                            m.wquantizer.set_alpha_L_R(global_step)
                    temp_step = global_step - self.args.progress_warm_up
                    reduced_num = temp_step // self.args.progress_alpha_T
                    total_reduced_num = math.log2(self.args.progress_rank) + 1
                    if global_step == self.args.progress_warm_up or (
                        temp_step % self.args.progress_alpha_T == 0
                        and 0 <= reduced_num <= total_reduced_num
                    ):
                        free_memory()
                        if global_step == self.args.progress_warm_up:
                            logger.info(f"LoRA Turning (i.e., warm-up) ended.")
                        if (
                            temp_step % self.args.progress_alpha_T == 0
                            and 0 <= reduced_num <= total_reduced_num
                        ):
                            rk = self.args.progress_rank // (2**reduced_num)
                            if reduced_num != total_reduced_num:
                                logger.info(
                                    f"Progressive Training with Rank={rk} started."
                                )
                            else:
                                rk = 0
                                logger.info(f"Final Training without LoRA.")

                with accelerator.accumulate(models_to_accumulate):
                    if not self.args.precompute_conditions:
                        videos = batch["videos"]
                        prompts = batch["prompts"]
                        batch_size = len(prompts)

                        if self.args.caption_dropout_technique == "empty":
                            if random.random() < self.args.caption_dropout_p:
                                prompts = [""] * batch_size

                        latent_conditions = self.model_config["prepare_latents"](
                            vae=self.vae,
                            image_or_video=videos,
                            patch_size=self.transformer_config.patch_size,
                            patch_size_t=self.transformer_config.patch_size_t
                            if hasattr(self.transformer_config, "patch_size_t")
                            else None,
                            device=accelerator.device,
                            dtype=self.args.transformer_dtype,
                            generator=self.state.generator,
                        )
                        text_conditions = self.model_config["prepare_conditions"](
                            tokenizer=self.tokenizer,
                            text_encoder=self.text_encoder,
                            tokenizer_2=self.tokenizer_2,
                            text_encoder_2=self.text_encoder_2,
                            prompt=prompts,
                            device=accelerator.device,
                            dtype=self.args.transformer_dtype,
                        )
                    else:
                        latent_conditions = batch["latent_conditions"]
                        text_conditions = batch["text_conditions"]
                        latent_conditions["latents"] = DiagonalGaussianDistribution(
                            latent_conditions["latents"]
                        ).sample(self.state.generator)

                        # This method should only be called for precomputed latents.
                        #  rename this in separate PR
                        latent_conditions = self.model_config[
                            "post_latent_preparation"
                        ](
                            vae_config=self.vae_config,
                            patch_size=self.transformer_config.patch_size,
                            patch_size_t=self.transformer_config.patch_size_t
                            if hasattr(self.transformer_config, "patch_size_t")
                            else None,
                            **latent_conditions,
                        )
                        align_device_and_dtype(
                            latent_conditions,
                            accelerator.device,
                            self.args.transformer_dtype,
                        )
                        align_device_and_dtype(
                            text_conditions,
                            accelerator.device,
                            self.args.transformer_dtype,
                        )
                        batch_size = latent_conditions["latents"].shape[0]

                    latent_conditions = make_contiguous(latent_conditions)
                    text_conditions = make_contiguous(text_conditions)

                    if self.args.caption_dropout_technique == "zero":
                        if random.random() < self.args.caption_dropout_p:
                            text_conditions["prompt_embeds"].fill_(0)
                            text_conditions["prompt_attention_mask"].fill_(False)

                            #  refactor later
                            if "pooled_prompt_embeds" in text_conditions:
                                text_conditions["pooled_prompt_embeds"].fill_(0)

                    sigmas = prepare_sigmas(
                        scheduler=self.scheduler,
                        sigmas=scheduler_sigmas,
                        batch_size=batch_size,
                        num_train_timesteps=self.scheduler.config.num_train_timesteps,
                        flow_weighting_scheme=self.args.flow_weighting_scheme,
                        flow_logit_mean=self.args.flow_logit_mean,
                        flow_logit_std=self.args.flow_logit_std,
                        flow_mode_scale=self.args.flow_mode_scale,
                        device=accelerator.device,
                        generator=self.state.generator,
                    )
                    timesteps = (sigmas * 1000.0).long()
                    if self.args.act_quantizer == "learnable_clipped_dynamic":
                        transformer = unwrap_model(accelerator, self.transformer)
                        for n, m in transformer.named_modules():
                            if isinstance(m, (QuantLinear, QuantLoRALinear)):
                                m.aquantizer.set_idx(timesteps)
                    if self.args.round_func == "LearnableAdaSTE":
                        transformer = unwrap_model(accelerator, self.transformer)
                        for n, m in transformer.named_modules():
                            if isinstance(m, (QuantLinear, QuantLoRALinear)):
                                m.wquantizer.set_idx_ste(timesteps)
                                m.aquantizer.set_idx_ste(timesteps)

                    noise = torch.randn(
                        latent_conditions["latents"].shape,
                        generator=self.state.generator,
                        device=accelerator.device,
                        dtype=self.args.transformer_dtype,
                    )
                    sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)

                    #  We probably don't need calculate_noisy_latents because we can determine the type of
                    # scheduler and calculate the noisy latents accordingly. Look into this later.
                    if "calculate_noisy_latents" in self.model_config.keys():
                        noisy_latents = self.model_config["calculate_noisy_latents"](
                            scheduler=self.scheduler,
                            noise=noise,
                            latents=latent_conditions["latents"],
                            timesteps=timesteps,
                        )
                    else:
                        # Default to flow-matching noise addition
                        noisy_latents = (1.0 - sigmas) * latent_conditions[
                            "latents"
                        ] + sigmas * noise
                        noisy_latents = noisy_latents.to(
                            latent_conditions["latents"].dtype
                        )

                    latent_conditions.update({"noisy_latents": noisy_latents})

                    weights = prepare_loss_weights(
                        scheduler=self.scheduler,
                        alphas=scheduler_alphas[timesteps]
                        if scheduler_alphas is not None
                        else None,
                        sigmas=sigmas,
                        flow_weighting_scheme=self.args.flow_weighting_scheme,
                    )
                    weights = expand_tensor_dims(weights, noise.ndim)

                    pred = self.model_config["forward_pass"](
                        transformer=self.transformer,
                        scheduler=self.scheduler,
                        timesteps=timesteps,
                        **latent_conditions,
                        **text_conditions,
                    )
                    if not self.args.use_teacher:
                        target = prepare_target(
                            scheduler=self.scheduler,
                            noise=noise,
                            latents=latent_conditions["latents"],
                        )
                    else:
                        target = self.model_config["forward_pass"](
                            transformer=self.teacher,
                            scheduler=self.scheduler,
                            timesteps=timesteps,
                            **latent_conditions,
                            **text_conditions,
                        )
                        target = target["latents"]

                    loss = weights.float() * (
                        pred["latents"].float() - target.float()
                    ).pow(2)
                    # Average loss across all but batch dimension
                    loss = loss.mean(list(range(1, loss.ndim)))
                    # Average loss across batch dimension
                    loss = loss.mean()
                    accelerator.backward(loss)

                    if accelerator.sync_gradients:
                        if accelerator.distributed_type == DistributedType.DEEPSPEED:
                            grad_norm = self.transformer.get_global_grad_norm()
                            # In some cases the grad norm may not return a float
                            if torch.is_tensor(grad_norm):
                                grad_norm = grad_norm.item()
                        else:
                            grad_norm = accelerator.clip_grad_norm_(
                                self.transformer.parameters(), self.args.max_grad_norm
                            )
                            if torch.is_tensor(grad_norm):
                                grad_norm = grad_norm.item()

                        logs["grad_norm"] = grad_norm

                    self.optimizer.step()
                    self.lr_scheduler.step()
                    self.optimizer.zero_grad()

                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    global_step += 1

                    # Checkpointing
                    if (
                        accelerator.distributed_type == DistributedType.DEEPSPEED
                        or accelerator.is_main_process
                    ):
                        if global_step % self.args.checkpointing_steps == 0:
                            save_path = get_intermediate_ckpt_path(
                                checkpointing_limit=self.args.checkpointing_limit,
                                step=global_step,
                                output_dir=self.args.output_dir,
                            )
                            accelerator.save_state(save_path)
                            free_memory()

                    # Maybe run validation
                    should_run_validation = (
                        self.args.validation_every_n_steps is not None
                        and global_step % self.args.validation_every_n_steps == 0
                    ) or global_step >= self.state.train_steps
                    if should_run_validation:
                        self.validate(global_step)
                        if self.args.progress_training and rk != 0:
                            self.validate(global_step, progress=True, rank=rk)

                loss_item = loss.detach().item()
                epoch_loss += loss_item
                num_loss_updates += 1
                logs["step_loss"] = loss_item
                for i, (k, v) in enumerate(self.lr_dict.items()):
                    if "discard" not in k:
                        logs[k] = self.lr_scheduler.get_last_lr()[i]
                logs["timesteps"] = float(timesteps[0].item())
                progress_bar.set_postfix(logs)
                accelerator.log(logs, step=global_step)

                if global_step >= self.state.train_steps:
                    break

            if num_loss_updates > 0:
                epoch_loss /= num_loss_updates
            accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
            memory_statistics = get_memory_statistics()
            logger.info(
                f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
            )

            # Maybe run validation
            should_run_validation = (
                self.args.validation_every_n_epochs is not None
                and (epoch + 1) % self.args.validation_every_n_epochs == 0
            )
            if should_run_validation:
                self.validate(global_step)
                if self.args.progress_training and rk != 0:
                    self.validate(global_step, progress=True, rank=rk)

        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            transformer = unwrap_model(accelerator, self.transformer)

            if self.args.training_type == "lora":
                transformer_lora_layers = get_peft_model_state_dict(transformer)

                self.model_config["pipeline_cls"].save_lora_weights(
                    save_directory=self.args.output_dir,
                    transformer_lora_layers=transformer_lora_layers,
                )
            else:
                transformer.save_pretrained(
                    os.path.join(self.args.output_dir, "transformer")
                )
        accelerator.wait_for_everyone()
        self.validate(step=global_step, final_validation=True)
        if self.args.progress_training and rk != 0:
            self.validate(global_step, progress=True, final_validation=True, rank=rk)

        if accelerator.is_main_process:
            if self.args.push_to_hub:
                upload_folder(
                    repo_id=self.state.repo_id,
                    folder_path=self.args.output_dir,
                    ignore_patterns=["checkpoint-*"],
                )

        self._delete_components()
        memory_statistics = get_memory_statistics()
        logger.info(
            f"Memory after training end: {json.dumps(memory_statistics, indent=4)}"
        )

        accelerator.end_training()

    def validate(
        self,
        step: int,
        final_validation: bool = False,
        progress: bool = False,
        rank: int = 32,
    ) -> None:
        logger.info("Starting validation")

        accelerator = self.state.accelerator
        num_validation_samples = len(self.args.validation_prompts)

        if num_validation_samples == 0:
            logger.warning("No validation samples found. Skipping validation.")
            if accelerator.is_main_process:
                save_model_card(
                    args=self.args,
                    repo_id=self.state.repo_id,
                    videos=None,
                    validation_prompts=None,
                )
            return

        self.transformer.eval()

        memory_statistics = get_memory_statistics()
        logger.info(
            f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}"
        )

        pipeline = self._get_and_prepare_pipeline_for_validation(
            final_validation=final_validation
        )

        all_processes_artifacts = []
        prompts_to_filenames = {}
        for i in range(num_validation_samples):
            # Skip current validation on all processes but one
            if i % accelerator.num_processes != accelerator.process_index:
                continue

            prompt = self.args.validation_prompts[i]
            image = self.args.validation_images[i]
            video = self.args.validation_videos[i]
            height = self.args.validation_heights[i]
            width = self.args.validation_widths[i]
            num_frames = self.args.validation_num_frames[i]
            frame_rate = self.args.validation_frame_rate
            if image is not None:
                image = load_image(image)
            if video is not None:
                video = load_video(video)

            logger.debug(
                f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
                main_process_only=False,
            )
            if self.args.act_quantizer == "learnable_clipped_dynamic":
                for n, m in pipeline.transformer.named_modules():
                    if isinstance(m, (QuantLinear, QuantLoRALinear)):
                        m.aquantizer.set_auto(True)
            if progress:
                for n, m in pipeline.transformer.named_modules():
                    if isinstance(m, (QuantLinear, QuantLoRALinear)):
                        m.wquantizer.validation(True)
            validation_artifacts = self.model_config["validation"](
                pipeline=pipeline,
                prompt=prompt,
                image=image,
                video=video,
                height=height,
                width=width,
                num_frames=num_frames,
                frame_rate=frame_rate,
                num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
                generator=torch.Generator(device=accelerator.device).manual_seed(
                    self.args.seed if self.args.seed is not None else 0
                ),
                # todo support passing `fps` for supported pipelines.
            )
            if progress:
                for n, m in pipeline.transformer.named_modules():
                    if isinstance(m, (QuantLinear, QuantLoRALinear)):
                        m.wquantizer.validation(False)
            prompt_filename = string_to_filename(prompt)[:25]
            artifacts = {
                "image": {"type": "image", "value": image},
                "video": {"type": "video", "value": video},
            }
            for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
                if artifact_value:
                    artifacts.update(
                        {
                            f"artifact_{i}": {
                                "type": artifact_type,
                                "value": artifact_value,
                            }
                        }
                    )
            logger.debug(
                f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
                main_process_only=False,
            )

            for index, (key, value) in enumerate(list(artifacts.items())):
                artifact_type = value["type"]
                artifact_value = value["value"]
                if artifact_type not in ["image", "video"] or artifact_value is None:
                    continue

                extension = "png" if artifact_type == "image" else "mp4"
                filename = "validation-" if not final_validation else "final-"
                if progress:
                    filename = f"progress-no-rank{rank}-" + filename
                filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}"
                if accelerator.is_main_process and extension == "mp4":
                    prompts_to_filenames[prompt] = filename
                filename = os.path.join(self.args.output_dir, filename)

                if artifact_type == "image" and artifact_value:
                    logger.debug(f"Saving image to {filename}")
                    artifact_value.save(filename)
                    artifact_value = wandb.Image(filename)
                elif artifact_type == "video" and artifact_value:
                    logger.debug(f"Saving video to {filename}")
                    # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
                    export_to_video(artifact_value, filename, fps=frame_rate)
                    artifact_value = wandb.Video(filename, caption=prompt)

                all_processes_artifacts.append(artifact_value)

        all_artifacts = gather_object(all_processes_artifacts)

        if accelerator.is_main_process:
            tracker_key = "final" if final_validation else "validation"
            if progress:
                tracker_key = f"progress-no-rank{rank}-" + tracker_key
            for tracker in accelerator.trackers:
                if tracker.name == "wandb":
                    artifact_log_dict = {}

                    image_artifacts = [
                        artifact
                        for artifact in all_artifacts
                        if isinstance(artifact, wandb.Image)
                    ]
                    if len(image_artifacts) > 0:
                        artifact_log_dict["images"] = image_artifacts
                    video_artifacts = [
                        artifact
                        for artifact in all_artifacts
                        if isinstance(artifact, wandb.Video)
                    ]
                    if len(video_artifacts) > 0:
                        artifact_log_dict["videos"] = video_artifacts
                    tracker.log({tracker_key: artifact_log_dict}, step=step)

            if self.args.push_to_hub and final_validation:
                video_filenames = list(prompts_to_filenames.values())
                prompts = list(prompts_to_filenames.keys())
                save_model_card(
                    args=self.args,
                    repo_id=self.state.repo_id,
                    videos=video_filenames,
                    validation_prompts=prompts,
                )

        # Remove all hooks that might have been added during pipeline initialization to the models
        pipeline.remove_all_hooks()
        del pipeline

        accelerator.wait_for_everyone()

        free_memory()
        memory_statistics = get_memory_statistics()
        logger.info(
            f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}"
        )
        torch.cuda.reset_peak_memory_stats(accelerator.device)

        if not final_validation:
            self.transformer.train()

    def evaluate(self) -> None:
        raise NotImplementedError("Evaluation has not been implemented yet.")

    def _init_distributed(self) -> None:
        logging_dir = Path(self.args.output_dir, self.args.logging_dir)
        project_config = ProjectConfiguration(
            project_dir=self.args.output_dir, logging_dir=logging_dir
        )
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        init_process_group_kwargs = InitProcessGroupKwargs(
            backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
        )
        report_to = (
            None if self.args.report_to.lower() == "none" else self.args.report_to
        )

        accelerator = Accelerator(
            project_config=project_config,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
            log_with=report_to,
            kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
        )

        # Disable AMP for MPS.
        if torch.backends.mps.is_available():
            accelerator.native_amp = False

        self.state.accelerator = accelerator

        if self.args.seed is not None:
            self.state.seed = self.args.seed
            set_seed(self.args.seed)

    def _init_logging(self) -> None:
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=FINETRAINERS_LOG_LEVEL,
        )
        if self.state.accelerator.is_local_main_process:
            transformers.utils.logging.set_verbosity_warning()
            diffusers.utils.logging.set_verbosity_info()
        else:
            transformers.utils.logging.set_verbosity_error()
            diffusers.utils.logging.set_verbosity_error()

        logger.info("Initialized FineTrainers")
        logger.info(self.state.accelerator.state, main_process_only=False)

    def _init_directories_and_repositories(self) -> None:
        if self.state.accelerator.is_main_process:
            self.args.output_dir = Path(self.args.output_dir)
            self.args.output_dir.mkdir(parents=True, exist_ok=True)
            self.state.output_dir = Path(self.args.output_dir)

            if self.args.push_to_hub:
                repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
                self.state.repo_id = create_repo(
                    token=self.args.hub_token, repo_id=repo_id, exist_ok=True
                ).repo_id

    def _init_config_options(self) -> None:
        # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
        if self.args.allow_tf32 and torch.cuda.is_available():
            torch.backends.cuda.matmul.allow_tf32 = True

    def _move_components_to_device(self):
        if self.text_encoder is not None:
            self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
        if self.text_encoder_2 is not None:
            self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
        if self.text_encoder_3 is not None:
            self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
        if self.transformer is not None:
            self.transformer = self.transformer.to(self.state.accelerator.device)
        if self.unet is not None:
            self.unet = self.unet.to(self.state.accelerator.device)
        if self.vae is not None:
            self.vae = self.vae.to(self.state.accelerator.device)
        if self.teacher is not None:
            self.teacher = self.teacher.to(self.state.accelerator.device)

    def _get_load_components_kwargs(self) -> Dict[str, Any]:
        load_component_kwargs = {
            "text_encoder_dtype": self.args.text_encoder_dtype,
            "text_encoder_2_dtype": self.args.text_encoder_2_dtype,
            "text_encoder_3_dtype": self.args.text_encoder_3_dtype,
            "transformer_dtype": self.args.transformer_dtype,
            "vae_dtype": self.args.vae_dtype,
            "shift": self.args.flow_shift,
            "revision": self.args.revision,
            "cache_dir": self.args.cache_dir,
        }
        if self.args.pretrained_model_name_or_path is not None:
            load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
        return load_component_kwargs

    def _set_components(self, components: Dict[str, Any]) -> None:
        # Set models
        self.tokenizer = components.get("tokenizer", self.tokenizer)
        self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
        self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
        self.text_encoder = components.get("text_encoder", self.text_encoder)
        self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
        self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
        self.transformer = components.get("transformer", self.transformer)
        self.unet = components.get("unet", self.unet)
        self.vae = components.get("vae", self.vae)
        self.scheduler = components.get("scheduler", self.scheduler)

        # Set configs
        self.transformer_config = (
            self.transformer.config
            if self.transformer is not None
            else self.transformer_config
        )
        self.vae_config = self.vae.config if self.vae is not None else self.vae_config

    def _delete_components(self) -> None:
        self.tokenizer = None
        self.tokenizer_2 = None
        self.tokenizer_3 = None
        self.text_encoder = None
        self.text_encoder_2 = None
        self.text_encoder_3 = None
        self.transformer = None
        self.unet = None
        self.vae = None
        self.scheduler = None
        free_memory()
        torch.cuda.synchronize(self.state.accelerator.device)

    def _get_and_prepare_pipeline_for_validation(
        self, final_validation: bool = False
    ) -> DiffusionPipeline:
        accelerator = self.state.accelerator
        if not final_validation:
            pipeline = self.model_config["initialize_pipeline"](
                model_id=self.args.pretrained_model_name_or_path,
                tokenizer=self.tokenizer,
                text_encoder=self.text_encoder,
                tokenizer_2=self.tokenizer_2,
                text_encoder_2=self.text_encoder_2,
                transformer=unwrap_model(accelerator, self.transformer),
                vae=self.vae,
                device=accelerator.device,
                revision=self.args.revision,
                cache_dir=self.args.cache_dir,
                enable_slicing=self.args.enable_slicing,
                enable_tiling=self.args.enable_tiling,
                enable_model_cpu_offload=self.args.enable_model_cpu_offload,
                is_training=True,
            )
        else:
            self._delete_components()

            # Load the transformer weights from the final checkpoint if performing full-finetune
            transformer = None
            # if self.args.training_type == "full-finetune":
            #     transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"]

            pipeline = self.model_config["initialize_pipeline"](
                model_id=self.args.pretrained_model_name_or_path,
                transformer=transformer,
                device=accelerator.device,
                revision=self.args.revision,
                cache_dir=self.args.cache_dir,
                enable_slicing=self.args.enable_slicing,
                enable_tiling=self.args.enable_tiling,
                enable_model_cpu_offload=self.args.enable_model_cpu_offload,
                is_training=False,
            )

            # Load the LoRA weights if performing LoRA finetuning
            if self.args.training_type == "lora":
                pipeline.load_lora_weights(self.args.output_dir)

        return pipeline

    def _disable_grad_for_components(self, components: List[torch.nn.Module]):
        for component in components:
            if component is not None:
                component.requires_grad_(False)

    def _enable_grad_for_components(self, components: List[torch.nn.Module]):
        for component in components:
            if component is not None:
                component.requires_grad_(True)

    def _get_training_info(self) -> dict:
        args = self.args.to_dict()

        training_args = args.get("training_arguments", {})
        training_type = training_args.get("training_type", "")

        # LoRA/non-LoRA stuff.
        if training_type == "full-finetune":
            filtered_training_args = {
                k: v
                for k, v in training_args.items()
                if k not in {"rank", "lora_alpha", "target_modules"}
            }
        else:
            filtered_training_args = training_args

        # Diffusion/flow stuff.
        diffusion_args = args.get("diffusion_arguments", {})
        scheduler_name = self.scheduler.__class__.__name__
        if scheduler_name != "FlowMatchEulerDiscreteScheduler":
            filtered_diffusion_args = {
                k: v for k, v in diffusion_args.items() if "flow" not in k
            }
        else:
            filtered_diffusion_args = diffusion_args

        # Rest of the stuff.
        updated_training_info = args.copy()
        updated_training_info["training_arguments"] = filtered_training_args
        updated_training_info["diffusion_arguments"] = filtered_diffusion_args
        return updated_training_info
