import os
import math
import wandb
import random
import logging
import inspect
import argparse
import datetime
import subprocess

from pathlib import Path
from tqdm.auto import tqdm
from einops import rearrange
from omegaconf import OmegaConf
from safetensors import safe_open
from typing import Dict, Optional, Tuple, List
import numpy as np
import re
from PIL import Image

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim.swa_utils import AveragedModel
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

import diffusers
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.models import UNet2DConditionModel
from diffusers.pipelines import StableDiffusionPipeline
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available

import transformers
from transformers import CLIPTextModel, CLIPTokenizer

from animatediff.data.dataset import WebVid10M, HAAVideo
from animatediff.models.unet import UNet3DConditionModel
from animatediff.models.sparse_controlnet import SparseControlNetModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid, zero_rank_print
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora_unet_text_encoder

from animatediff.models.additional import LatentRectify, cross_frame_corr, MaskPredictor
from animatediff.models import global_utils
from animatediff.models.vae_decoder import DetailAutoencoderKL

from animatediff.data.data_distort_utils import negative_videos, change_app_video_v1, change_app_video_v2, change_app_video_v3, change_app_video_v5, change_app_video_v6, change_app_video_v7, change_app_video_v8, change_app_video_v9, change_app_video_v10, change_app_video_v11


def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
    """Initializes distributed environment."""
    if launcher == 'pytorch':
        rank = int(os.environ['RANK'])
        num_gpus = torch.cuda.device_count()
        local_rank = rank % num_gpus
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend=backend, **kwargs)
        
    elif launcher == 'slurm':
        proc_id = int(os.environ['SLURM_PROCID'])
        ntasks = int(os.environ['SLURM_NTASKS'])
        node_list = os.environ['SLURM_NODELIST']
        num_gpus = torch.cuda.device_count()
        local_rank = proc_id % num_gpus
        torch.cuda.set_device(local_rank)
        addr = subprocess.getoutput(
            f'scontrol show hostname {node_list} | head -n1')
        os.environ['MASTER_ADDR'] = addr
        os.environ['WORLD_SIZE'] = str(ntasks)
        os.environ['RANK'] = str(proc_id)
        port = os.environ.get('PORT', port)
        os.environ['MASTER_PORT'] = str(port)
        dist.init_process_group(backend=backend)
        zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
        
    else:
        raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
    
    return local_rank



def main(
    image_finetune: bool,
    
    name: str,
    use_wandb: bool,
    launcher: str,
    
    output_dir: str,
    pretrained_model_path: str,

    train_data: Dict,
    validation_data: Dict,

    cfg_random_null_text: bool = False,
    cfg_random_null_text_ratio: float = 0.1,

    cfg_negative_video: float = 0.0,
    cfg_app_change_video: float = 0.0,
    cfg_app_change_video_version: str = "1",
    cfg_app_change_video_type: str = "replace",
    cfg_loss_scaling_app_change_video: bool = False,
    cfg_loss_scaling_app_change_video_weight: float = 0.5,
    cfg_same_noise_app_change_video: bool = False,
    cfg_mix_video_source: str = "",
    cfg_diff_loss_weight: float = 0.0,
    cfg_mask_loss_factor: float = 0.0,
    motion_importance_sampling: bool = False,
    motion_importance_sampling_alpha: float = 0.5,
    latent_rectify: bool = False,
    latent_rectify_dim: int = 256,
    latent_rectify_loss_weight: float = 1.0,
    latent_rectify_add_static: bool = False,
    latent_rectify_add_shuffle: bool = False,
    latent_rectify_scale: float = 0.0,
    latent_rectify_clamp: float = 1.0,
    cfg_latent_corr_loss_weight: float = 0.0,
    cfg_latent_corr_loss_global: bool = False,
    cfg_feature_ranking_loss_weight: float = 0.0,
    loss_on_down_block_idx: Optional[List[int]] = None,
    loss_on_up_block_idx: Optional[List[int]] = None,
    cfg_feature_ranking_loss_sampling: bool = False,
    cfg_on_predictor_mask_loss_weight: float = 0.0,
    cfg_on_generator_mask_loss_weight: float = 0.0,
    mask_predictor_checkpoint_path: str = "",
    cfg_transformation_consistency_loss_weight: float = 0.0,
    cfg_transformation_smoothness_loss_weight: float = 0.0,
    cfg_noise_gt_first_frame: bool = False,
    cfg_unnoise_first_frame: bool = False,
    cfg_input_perturbation: Optional[List[float]] = None,
    cfg_loss_aug_strength: bool = False,
    cfg_loss_input_perturbation: bool = False,
    
    noise_scheduler_kwargs = None,
    unet_additional_kwargs: Dict = {},

    detail_vae_checkpoint_path: str = "",
    
    unet_checkpoint_path: str = "",
    pretrained_image_adapter_path: str = "",
    pretrained_image_adapter_alpha: float = 1.0,
    image_adapter_lora_path: str = "",
    image_adapter_lora_alpha: float = 1.0,

    pretrained_motion_module_path: str = "",

    controlnet_config_file: str = "",
    controlnet_checkpoint_path: str = "",
    controlnet_not_load_image_params: bool = False,
    train_controlnet: bool = False,
    trainable_controlnet_modules: Tuple[str] = (),
    
    max_train_epoch: int = -1,
    max_train_steps: int = 100,
    validation_steps: int = 100,
    validation_steps_tuple: Tuple = (-1,),

    learning_rate: float = 3e-5,
    scale_lr: bool = False,
    lr_warmup_steps: int = 0,
    lr_scheduler: str = "constant",

    trainable_modules: Tuple[str] = (None, ),

    num_workers: int = 16,
    train_batch_size: int = 1,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.999,
    adam_weight_decay: float = 1e-2,
    adam_epsilon: float = 1e-08,
    max_grad_norm: float = 1.0,
    gradient_accumulation_steps: int = 1,
    gradient_checkpointing: bool = True,
    checkpointing_epochs: int = 5,
    checkpointing_steps: int = -1,
    checkpointing_steps_tuple: Tuple = (-1,),

    mixed_precision_training: bool = True,
    enable_xformers_memory_efficient_attention: bool = True,

    global_seed: int = 42,
    is_debug: bool = False,

    dataset: str = 'WebVid',
):
    check_min_version("0.10.0.dev0")

    # Initialize distributed training
    local_rank      = init_dist(launcher=launcher)
    global_rank     = dist.get_rank()
    num_processes   = dist.get_world_size()
    is_main_process = global_rank == 0

    seed = global_seed + global_rank
    torch.manual_seed(seed)
    
    # Logging folder
    folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
    output_dir = os.path.join(output_dir, folder_name)
    if is_debug and os.path.exists(output_dir):
        os.system(f"rm -rf {output_dir}")

    *_, config = inspect.getargvalues(inspect.currentframe())

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    if is_main_process and (not is_debug) and use_wandb:
        run = wandb.init(project="animatediff", name=folder_name, config=config)

    # Handle the output folder creation
    if is_main_process:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/samples", exist_ok=True)
        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
        OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))

    # Load scheduler, tokenizer and models.
    noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))

    if detail_vae_checkpoint_path == "":
        vae          = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
    else:
        vae = DetailAutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae", low_cpu_mem_usage=False) # low_cpu_mem_usage=False since it has new parameters
    vae.enable_slicing()
    tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
    if not image_finetune:
        unet = UNet3DConditionModel.from_pretrained_2d(
            pretrained_model_path, subfolder="unet", 
            unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
        )
    else:
        unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")

    # Load vae weights
    if detail_vae_checkpoint_path != "":
        zero_rank_print(f"Load pretrained vae from checkpoint: {detail_vae_checkpoint_path}")
        vae_checkpoint = torch.load(detail_vae_checkpoint_path, map_location="cpu")
        state_dict = vae_checkpoint["state_dict"] if "state_dict" in vae_checkpoint else vae_checkpoint
        extra_param_state_dict = {}
        for key in list(state_dict.keys()):
            if 'ref_attn_blocks' in key or 'channel_matching_convs' in key:
                extra_param_state_dict[key.replace('module.', '')] = state_dict.pop(key)
        m, u = vae.decoder.load_state_dict(extra_param_state_dict, strict=False)
        zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0

    # Load pretrained image adapter weights
    if pretrained_image_adapter_path != "":
        zero_rank_print(f"Load pretrained image adapter from checkpoint: {pretrained_image_adapter_path}")
        domain_lora_state_dict = torch.load(pretrained_image_adapter_path, map_location="cpu")
        domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
        for name, param in domain_lora_state_dict.copy().items():
            if name.startswith("module."):
                domain_lora_state_dict.pop(name)
                name = name.replace("module.", "")
                domain_lora_state_dict[name] = param

        for key in domain_lora_state_dict:
            if "up." in key: continue
            up_key = key.replace(".down.", ".up.")
            model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "").replace("to_out.", "to_out.0.")
            layer_infos = model_key.split(".")[:-1]
            curr_layer = unet
            while len(layer_infos) > 0:
                temp_name = layer_infos.pop(0)
                curr_layer = curr_layer.__getattr__(temp_name)
            weight_down = domain_lora_state_dict[key]
            weight_up   = domain_lora_state_dict[up_key]
            curr_layer.weight.data += pretrained_image_adapter_alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)

    # Load image adapter lora weights
    if image_adapter_lora_path != "":
        zero_rank_print(f"Load lora model from {image_adapter_lora_path}")
        assert image_adapter_lora_path.endswith(".safetensors")
        lora_state_dict = {}
        with safe_open(image_adapter_lora_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                lora_state_dict[key] = f.get_tensor(key)
        unet, text_encoder = convert_lora_unet_text_encoder(unet, text_encoder, lora_state_dict, alpha=image_adapter_lora_alpha)

    # Load pretrained motion module weights
    if pretrained_motion_module_path != "":
        unet_motion_state_dict = {}
        zero_rank_print(f"Load pretrained motion module from {pretrained_motion_module_path}")
        motion_module_state_dict = torch.load(pretrained_motion_module_path, map_location="cpu")
        motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
        unet_motion_state_dict.update({name.replace('module.',''): param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
        m, u = unet.load_state_dict(unet_motion_state_dict, strict=False)
        zero_rank_print(f"{len(u)} parameters cannot be loaded. {len(m)} parameters are missing.")
        del unet_motion_state_dict

    # Load pretrained unet weights
    if unet_checkpoint_path != "":
        zero_rank_print(f"Load pretrained unet from checkpoint: {unet_checkpoint_path}")
        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
        state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
        for key in list(state_dict.keys()):
            state_dict[key.replace('module.', '')] = state_dict.pop(key)
        m, u = unet.load_state_dict(state_dict, strict=False)
        zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0 and len(m) == 0
        global_step = unet_checkpoint_path["global_step"]
        first_epoch = unet_checkpoint_path["epoch"]
    else:
        global_step = 0
        first_epoch = 0

    ### copy fix parameters in unet
    num = 0
    for k in unet.state_dict().keys():
        if 'copy' in k:
            num += 1
            copy_k = k.replace('motion_modules', 'attentions').replace('temporal_transformer.transformer_blocks', 'transformer_blocks').replace('attention_blocks.0.motion_predictor.visual_copy_trans', 'attn2.to_q').replace('attention_blocks.0.motion_predictor.textual_copy_trans', 'attn2.to_k')
            unet.state_dict()[k].data = unet.state_dict()[copy_k].data
            num -= 1
    zero_rank_print(f"{num} parameters can not copy.")

    # Controlnet
    if controlnet_config_file != "":
        unet.config.num_attention_heads = 8
        unet.config.projection_class_embeddings_input_dim = None
        controlnet_config = OmegaConf.load(controlnet_config_file)
        controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))
        if controlnet_checkpoint_path != "":
            controlnet_state_dict = torch.load(controlnet_checkpoint_path, map_location="cpu")
            controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
            controlnet_state_dict.pop("animatediff_config", "")
            if controlnet_not_load_image_params:
                new_controlnet_state_dict = {}
                for name, param in controlnet_state_dict.items():
                    if "motion_modules." in name or "conv_in" in name or "time_proj" in name or "time_embedding" in name or "controlnet" in name:
                        new_controlnet_state_dict[name] = param
                zero_rank_print(f"Load {len(new_controlnet_state_dict)} pretrained controlnet params")
                m, u = controlnet.load_state_dict(new_controlnet_state_dict, strict=False)
                zero_rank_print(f"{len(u)} parameters cannot be loaded. {len(m)} parameters are missing.")
                del new_controlnet_state_dict
            else:
                zero_rank_print(f"Load {len(controlnet_state_dict)} pretrained controlnet params")
                m, u = controlnet.load_state_dict(controlnet_state_dict, strict=False)
                zero_rank_print(f"{len(u)} parameters cannot be loaded. {len(m)} parameters are missing.")
                del controlnet_state_dict
        num = 0
        for k in controlnet.state_dict().keys():
            if 'copy' in k:
                num += 1
                copy_k = k.replace('motion_modules', 'attentions').replace('temporal_transformer.transformer_blocks', 'transformer_blocks').replace('attention_blocks.0.motion_predictor.visual_copy_trans', 'attn2.to_q').replace('attention_blocks.0.motion_predictor.textual_copy_trans', 'attn2.to_k')
                controlnet.state_dict()[k].data = controlnet.state_dict()[copy_k].data
                num -= 1
        zero_rank_print(f"{num} parameters can not copy.")

    if latent_rectify:
        latent_rectify_module = LatentRectify(inner_dim=latent_rectify_dim)

    if cfg_on_predictor_mask_loss_weight > 0.0 or cfg_on_generator_mask_loss_weight > 0.0:
        mask_predictor = MaskPredictor(4,1)
        if cfg_on_generator_mask_loss_weight > 0.0:
            assert mask_predictor_checkpoint_path != ""
            mask_predictor_state_dict = torch.load(mask_predictor_checkpoint_path, map_location="cpu")["state_dict"]
            new_mask_predictor_state_dict = {}
            new_mask_predictor_state_dict.update({name.replace('module.',''): param for name, param in mask_predictor_state_dict.items()})
            mask_predictor.load_state_dict(new_mask_predictor_state_dict)
        
    # Freeze vae, text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    
    # Set unet trainable parameters
    unet.requires_grad_(False)
    for name, param in unet.named_parameters():
        for trainable_module_name in trainable_modules:
            if trainable_module_name in name:
                param.requires_grad = True
                break
    trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
    train_unet = len(trainable_params) > 0
    # Set controlnet trainable parameters
    if controlnet_config_file != "":
        if train_controlnet:
            if len(trainable_controlnet_modules) > 0:
                controlnet.requires_grad_(False)
                for name, param in controlnet.named_parameters():
                    for trainable_module_name in trainable_controlnet_modules:
                        if trainable_module_name in name:
                            param.requires_grad = True
                            break
            trainable_params += list(filter(lambda p: p.requires_grad, controlnet.parameters()))
        else:
            controlnet.requires_grad_(False)

    if latent_rectify:
        trainable_params += list(filter(lambda p: p.requires_grad, latent_rectify_module.parameters()))

    if cfg_on_generator_mask_loss_weight > 0.0:
        mask_predictor.requires_grad_(False)
    if cfg_on_predictor_mask_loss_weight > 0.0:
        trainable_params += list(filter(lambda p: p.requires_grad, mask_predictor.parameters()))

    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

    if is_main_process:
        zero_rank_print(f"trainable params number: {len(trainable_params)}")
        zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")

    ### delete due to different versions of diffusers and torch
    # https://huggingface.co/docs/diffusers/en/using-diffusers/img2img
    # # Enable xformers
    # if enable_xformers_memory_efficient_attention:
    #     if is_xformers_available():
    #         unet.enable_xformers_memory_efficient_attention()
    #     else:
    #         raise ValueError("xformers is not available. Make sure it is installed correctly")

    # Enable gradient checkpointing
    if gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        if train_controlnet:
            controlnet.enable_gradient_checkpointing()

    # Move models to GPU
    vae.to(local_rank)
    text_encoder.to(local_rank)
    unet.to(local_rank)
    if controlnet_config_file != "":
        controlnet.to(local_rank)
    if latent_rectify:
        latent_rectify_module.to(local_rank)
    if cfg_on_predictor_mask_loss_weight > 0.0 or cfg_on_generator_mask_loss_weight > 0.0:
        mask_predictor.to(local_rank)

    # Get the training dataset
    if dataset == 'WebVid':
        train_dataset = WebVid10M(**train_data, is_image=image_finetune)
    elif dataset == 'HAA':
        train_dataset = HAAVideo(**train_data)
    else:
        assert False, "Not Implemented."
    distributed_sampler = DistributedSampler(
        train_dataset,
        num_replicas=num_processes,
        rank=global_rank,
        shuffle=True,
        seed=global_seed,
    )

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=False,
        sampler=distributed_sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
    )

    # Get the training iteration
    if max_train_steps == -1:
        assert max_train_epoch != -1
        max_train_steps = max_train_epoch * len(train_dataloader)
        
    if checkpointing_steps == -1:
        assert checkpointing_epochs != -1
        checkpointing_steps = checkpointing_epochs * len(train_dataloader)

    if scale_lr:
        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)

    # Scheduler
    lr_scheduler = get_scheduler(
        lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
        num_training_steps=max_train_steps * gradient_accumulation_steps,
    )

    # Validation pipeline
    if not image_finetune:
        validation_pipeline = AnimationPipeline(
            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, controlnet=controlnet if controlnet_config_file != "" else None, latent_rectify_module=latent_rectify_module if latent_rectify_scale > 0.0 else None,
        ).to("cuda")
    else:
        validation_pipeline = StableDiffusionPipeline.from_pretrained(
            pretrained_model_path,
            unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
        )

    # DDP warpper
    if train_unet:
        unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
    if controlnet_config_file != "" and train_controlnet:
        controlnet = DDP(controlnet, device_ids=[local_rank], output_device=local_rank)
    if latent_rectify:
        latent_rectify_module = DDP(latent_rectify_module, device_ids=[local_rank], output_device=local_rank)
    if cfg_on_predictor_mask_loss_weight > 0.0:
        mask_predictor = DDP(mask_predictor, device_ids=[local_rank], output_device=local_rank)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps

    if is_main_process:
        logging.info("***** Running training *****")
        logging.info(f"  Num examples = {len(train_dataset)}")
        logging.info(f"  Num Epochs = {num_train_epochs}")
        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
        logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
        logging.info(f"  Total optimization steps = {max_train_steps}")

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(total=max_train_steps, initial=global_step, disable=not is_main_process)
    progress_bar.set_description("Steps")

    # Support mixed-precision training
    scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None

    # load candidates to mix video
    if cfg_app_change_video > 0.0 and cfg_app_change_video_version == "5":
        mix_images = os.listdir(cfg_mix_video_source)
        assert len(mix_images) > 0

    if cfg_loss_aug_strength or cfg_loss_input_perturbation:
        sample_loss_dict = {}
        if cfg_loss_aug_strength:
            assert cfg_app_change_video_version in ["6", "7", "8", "9"]

    for epoch in range(first_epoch, num_train_epochs):
        train_dataloader.sampler.set_epoch(epoch)
        unet.train()
        if train_controlnet:
            controlnet.train()
        
        for step, batch in enumerate(train_dataloader):
            if 'global_random' in unet_additional_kwargs['unet_use_cross_frame_attention']:
                global_utils.update_global_random()
            if cfg_random_null_text:
                for ii in range(len(batch['text'])):
                    random_num = random.random()
                    if random_num <= cfg_random_null_text_ratio:
                        batch['text'][ii] = ""

            # allow training with appearance-changed videos
            if cfg_app_change_video > 0.0:
                batch['pixel_values'] = batch['pixel_values'].to(local_rank)
                if cfg_app_change_video_version == "1":
                    batch['pixel_values'], batch['text'] = change_app_video_v1(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video)
                elif cfg_app_change_video_version == "2":
                    batch['hed_values'] = batch['hed_values'].to(local_rank)
                    batch['pixel_values'], batch['text'] = change_app_video_v2(batch['pixel_values'], batch['hed_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video)
                elif cfg_app_change_video_version == "3":
                    batch['hed_values'] = batch['hed_values'].to(local_rank)
                    batch['pixel_values'], batch['text'] = change_app_video_v3(batch['pixel_values'], batch['hed_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video)
                elif cfg_app_change_video_version == "5":
                    batch['pixel_values'], batch['text'] = change_app_video_v5(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video, cfg_mix_video_source, mix_images)
                elif cfg_app_change_video_version == "6":
                    batch['pixel_values'], batch['text'] = change_app_video_v6(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video, batch['video_id'], sample_loss_dict)
                elif cfg_app_change_video_version == "7":
                    batch['pixel_values'], batch['text'] = change_app_video_v7(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video, batch['video_id'], sample_loss_dict)
                elif cfg_app_change_video_version == "8":
                    batch['pixel_values'], batch['text'] = change_app_video_v8(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video, batch['video_id'], sample_loss_dict)
                elif cfg_app_change_video_version == "9":
                    batch['pixel_values'], batch['text'] = change_app_video_v9(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video, batch['video_id'], sample_loss_dict)
                elif cfg_app_change_video_version == "10":
                    batch['pixel_values'], batch['text'] = change_app_video_v10(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video)
                elif cfg_app_change_video_version == "11":
                    batch['pixel_values'], batch['text'] = change_app_video_v11(batch['pixel_values'], batch['text'], cfg_app_change_video_type, cfg_app_change_video)
                else:
                    assert False, "Not Implemented."
                
            # allow training with negative videos
            if cfg_negative_video > 0.0:
                batch['pixel_values'], batch['text'] = negative_videos(batch['pixel_values'], batch['text'], cfg_negative_video)
                
            # Data batch sanity check
            assert batch['pixel_values'].shape[1] == train_data.sample_n_frames
            if epoch == first_epoch and step == 0:
                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
                if not image_finetune:
                    pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
                        pixel_value = pixel_value[None, ...]
                        save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
                else:
                    for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
                        pixel_value = pixel_value / 2. + 0.5
                        torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
                    
            ### >>>> Training >>>> ###
            
            # Convert videos to latent space            
            pixel_values = batch["pixel_values"].to(local_rank)
            if cfg_feature_ranking_loss_weight > 0.0:
                global_utils.reset_feature_ranking_loss()
                if "depth_values" in batch.keys():
                    global_utils.update_ranking_reference(pixel_values, batch["depth_values"].to(local_rank))
                else:
                    global_utils.update_ranking_reference(pixel_values, None)
            video_length = pixel_values.shape[1]
            with torch.no_grad():
                if not image_finetune:
                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
                    latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
                    if cfg_mask_loss_factor > 0:
                        latent_mask = torch.nn.functional.sigmoid(cfg_mask_loss_factor * torch.norm(latents - torch.cat([latents[:,:,0,:,:].unsqueeze(2), latents[:,:,:-1,:,:]], dim=2), dim=1, keepdim=True))
                        latent_mask = latent_mask / torch.mean(latent_mask)
                else:
                    latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215

                if controlnet_config_file != "":
                    controlnet_images = latents[:,:,0,:,:].clone().unsqueeze(2)

                if cfg_latent_corr_loss_weight > 0.0:
                    attn_blocks = unet_additional_kwargs['motion_module_kwargs']['attention_block_types']
                    if controlnet_config_file != "":
                        attn_blocks.extend(controlnet_config.get("controlnet_additional_kwargs", {})['motion_module_kwargs']['attention_block_types'])
                    corr_mean_filtering = []
                    for blk in attn_blocks:
                        if 'CORR' in blk:
                            corr_mean_filtering.append('Mean' in blk)
                    if len(corr_mean_filtering) > 0:
                        assert all(elem == corr_mean_filtering[0] for elem in corr_mean_filtering)
                        corr_mean_filtering = corr_mean_filtering[0]
                    else:
                        corr_mean_filtering = False
                    corr_local_cal = []
                    for blk in attn_blocks:
                        if 'CORR' in blk:
                            corr_local_cal.append('Local' in blk)
                    if len(corr_local_cal) > 0:
                        assert all(elem == corr_local_cal[0] for elem in corr_local_cal)
                        corr_local_cal = corr_local_cal[0]
                    else:
                        corr_local_cal = False
                    corr_normalize = []
                    for blk in attn_blocks:
                        if 'CORR' in blk:
                            corr_normalize.append('Norm' in blk)
                    if len(corr_normalize) > 0:
                        assert all(elem == corr_normalize[0] for elem in corr_normalize)
                        corr_normalize = corr_normalize[0]
                    else:
                        corr_normalize = False
                    corr_intra_ref = []
                    for blk in attn_blocks:
                        if 'CORR' in blk:
                            corr_intra_ref.append('REFIntra' in blk or 'REFOnlyIntra' in blk)
                    if len(corr_intra_ref) > 0:
                        assert all(elem == corr_intra_ref[0] for elem in corr_intra_ref)
                        corr_intra_ref = corr_intra_ref[0]
                    else:
                        corr_intra_ref = False
                    corr_adj_ref = []
                    for blk in attn_blocks:
                        if 'CORR' in blk:
                            corr_adj_ref.append('REFFirst' not in blk and 'REFOnlyIntra' not in blk)
                    if len(corr_adj_ref) > 0:
                        assert all(elem == corr_adj_ref[0] for elem in corr_adj_ref)
                        corr_adj_ref = corr_adj_ref[0]
                    else:
                        corr_adj_ref = False
                    corr_first_ref = []
                    for blk in attn_blocks:
                        if 'CORR' in blk:
                            corr_first_ref.append('REFFirst' in blk or 'REFBoth' in blk)
                    if len(corr_first_ref) > 0:
                        assert all(elem == corr_first_ref[0] for elem in corr_first_ref)
                        corr_first_ref = corr_first_ref[0]
                    else:
                        corr_first_ref = False
                    corr_masking = []
                    for blk in attn_blocks:
                        if 'CORR' in blk:
                            corr_masking.append('MASK' in blk)
                    if len(corr_masking) > 0:
                        assert all(elem == corr_masking[0] for elem in corr_masking)
                        corr_masking = corr_masking[0]
                    else:
                        corr_masking = False
                    global_utils.reset_video_corr_loss()
                    if corr_masking:
                        global_utils.update_masking()
                    global_utils.update_video_corr(cross_frame_corr(apply_mean_filtering=corr_mean_filtering, apply_local_calculation=corr_local_cal, apply_intra_frame_ref=corr_intra_ref, apply_adjacent_frame_ref=corr_adj_ref, apply_first_frame_ref=corr_first_ref, apply_masking=corr_masking).calculate_corr(latents))

            # Reset transformation loss
            global_utils.reset_transformation_consistency_loss()
            global_utils.reset_transformation_smoothness_loss()
            # Sample noise that we'll add to the latents
            if cfg_same_noise_app_change_video:
                tmp_latents = torch.chunk(latents,2,0)
                tmp_noise = torch.randn_like(tmp_latents[0])
                noise = torch.cat([tmp_noise]*2, dim=0)
            else:
                noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            
            # Sample a random timestep for each video
            if motion_importance_sampling:
                prob_dist = [1 / noise_scheduler.config.num_train_timesteps * (1 - motion_importance_sampling_alpha * math.cos(math.pi * t / noise_scheduler.config.num_train_timesteps)) for t in range(noise_scheduler.config.num_train_timesteps)]
                prob_sum = 0
                for pp in prob_dist:
                    prob_sum += pp
                prob_dist = [x / prob_sum for x in prob_dist]
                timesteps = np.random.choice(
                    list(range(noise_scheduler.config.num_train_timesteps)),
                    size=bsz,
                    replace=True,
                    p=prob_dist)
                timesteps = torch.tensor(timesteps, device=latents.device)
            else:
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            
            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            ### input perturbation
            if cfg_input_perturbation is not None:
                if len(cfg_input_perturbation) == 1:
                    new_noise = cfg_input_perturbation[0] * torch.randn_like(noise)
                elif len(cfg_input_perturbation) == 2:
                    new_noise = []
                    for ii in range(noise.shape[2]):
                        new_noise.append(random.uniform(cfg_input_perturbation[0], cfg_input_perturbation[1]) * torch.randn_like(noise[:,:,ii,:,:]))
                    new_noise = torch.stack(new_noise, dim=2)
                else:
                    assert False
                noisy_latents = noise_scheduler.add_noise(latents, noise + new_noise, timesteps)
            elif cfg_loss_input_perturbation:
                all_sample_loss = []
                for sample_key in sample_loss_dict.keys():
                    all_sample_loss.append(sample_loss_dict[sample_key])
                sample_loss_33 = np.percentile(all_sample_loss, 33) if len(all_sample_loss) > 0 else 0
                new_noise = []
                assert noise.shape[0] % len(batch['video_id']) == 0
                for ii in range(noise.shape[0]):
                    video_id_ind = ii % len(batch['video_id'])
                    if batch['video_id'][video_id_ind] in sample_loss_dict.keys() and sample_loss_dict[batch['video_id'][video_id_ind]] < sample_loss_33:
                        new_noise.append(0.005 * torch.randn_like(noise[ii,:,:,:,:]))
                    else:
                        new_noise.append(torch.zeros_like(noise[ii,:,:,:,:]))
                new_noise = torch.stack(new_noise, dim=0)
                noisy_latents = noise_scheduler.add_noise(latents, noise + new_noise, timesteps)
            else:
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            if cfg_unnoise_first_frame:
                noisy_latents[:,:,0,:,:] = latents[:,:,0,:,:]
                noise[:,:,0,:,:] = 0
            
            with torch.no_grad():
                # Get the text embedding for conditioning
                prompt_ids = tokenizer(
                    batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
                ).input_ids.to(latents.device)
                encoder_hidden_states = text_encoder(prompt_ids)[0]

                # controlnet configs
                if controlnet_config_file != "":
                    controlnet_image_index = [0]

                    controlnet_cond_shape    = list(controlnet_images.shape)
                    controlnet_cond_shape[2] = video_length
                    controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device)

                    controlnet_conditioning_mask_shape    = list(controlnet_cond.shape)
                    controlnet_conditioning_mask_shape[1] = 1
                    controlnet_conditioning_mask          = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device)

                    assert controlnet_images.shape[2] >= len(controlnet_image_index)
                    controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
                    if "mask" not in batch.keys() and "rel_pixel_values" not in batch.keys():
                        controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
                    elif "mask" in batch.keys():
                        if cfg_on_predictor_mask_loss_weight > 0.0 or cfg_on_generator_mask_loss_weight > 0.0:
                            tmp_mask = rearrange(batch["mask"].to(controlnet_conditioning_mask.device), "b f c h w -> (b f) c h w")
                            tmp_mask = F.interpolate(tmp_mask,(controlnet_conditioning_mask.shape[-2],controlnet_conditioning_mask.shape[-1]))
                            controlnet_conditioning_mask = rearrange(tmp_mask, "(b f) c h w -> b c f h w", f=video_length)
                        else:
                            controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
                    elif "rel_pixel_values" in batch.keys():
                        tmp_mask = rearrange(batch["rel_pixel_values"].to(controlnet_conditioning_mask.device), "b f c h w -> (b f) c h w")
                        tmp_mask = F.interpolate(tmp_mask,(controlnet_conditioning_mask.shape[-2],controlnet_conditioning_mask.shape[-1]))
                        controlnet_conditioning_mask = rearrange(tmp_mask, "(b f) c h w -> b c f h w", f=video_length)
                    else:
                        assert False

                    if not train_controlnet:
                        down_block_additional_residuals, mid_block_additional_residual = controlnet(
                            noisy_latents, timesteps,
                            encoder_hidden_states=encoder_hidden_states,
                            controlnet_cond=controlnet_cond,
                            conditioning_mask=controlnet_conditioning_mask,
                            conditioning_scale=1.0,
                            guess_mode=False, return_dict=False
                        )
            if latent_rectify:
                latent_rectify_loss = latent_rectify_module(noisy_latents, encoder_hidden_states, timesteps, include_static=latent_rectify_add_static, include_shuffle=latent_rectify_add_shuffle)[0]
            if controlnet_config_file != "" and train_controlnet:
                down_block_additional_residuals, mid_block_additional_residual = controlnet(
                    noisy_latents, timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    controlnet_cond=controlnet_cond,
                    conditioning_mask=controlnet_conditioning_mask,
                    conditioning_scale=1.0,
                    guess_mode=False, return_dict=False
                    )
                
            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                raise NotImplementedError
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            # Mixed-precision training
            with torch.cuda.amp.autocast(enabled=mixed_precision_training):
                if controlnet_config_file != "":
                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, loss_on_down_block_idx=loss_on_down_block_idx if cfg_feature_ranking_loss_weight > 0 else None, loss_on_up_block_idx=loss_on_up_block_idx if cfg_feature_ranking_loss_weight > 0 else None, loss_sampling=cfg_feature_ranking_loss_weight > 0 and cfg_feature_ranking_loss_sampling).sample
                else:
                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, loss_on_down_block_idx=loss_on_down_block_idx if cfg_feature_ranking_loss_weight > 0 else None, loss_on_up_block_idx=loss_on_up_block_idx if cfg_feature_ranking_loss_weight > 0 else None, loss_sampling=cfg_feature_ranking_loss_weight > 0 and cfg_feature_ranking_loss_sampling).sample
                if cfg_mask_loss_factor > 0:
                    loss = (latent_mask * F.mse_loss(model_pred.float(), target.float(), reduction="none")).mean()
                elif cfg_loss_scaling_app_change_video:
                    tmp_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                    tmp_loss = tmp_loss.reshape((tmp_loss.shape[0], -1))
                    tmp_loss = torch.chunk(tmp_loss,2)
                    loss = (1 - cfg_loss_scaling_app_change_video_weight) * tmp_loss[0].mean() + cfg_loss_scaling_app_change_video_weight * tmp_loss[1].mean()
                elif cfg_loss_aug_strength or cfg_loss_input_perturbation:
                    tmp_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                    tmp_loss = tmp_loss.reshape((tmp_loss.shape[0], -1)).mean(-1)
                    assert tmp_loss.shape[0] == bsz
                    if len(batch['video_id']) == bsz:
                        for tmp_loss_ind in range(bsz):
                            if batch['video_id'][tmp_loss_ind] not in sample_loss_dict.keys():
                                sample_loss_dict[batch['video_id'][tmp_loss_ind]] = tmp_loss[tmp_loss_ind].item()
                            else:
                                sample_loss_dict[batch['video_id'][tmp_loss_ind]] = 0.5 * sample_loss_dict[batch['video_id'][tmp_loss_ind]] + 0.5 * tmp_loss[tmp_loss_ind].item()
                    else:
                        assert bsz % len(batch['video_id']) == 0
                        tmp_loss = tmp_loss.reshape((-1, len(batch['video_id'])))
                        assert tmp_loss.shape[0] == bsz // len(batch['video_id'])
                        tmp_loss = tmp_loss.mean(0)
                        for tmp_loss_ind in range(len(batch['video_id'])):
                            if batch['video_id'][tmp_loss_ind] not in sample_loss_dict.keys():
                                sample_loss_dict[batch['video_id'][tmp_loss_ind]] = tmp_loss[tmp_loss_ind].item()
                            else:
                                sample_loss_dict[batch['video_id'][tmp_loss_ind]] = 0.5 * sample_loss_dict[batch['video_id'][tmp_loss_ind]] + 0.5 * tmp_loss[tmp_loss_ind].item()
                    loss = tmp_loss.mean()
                else:
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                if cfg_diff_loss_weight > 0:
                    model_pred_diff = model_pred[:,:,1:,:,:] - model_pred[:,:,:-1,:,:]
                    target_diff = target[:,:,1:,:,:] - target[:,:,:-1,:,:]
                    loss += cfg_diff_loss_weight * F.mse_loss(model_pred_diff.float(), target_diff.float(), reduction="mean")
                if latent_rectify:
                    loss += latent_rectify_loss_weight * latent_rectify_loss
                if cfg_latent_corr_loss_weight > 0.0:
                    if cfg_latent_corr_loss_global:
                        noise_scheduler.set_timesteps(noise_scheduler.config.num_train_timesteps)
                        model_pred_video = noise_scheduler.step(model_pred, timesteps, noisy_latents).prev_sample
                        model_pred_video_corr = cross_frame_corr(apply_mean_filtering=corr_mean_filtering, apply_local_calculation=corr_local_cal, apply_intra_frame_ref=corr_intra_ref, apply_adjacent_frame_ref=corr_adj_ref, apply_first_frame_ref=corr_first_ref).calculate_corr(model_pred_video)
                        global_utils.update_video_corr_loss(model_pred_video_corr, corr_normalize)
                    loss += cfg_latent_corr_loss_weight * global_utils.video_corr_loss
                if cfg_transformation_consistency_loss_weight > 0.0:
                    loss += cfg_transformation_consistency_loss_weight * global_utils.transformation_consistency_loss
                if cfg_transformation_smoothness_loss_weight > 0.0:
                    loss += cfg_transformation_smoothness_loss_weight * global_utils.transformation_smoothness_loss
                if cfg_on_predictor_mask_loss_weight > 0.0:
                    mask_predictor_in = rearrange(latents.float(), "b c f h w -> (b f) c h w")
                    mask_predictor_out = mask_predictor(mask_predictor_in)
                    mask_predictor_out = rearrange(mask_predictor_out, "(b f) c h w -> b c f h w", b=bsz)
                    mask_predictor_gt = rearrange(tmp_mask, "(b f) c h w -> b c f h w", f=video_length)
                    mask_predictor_gt[mask_predictor_gt >= 0.5] = 1.0
                    mask_predictor_gt[mask_predictor_gt < 0.5] = 0.0
                    loss += cfg_on_predictor_mask_loss_weight * F.binary_cross_entropy_with_logits(mask_predictor_out, mask_predictor_gt)
                if cfg_on_generator_mask_loss_weight > 0.0:
                    noise_scheduler.set_timesteps(noise_scheduler.config.num_train_timesteps)
                    model_pred_video = noise_scheduler.step(model_pred, timesteps, noisy_latents).pred_original_sample
                    mask_predictor_in = rearrange(model_pred_video.float(), "b c f h w -> (b f) c h w")
                    mask_predictor_out = mask_predictor(mask_predictor_in)
                    mask_predictor_out = rearrange(mask_predictor_out, "(b f) c h w -> b c f h w", b=bsz)
                    mask_predictor_gt = rearrange(tmp_mask, "(b f) c h w -> b c f h w", f=video_length)
                    loss += cfg_on_predictor_mask_loss_weight * F.binary_cross_entropy_with_logits(mask_predictor_out, mask_predictor_gt)
                if cfg_feature_ranking_loss_weight > 0.0:
                    loss += cfg_feature_ranking_loss_weight * global_utils.feature_ranking_loss

            # optimizer.zero_grad()
            loss = loss / gradient_accumulation_steps

            # Backpropagate
            if mixed_precision_training:
                scaler.scale(loss).backward()
                """ >>> gradient clipping >>> """
                if (step + 1) % gradient_accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
                """ <<< gradient clipping <<< """
                if (step + 1) % gradient_accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
            else:
                loss.backward()
                """ >>> gradient clipping >>> """
                torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
                """ <<< gradient clipping <<< """
                if (step + 1) % gradient_accumulation_steps == 0:
                    optimizer.step()

            if (step + 1) % gradient_accumulation_steps == 0:
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1*gradient_accumulation_steps)
            global_step += 1
            
            ### <<<< Training <<<< ###
            
            # Wandb logging
            if is_main_process and (not is_debug) and use_wandb:
                wandb.log({"train_loss": loss.item()}, step=global_step)
                
            # Save checkpoint
            if is_main_process and global_step % checkpointing_steps == 0 or global_step in checkpointing_steps_tuple:
                if train_unet:
                    save_path = os.path.join(output_dir, f"checkpoints-unet")
                    os.makedirs(save_path, exist_ok=True)
                    state_dict = {
                        "epoch": epoch,
                        "global_step": global_step,
                        "state_dict": unet.state_dict(),
                    }
                    torch.save(state_dict, os.path.join(save_path, f"checkpoint-{global_step}.ckpt"))
                    logging.info(f"Saved state to {save_path} (global_step: {global_step})")
                if train_controlnet:
                    save_path = os.path.join(output_dir, f"checkpoints-controlnet")
                    os.makedirs(save_path, exist_ok=True)
                    state_dict = {
                        "epoch": epoch,
                        "global_step": global_step,
                        "state_dict": controlnet.state_dict(),
                    }
                    torch.save(state_dict, os.path.join(save_path, f"checkpoint-{global_step}.ckpt"))
                    logging.info(f"Saved state to {save_path} (global_step: {global_step})")
                if latent_rectify:
                    save_path = os.path.join(output_dir, f"checkpoints-LatentRectify")
                    os.makedirs(save_path, exist_ok=True)
                    state_dict = {
                        "epoch": epoch,
                        "global_step": global_step,
                        "state_dict": latent_rectify_module.state_dict(),
                    }
                    torch.save(state_dict, os.path.join(save_path, f"checkpoint-{global_step}.ckpt"))
                    logging.info(f"Saved state to {save_path} (global_step: {global_step})")
                if cfg_on_predictor_mask_loss_weight > 0.0:
                    save_path = os.path.join(output_dir, f"checkpoints-MaskPredictor")
                    os.makedirs(save_path, exist_ok=True)
                    state_dict = {
                        "epoch": epoch,
                        "global_step": global_step,
                        "state_dict": mask_predictor.state_dict(),
                    }
                    torch.save(state_dict, os.path.join(save_path, f"checkpoint-{global_step}.ckpt"))
                    logging.info(f"Saved state to {save_path} (global_step: {global_step})")
                
            # Periodically validation
            if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
                unet.eval()
                if train_controlnet:
                    controlnet.eval()
                samples = []
                
                generator = torch.Generator(device=latents.device)
                generator.manual_seed(global_seed)
                
                height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
                width  = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size

                prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
                if hasattr(validation_data, "n_prompts"):
                    n_prompts = validation_data.n_prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.n_prompts
                else:
                    n_prompts = [""]*len(prompts)

                if controlnet_config_file != "":
                    controlnet_images_path = validation_data.controlnet_images_path[:2] if global_step < 1000 and (not image_finetune) else validation_data.controlnet_images_path
                    if "controlnet_masks_path" in validation_data.keys():
                        controlnet_masks_path = validation_data.controlnet_masks_path[:2] if global_step < 1000 and (not image_finetune) else validation_data.controlnet_masks_path
                        assert len(controlnet_masks_path) == len(controlnet_images_path)
                    assert len(prompts) == len(controlnet_images_path)
                    image_transforms = transforms.Compose([
                        transforms.Resize((height, width), antialias=True),
                        # transforms.ColorJitter(brightness=0.2, contrast=0.2),
                        transforms.ToTensor(),
                    ])
                if controlnet_config_file == "" and (cfg_noise_gt_first_frame or cfg_unnoise_first_frame):
                    controlnet_images_path = validation_data.controlnet_images_path[:2] if global_step < 1000 and (not image_finetune) else validation_data.controlnet_images_path
                    assert len(prompts) == len(controlnet_images_path)
                    image_transforms = transforms.Compose([
                        transforms.Resize((height, width), antialias=True),
                        # transforms.ColorJitter(brightness=0.2, contrast=0.2),
                        transforms.ToTensor(),
                    ])

                for idx, (prompt, n_prompt) in enumerate(zip(prompts, n_prompts)):
                    if not image_finetune:
                        ### load groundtruth video and add noise
                        if "groundtruth_video_path" in validation_data.keys():
                            
                            from decord import VideoReader
                            video_transforms = transforms.Compose([
                                transforms.Resize((height, width), antialias=True),
                            ])
                            video_reader = VideoReader(validation_data["groundtruth_video_path"][idx])
                            batch_index = np.linspace(0, len(video_reader) - 1, train_data.sample_n_frames, dtype=int)
                            gt_video = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
                            gt_video = video_transforms(gt_video).cuda()
                            gt_video_latent = vae.encode(gt_video / 255. * 2. - 1.).latent_dist.sample() * 0.18215
                            gt_video_latent = rearrange(gt_video_latent.unsqueeze(0), "b f c h w -> b c f h w")
                            validation_pipeline.scheduler.set_timesteps(validation_data.num_inference_steps)
                            latents = validation_pipeline.scheduler.add_noise(gt_video_latent, torch.randn_like(gt_video_latent), validation_pipeline.scheduler.timesteps[0])
                        else:
                            latents = None
                        if controlnet_config_file != "":
                            controlnet_images = image_transforms(Image.open(controlnet_images_path[idx]).convert("RGB")).cuda().unsqueeze(0)
                            if detail_vae_checkpoint_path != "":
                                ref_pyramid_features = vae.encoder.pyramid_feature_forward(controlnet_images * 2. - 1.)
                            controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample().unsqueeze(2) * 0.18215
                            if "controlnet_masks_path" in validation_data.keys():
                                assert "controlnet_rel_video_path" not in validation_data.keys()
                                if "mask_oracle" in validation_data.keys() and validation_data.mask_oracle:
                                    controlnet_masks_root_path = os.path.dirname(controlnet_masks_path[idx])
                                    num_of_mask = len(os.listdir(controlnet_masks_root_path))
                                    if num_of_mask > 1:
                                        num_of_sample = train_data.sample_n_frames
                                        sample_ind = [int(i/(num_of_sample-1)*(num_of_mask-1))+1 for i in range(num_of_sample)]
                                        controlnet_masks = []
                                        for ind in sample_ind:
                                            tmp_mask = image_transforms(Image.open(os.path.join(controlnet_masks_root_path, "%05d.jpg"%max(1,min(ind,num_of_mask)))).convert("L")).cuda()
                                            controlnet_masks.append(tmp_mask)
                                        controlnet_masks = torch.stack(controlnet_masks)
                                        controlnet_masks = F.interpolate(controlnet_masks,(controlnet_images.shape[-2],controlnet_images.shape[-1])).unsqueeze(0)
                                        controlnet_masks = rearrange(controlnet_masks, "b f c h w -> b c f h w")
                                    else:
                                        controlnet_masks = image_transforms(Image.open(controlnet_masks_path[idx]).convert("L")).cuda().unsqueeze(0)
                                        controlnet_masks = F.interpolate(controlnet_masks,(controlnet_images.shape[-2],controlnet_images.shape[-1])).unsqueeze(2)
                                else:
                                    controlnet_masks = image_transforms(Image.open(controlnet_masks_path[idx]).convert("L")).cuda().unsqueeze(0)
                                    controlnet_masks = F.interpolate(controlnet_masks,(controlnet_images.shape[-2],controlnet_images.shape[-1])).unsqueeze(2)
                            if "controlnet_rel_video_path" in validation_data.keys():
                                assert "controlnet_masks_path" not in validation_data.keys()
                                from decord import VideoReader
                                rel_video_reader = VideoReader(validation_data.controlnet_rel_video_path[idx])
                                rel_video_length = len(rel_video_reader)
                                rel_frame_index = np.linspace(0, rel_video_length -1, train_data.sample_n_frames, dtype=int)
                                rel_pixel_values = torch.from_numpy(np.dot(rel_video_reader.get_batch(rel_frame_index).asnumpy(), [0.299, 0.587, 0.114])).unsqueeze(-1).permute(0, 3, 1, 2).contiguous().to(controlnet_images.dtype)
                                rel_pixel_values = rel_pixel_values / 255.
                                controlnet_masks = F.interpolate(rel_pixel_values,(controlnet_images.shape[-2],controlnet_images.shape[-1])).unsqueeze(0)
                                controlnet_masks = rearrange(controlnet_masks, "b f c h w -> b c f h w")

                            controlnet_image_index = [0]
                            sample = validation_pipeline(
                                prompt,
                                negative_prompt = n_prompt,
                                generator    = generator,
                                video_length = train_data.sample_n_frames,
                                height       = height,
                                width        = width,
                                controlnet_images = controlnet_images,
                                controlnet_masks = controlnet_masks if "controlnet_masks_path" in validation_data.keys() or "controlnet_rel_video_path" in validation_data.keys() else None,
                                controlnet_image_index = controlnet_image_index,
                                latent_rectify_scale=latent_rectify_scale,
                                latent_rectify_clamp=latent_rectify_clamp,
                                cfg_noise_gt_first_frame=cfg_noise_gt_first_frame,
                                cfg_unnoise_first_frame = cfg_unnoise_first_frame,
                                ref_features = ref_pyramid_features if detail_vae_checkpoint_path != "" else None,
                                latents = latents,
                                **validation_data,
                            ).videos
                        else:
                            if cfg_noise_gt_first_frame or cfg_unnoise_first_frame:
                                controlnet_images = image_transforms(Image.open(controlnet_images_path[idx]).convert("RGB")).cuda().unsqueeze(0)
                                if detail_vae_checkpoint_path != "":
                                    ref_pyramid_features = vae.encoder.pyramid_feature_forward(controlnet_images * 2. - 1.)
                                controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample().unsqueeze(2) * 0.18215
                                sample = validation_pipeline(
                                    prompt,
                                    negative_prompt = n_prompt,
                                    generator    = generator,
                                    video_length = train_data.sample_n_frames,
                                    height       = height,
                                    width        = width,
                                    controlnet_images = controlnet_images,
                                    cfg_noise_gt_first_frame=cfg_noise_gt_first_frame,
                                    cfg_unnoise_first_frame = cfg_unnoise_first_frame,
                                    ref_features = ref_pyramid_features if detail_vae_checkpoint_path != "" else None,
                                    latents = latents,
                                    **validation_data,
                                ).videos
                            else:
                                sample = validation_pipeline(
                                    prompt,
                                    generator    = generator,
                                    video_length = train_data.sample_n_frames,
                                    height       = height,
                                    width        = width,
                                    **validation_data,
                                ).videos
                        save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
                        samples.append(sample)
                        
                    else:
                        sample = validation_pipeline(
                            prompt,
                            generator           = generator,
                            height              = height,
                            width               = width,
                            num_inference_steps = validation_data.get("num_inference_steps", 25),
                            guidance_scale      = validation_data.get("guidance_scale", 8.),
                        ).images[0]
                        sample = torchvision.transforms.functional.to_tensor(sample)
                        samples.append(sample)
                
                if not image_finetune:
                    samples = torch.concat(samples)
                    save_path = f"{output_dir}/samples/sample-{global_step}.gif"
                    save_videos_grid(samples, save_path)
                else:
                    samples = torch.stack(samples)
                    save_path = f"{output_dir}/samples/sample-{global_step}.png"
                    torchvision.utils.save_image(samples, save_path, nrow=4)

                logging.info(f"Saved samples to {save_path}")

                unet.train()
                if train_controlnet:
                    controlnet.train()
                
            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            
            if global_step >= max_train_steps:
                break

            torch.cuda.empty_cache()
            
    dist.destroy_process_group()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",   type=str, required=True)
    parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
    parser.add_argument("--wandb",    action="store_true")
    args = parser.parse_args()

    # name   = Path(args.config).stem
    name = args.config.split('/')[-3] + '_' + args.config.split('/')[-2] + '_' + args.config.split('/')[-1].split('.')[0]
    config = OmegaConf.load(args.config)

    main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)
