import argparse
import copy
import logging
import math
import os
import os.path as osp
import pathlib
import random
from itertools import chain
import time
import warnings
from collections import OrderedDict
from datetime import datetime
from pathlib import Path
from typing import Callable, List, Optional, Union, Dict, Tuple, Union, Any
from PIL import Image

import diffusers
import mlflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision.transforms.functional as transforms_f
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.optimization import get_scheduler
from einops import rearrange
from omegaconf import OmegaConf, ListConfig
from tqdm.auto import tqdm
from transformers import (
    CLIPVisionModelWithProjection,
    CLIPImageProcessor,
    Wav2Vec2Model,
    Wav2Vec2Processor,
    CLIPTextModel,
    CLIPTokenizer,
)

from datasets import TalkingFaceVideo, TalkingFaceImage
from modules import (
    UNet3DConditionModel,
    VKpsGuider,
    AudioProjection,
    T2IAdapter,
    MultiAdapter,
)

from utils.module_utils import zero_module_params, is_torch2_available
from utils.utils import (
    seed_everything,
    count_updated_params,
    check_zero_initialization,
    get_module_params,
    print_highlighted_block_log,
    ColoredLogger,
)
from utils.adapter_utils import t2i_adapter_map_keys

warnings.filterwarnings("ignore")

from peft import get_peft_model
from utils.lora_utils import load_denoise_unet_lora


if is_torch2_available():
    from modules.attention_processor import (
        AttnProcessor2_0 as AttnProcessor,
    )
    from modules.attention_processor import (
        IPAttnProcessor2_0 as IPAttnProcessor,
    )
    from modules.attention_processor_decoupled import IPAttnProcessorDecoupled2_0 as IPAttnProcessorDecoupled
else:
    from modules.attention_processor import AttnProcessor, IPAttnProcessor
    from modules.attention_processor_decoupled import IPAttnProcessorDecoupled

logger = get_logger(__name__, log_level="INFO")


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    """
    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
    """
    std_text = noise_pred_text.std(
        dim=list(range(1, noise_pred_text.ndim)), keepdim=True
    )
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = (
        guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    )
    return noise_cfg


def custom_sample_timesteps(num_timesteps, batch_size, skew_type="early", skew_factor=2.0, device='cpu'):
    """
    Sample timesteps with a custom skewed distribution.

    Args:
        num_timesteps (int): Total number of timesteps.
        batch_size (int): Batch size.
        skew_type (str): Type of skew ("early" or "late").
        skew_factor (float): Factor to control the skewness.
        device (torch.device): Device to place the sampled timesteps.

    Returns:
        torch.Tensor: Sampled timesteps.
    """
    # Create a probability distribution
    if skew_type == "early":
        weights = torch.linspace(skew_factor, 1, num_timesteps)
    elif skew_type == "late":
        weights = torch.linspace(1, skew_factor, num_timesteps)
    else:
        raise ValueError("Invalid skew_type. Choose 'early' or 'late'.")

    # Normalize the weights to create a probability distribution
    weights = weights / weights.sum()

    # Sample timesteps based on the custom probability distribution
    timesteps = torch.multinomial(weights, batch_size, replacement=True)

    return timesteps.to(device)

class Net(nn.Module):

    def __init__(
        self,
        denoising_unet: UNet3DConditionModel,
        v_kps_guider: VKpsGuider,
        t2i_adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
        audio_projection: AudioProjection,
        image_encoder: CLIPVisionModelWithProjection,
        tokenizer: CLIPTokenizer,
        text_encoder: CLIPTextModel,
        device: str,
        weight_dtype,
        kps_drop_rate: float = 0.0,
        faceid_drop_rate: float = 0.0,
        ip_ckpt: str = None,
        ip_mode: str = None,
        resampler_depth: int = 4,
        num_tokens: int = 16,
        n_cond: int = 1,
    ):
        super().__init__()
        self.denoising_unet = denoising_unet
        self.v_kps_guider = v_kps_guider
        self.t2i_adapter = t2i_adapter
        self.audio_projection = audio_projection
        self.image_encoder = image_encoder
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.device = device
        self.weight_dtype = weight_dtype
        self.kps_drop_rate = kps_drop_rate
        self.ip_ckpt = ip_ckpt
        self.ip_mode = ip_mode
        self.resampler_depth = resampler_depth
        self.num_tokens = num_tokens
        self.n_cond = n_cond
        self.ip_mode = ip_mode
        self.faceid_drop_rate = faceid_drop_rate
        self.disable_kps = self.v_kps_guider is None
        self.apply_t2i_adapter = self.t2i_adapter is not None

        self.clip_image_processor = CLIPImageProcessor()

        if ip_ckpt is not None:
            self.set_ip_adapter()
            if self.ip_mode in ['faceid-decoupled']:
                self.image_proj_model_1, self.image_proj_model_2 = self.init_proj()
                self.image_proj_model_1.requires_grad_(False) # REVIEW
                self.image_proj_model_2.requires_grad_(False) # REVIEW
            else:
                self.image_proj_model = self.init_proj()
                self.image_proj_model.requires_grad_(False) # REVIEW
            self.load_ip_adapter()

    def init_proj(self):
        if self.ip_mode in ['faceid', 'portrait']:
            from modules.image_projection import MLPProjModel

            image_proj_model = MLPProjModel(
                cross_attention_dim=self.denoising_unet.config.cross_attention_dim,
                id_embeddings_dim=512,
                num_tokens=self.num_tokens,
            ).to(self.device, dtype=self.weight_dtype)
            return image_proj_model
        elif self.ip_mode in ['plus', 'vanilla']:
            if self.ip_mode == 'plus':
                from modules.image_projection import Resampler
                img_projector = Resampler
            elif self.ip_mode == 'vanilla':
                from modules.image_projection import ImageProjection
                img_projector = ImageProjection

            image_proj_model = img_projector(
                dim=self.denoising_unet.config.cross_attention_dim,
                depth=4,
                dim_head=64,
                heads=12,
                num_queries=self.num_tokens,
                embedding_dim=self.image_encoder.config.hidden_size,
                output_dim=self.denoising_unet.config.cross_attention_dim,
                ff_mult=4,
            ).to(self.device, dtype=self.weight_dtype)
            return image_proj_model
        elif self.ip_mode == 'faceid-plus':
            from modules.image_projection import ProjPlusModel

            image_proj_model = ProjPlusModel(
                cross_attention_dim=self.denoising_unet.config.cross_attention_dim,
                id_embeddings_dim=512,
                clip_embeddings_dim=self.image_encoder.config.hidden_size,
                num_tokens=self.num_tokens,
            ).to(self.device, dtype=self.weight_dtype)
            return image_proj_model
        elif self.ip_mode in ['faceid-decoupled']:
            from modules.image_projection import MLPProjModel, Resampler
            image_proj_model_1 = MLPProjModel(
                cross_attention_dim=self.denoising_unet.config.cross_attention_dim,
                id_embeddings_dim=512,
                num_tokens=self.num_tokens[0],
            ).to(self.device, dtype=self.weight_dtype)

            image_proj_model_2 = Resampler(
                dim=self.denoising_unet.config.cross_attention_dim,
                depth=self.resampler_depth,
                dim_head=64,
                heads=12,
                num_queries=self.num_tokens[1],
                embedding_dim=self.image_encoder.config.hidden_size,
                output_dim=self.denoising_unet.config.cross_attention_dim,
                ff_mult=4,
            ).to(self.device, dtype=self.weight_dtype)

            return image_proj_model_1, image_proj_model_2

        else:
            raise ValueError('Not supported type of IP-Adpater Image Projection Model!')

    def set_ip_adapter(self):
        attn_procs = {}
        for name in self.denoising_unet.attn_processors.keys():
            if not name.endswith("attn1_7.processor"):
                attn_procs[name] = AttnProcessor()
            else:
                cross_attention_dim = self.denoising_unet.config.cross_attention_dim
                if name.startswith("mid_block"):
                    hidden_size = self.denoising_unet.config.block_out_channels[-1]
                elif name.startswith("up_blocks"):
                    block_id = int(name[len("up_blocks.")])
                    hidden_size = list(
                        reversed(self.denoising_unet.config.block_out_channels)
                    )[block_id]
                elif name.startswith("down_blocks"):
                    block_id = int(name[len("down_blocks.")])
                    hidden_size = self.denoising_unet.config.block_out_channels[
                        block_id
                    ]

                if self.ip_mode in ['faceid-decoupled']:
                    attn_procs[name] = IPAttnProcessorDecoupled(
                        hidden_size=hidden_size,
                        cross_attention_dim=cross_attention_dim,
                        scale=[0.7, 0.3],
                        num_tokens=self.num_tokens * self.n_cond,
                    ).to(self.device, dtype=torch.float16)

                    # Freeze the parameters in the to_k_ip and to_v_ip layers
                    for param in attn_procs[name].to_k_ip_1.parameters():
                        param.requires_grad = False
                    for param in attn_procs[name].to_v_ip_1.parameters():
                        param.requires_grad = False
                    for param in attn_procs[name].to_k_ip_2.parameters():
                        param.requires_grad = False
                    for param in attn_procs[name].to_v_ip_2.parameters():
                        param.requires_grad = False

                else:
                    attn_procs[name] = IPAttnProcessor(
                        hidden_size=hidden_size,
                        cross_attention_dim=cross_attention_dim,
                        scale=1.0,
                        num_tokens=self.num_tokens * self.n_cond,
                    ).to(self.device, dtype=torch.float16)

                    # Freeze the parameters in the to_k_ip and to_v_ip layers
                    for param in attn_procs[name].to_k_ip.parameters():
                        param.requires_grad = False
                    for param in attn_procs[name].to_v_ip.parameters():
                        param.requires_grad = False

        self.denoising_unet.set_attn_processor(attn_procs)

    def load_ip_adapter(self):
        logger.info(f"[INFO] loading IP-Adapter checkpoints...")
        state_dict = torch.load(self.ip_ckpt, map_location="cpu")

        if self.ip_mode in ['faceid-decoupled']:
            self.image_proj_model_1.load_state_dict(state_dict["image_proj_1"], strict=True)
            self.image_proj_model_2.load_state_dict(state_dict["image_proj_2"], strict=True)
        else:
            self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)

        # need mapping to match the new model
        # Step 1: Create a mapping
        if self.ip_mode in ['faceid-decoupled']:
            ip_attn_name = (
                "IPAttnProcessorDecoupled2_0" if is_torch2_available() else "IPAttnProcessorDecoupled"
            )
        else:
            ip_attn_name = (
                "IPAttnProcessor2_0" if is_torch2_available() else "IPAttnProcessor"
            )
        ip_attns = {
            idx: i
            for idx, i in enumerate(list(self.denoising_unet.attn_processors.values()))
            if ip_attn_name in str(i)
        }
        state_id_to_model_pos_mapping = {
            idx + 1: list(ip_attns.keys())[pos]
            for pos, idx in enumerate(range(0, len(ip_attns) * 2, 2))
        }
        # Step 2: Modify keys in the state_dict
        new_state_dict = {}
        for k, v in state_dict["ip_adapter"].items():
            # Extract the original ID from the key (assuming format 'x.to_k_ip.weight' or 'x.to_v_ip.weight')
            original_id = int(k.split(".")[0])
            # Map the original ID to the position in the model
            mapped_pos = state_id_to_model_pos_mapping[original_id]
            # Calculate the new key based on the position in the model
            if self.ip_mode in ['faceid-decoupled']:
                if "to_k_ip_1" in k:
                    new_key = f"{mapped_pos}.to_k_ip_1.weight"
                elif "to_v_ip_1" in k:
                    new_key = f"{mapped_pos}.to_v_ip_1.weight"
                elif "to_k_ip_2" in k:
                    new_key = f"{mapped_pos}.to_k_ip_2.weight"
                elif "to_v_ip_2" in k:
                    new_key = f"{mapped_pos}.to_v_ip_2.weight"
                else:
                    raise ValueError("Unexpected key format in state_dict.")
            else:
                if "to_k_ip" in k:
                    new_key = f"{mapped_pos}.to_k_ip.weight"
                elif "to_v_ip" in k:
                    new_key = f"{mapped_pos}.to_v_ip.weight"
                else:
                    raise ValueError("Unexpected key format in state_dict.")
            # Assign the value to the new key
            new_state_dict[new_key] = v
        # Now, `new_state_dict` should have keys that align with your model's structure.
        # You can then load this new state_dict into your model.
        self.ip_layers = torch.nn.ModuleList(
            self.denoising_unet.attn_processors.values()
        )
        self.ip_layers.load_state_dict(new_state_dict, strict=True)
        logger.info(f"[INFO] loaded IP-Adapter checkpoints...")

    def encode_prompt(
        self,
        prompt,
        num_videos_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
    ):
        batch_size = len(prompt) if isinstance(prompt, list) else 1

        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        untruncated_ids = self.tokenizer(
            prompt, padding="longest", return_tensors="pt"
        ).input_ids

        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
            text_input_ids, untruncated_ids
        ):
            removed_text = self.tokenizer.batch_decode(
                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
            )

        if (
            hasattr(self.text_encoder.config, "use_attention_mask")
            and self.text_encoder.config.use_attention_mask
        ):
            attention_mask = text_inputs.attention_mask.to(self.device)
        else:
            attention_mask = None

        text_embeddings = self.text_encoder(
            text_input_ids.to(self.device),
            attention_mask=attention_mask,
        )
        text_embeddings = text_embeddings[0]

        # duplicate text embeddings for each generation per prompt, using mps friendly method
        bs_embed, seq_len, _ = text_embeddings.shape
        text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
        text_embeddings = text_embeddings.view(
            bs_embed * num_videos_per_prompt, seq_len, -1
        )

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = text_input_ids.shape[-1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if (
                hasattr(self.text_encoder.config, "use_attention_mask")
                and self.text_encoder.config.use_attention_mask
            ):
                attention_mask = uncond_input.attention_mask.to(self.device)
            else:
                attention_mask = None

            uncond_embeddings = self.text_encoder(
                uncond_input.input_ids.to(self.device),
                attention_mask=attention_mask,
            )
            uncond_embeddings = uncond_embeddings[0]

            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = uncond_embeddings.shape[1]
            uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
            uncond_embeddings = uncond_embeddings.view(
                batch_size * num_videos_per_prompt, seq_len, -1
            )
        else:
            uncond_embeddings = torch.zeros_like(text_embeddings)

        return text_embeddings, uncond_embeddings

    def get_image_embeds(self, pil_image=None, faceid_embeds=None):
        if pil_image is not None:
            if isinstance(pil_image, Image.Image):
                pil_image = [pil_image]
            clip_image = self.clip_image_processor(
                images=pil_image, return_tensors="pt"
            ).pixel_values
            clip_image = clip_image.to(self.device, dtype=self.weight_dtype)
            clip_image_embeds = self.image_encoder(
                clip_image, output_hidden_states=True
            ).hidden_states[-2]
            uncond_clip_image_embeds = self.image_encoder(
                torch.zeros_like(clip_image), output_hidden_states=True
            ).hidden_states[-2]
        if faceid_embeds is not None:
            if faceid_embeds.dim() == 3:
                multi_face = True
                b, n, c = faceid_embeds.shape
                faceid_embeds = faceid_embeds.reshape(b * n, c)

            faceid_embeds = faceid_embeds.to(self.device, dtype=self.weight_dtype)

        if self.ip_mode in ['faceid', 'portrait']:
            image_prompt_embeds = self.image_proj_model(faceid_embeds)
            uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
        elif self.ip_mode == 'faceid-plus':
            image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds)
            uncond_image_prompt_embeds = self.image_proj_model(
                torch.zeros_like(faceid_embeds), uncond_clip_image_embeds
            )
        elif self.ip_mode in ['vanilla', 'plus']:
            image_prompt_embeds = self.image_proj_model(clip_image_embeds)
            uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
        elif self.ip_mode in ['faceid-decoupled']:
            image_prompt_embeds_1 = self.image_proj_model_1(faceid_embeds)
            uncond_image_prompt_embeds_1 = self.image_proj_model_1(torch.zeros_like(faceid_embeds))
            image_prompt_embeds_2 = self.image_proj_model_2(clip_image_embeds)
            uncond_image_prompt_embeds_2 = self.image_proj_model_2(uncond_clip_image_embeds)
            image_prompt_embeds = torch.cat([image_prompt_embeds_1, image_prompt_embeds_2], dim=1)
            uncond_image_prompt_embeds = torch.cat(
                [uncond_image_prompt_embeds_1, uncond_image_prompt_embeds_2], dim=1
            )

        return image_prompt_embeds, uncond_image_prompt_embeds

    def set_ipa_scale(self, scale):
        for attn_processor in self.denoising_unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor) or isinstance(attn_processor, IPAttnProcessorDecoupled):
                attn_processor.scale = scale

    def prepare_t2iadapter_feature(
        self,
        adapter_images,
        do_classifier_free_guidance,
        t2i_adapter_conditioning_scale,
    ):
        # Convert images to the correct device and dtype upfront
        adapter_inputs = [
            (
                [
                    one_image.to(device=self.device, dtype=self.t2i_adapter.dtype)
                    for one_image in image
                ]
                if isinstance(self.t2i_adapter, MultiAdapter)
                else image.to(device=self.device, dtype=self.t2i_adapter.dtype)
            )
            for image in adapter_images
        ]

        down_intrablock_additional_residuals = []
        for adapter_input in adapter_inputs:
            if isinstance(self.t2i_adapter, MultiAdapter):
                adapter_state = self.t2i_adapter(
                    adapter_input, t2i_adapter_conditioning_scale
                )
                # for k, v in enumerate(adapter_state):
                #     adapter_state[k] = v
            else:
                adapter_state = self.t2i_adapter(adapter_input)
                for k, v in enumerate(adapter_state):
                    adapter_state[k] = v * t2i_adapter_conditioning_scale
            if do_classifier_free_guidance:
                for k, v in enumerate(adapter_state):
                    adapter_state[k] = torch.cat([v] * 2, dim=0)
            down_intrablock_additional_residuals.append(
                [state.clone() for state in adapter_state]
            )

        if len(down_intrablock_additional_residuals) > 0:
            # Initialize the output list
            output = [
                torch.zeros_like(tensor)
                .unsqueeze(2)
                .repeat(
                    1, 1, len(adapter_images), 1, 1
                )  # inject frame dimension, form (bs, c, f, h, w) tensor
                for tensor in down_intrablock_additional_residuals[0]
            ]
            # Iterate over frames and concatenate along the new dimension
            for i, tensors in enumerate(zip(*down_intrablock_additional_residuals)):
                output[i] = torch.stack(tensors, dim=2)

            return output

    def forward(
        self,
        noisy_latents: Union[torch.Generator, List[torch.Generator]],
        timesteps,
        audio_frame_embeddings: Union[torch.Generator, List[torch.Generator]],
        kps_images: Union[List[Image.Image], torch.FloatTensor],
        ref_images_pil: Optional[Union[List[Image.Image]]] = None,
        mask_images: Optional[List[Image.Image]] = None,
        prompt: Optional[Union[str, List[str]]] = None,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        do_unconditional_forward: bool = False,
        guidance_scale: float = 5.0,
        guidance_rescale: float = 0.0,
        face_embeds=None,
        ipa_scale: float = 1.0,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        t2i_adapter_control_type: Optional[str] = None,
        t2i_adapter_conditioning_scale: Union[float, List[float]] = 1.0,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
    ):
        self.set_ipa_scale(ipa_scale)
        do_classifier_free_guidance = guidance_scale > 1.0

        ######### COND 1: Residual Conditions ##########
        # KPS Features
        if not self.disable_kps:
            # DIM: (batch_size, channel, num_frames, height, width)
            kps_features = self.v_kps_guider(kps_images)
            if do_unconditional_forward:
                kps_features = torch.zeros_like(kps_features)
            elif self.kps_drop_rate != 0.0:
                drop_mask = torch.rand(kps_features.shape[0]) < self.kps_drop_rate
                kps_features[drop_mask, ...] = 0.0
        else:
            kps_features = None
        # T2I-Adapter Features
        if self.apply_t2i_adapter:
            # T2I-Adapter Inputs
            if t2i_adapter_control_type in ["kps", "openpose"]:
                adapter_images = rearrange(kps_images, "b c f h w -> f b c h w")
            elif t2i_adapter_control_type == "mask":
                adapter_images = mask_images
            else:
                raise NotImplementedError("t2i_adapter_control_type not supported!")
            down_intrablock_additional_residuals = self.prepare_t2iadapter_feature(
                adapter_images,
                do_classifier_free_guidance,
                t2i_adapter_conditioning_scale,
            )
            if do_unconditional_forward:
                down_intrablock_additional_residuals = [
                    torch.zeros_like(item)
                    for item in down_intrablock_additional_residuals
                ]
            # TODO: apply t2i_drop_rate
        else:
            down_intrablock_additional_residuals = None

        ### COND 1.7: WEAK-CONDS (Text + FaceID) ###
        # DIM: (batch_size, num_frames, num_embeds, dim)
        # face-id image latents
        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
            pil_image=ref_images_pil, faceid_embeds=face_embeds
        )
        # TEST: FACEID Dropout
        if self.faceid_drop_rate != 0.0:
            drop_mask = torch.rand(image_prompt_embeds.shape[0]) < self.faceid_drop_rate
            image_prompt_embeds[drop_mask, ...] = 0.0
        # text prompts embeds
        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = (
                "monochrome, lowres, bad anatomy, worst quality, low quality"
            )
        if not isinstance(prompt, List):
            prompt = [prompt]
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt]
        prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(
            prompt,
            num_videos_per_prompt=1,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
        )
        # prompt_embeds_, negative_prompt_embeds_ = prompt_embeds_.unsqueeze(0), negative_prompt_embeds_.unsqueeze(0)
        prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
        negative_prompt_embeds = torch.cat(
            [negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1
        )

        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        prompt_embeds = prompt_embeds.unsqueeze(1).repeat(
            1, noisy_latents.shape[2], 1, 1
        )  # extend to the frames length
        if do_unconditional_forward:
            prompt_embeds = torch.zeros_like(prompt_embeds)
        _, _, num_tokens_prompt, dim_prompt = prompt_embeds.shape

        ######### COND 2: AUDIO EMBEDS ############
        # DIM: (batch_size, num_frames, num_embeds, dim)
        if self.audio_projection:
            batch_size, num_frames, num_embeds, dim = audio_frame_embeddings.shape
            audio_frame_embeddings = audio_frame_embeddings.reshape(-1, num_embeds, dim)
            audio_frame_embeddings = self.audio_projection(audio_frame_embeddings)
            _, num_embeds, dim = audio_frame_embeddings.shape
            audio_frame_embeddings = audio_frame_embeddings.reshape(
                batch_size, num_frames, num_embeds, dim
            )
            if do_unconditional_forward:
                audio_frame_embeddings = torch.zeros_like(audio_frame_embeddings)
        else:
            audio_frame_embeddings = None

        ########## Noise Prediction ##########
        # timesteps: (bsz, )
        # kps: (bs, c, f, H, W)
        # noisy_latent: (bs, c, f, H, W)
        # aud_embeds = prompt_embds: (bs, f, num_embeds, dim)
        if do_classifier_free_guidance:
            noisy_latents = noisy_latents.repeat(2, 1, 1, 1, 1)
            if kps_features is not None:
                kps_features = kps_features.repeat(2, 1, 1, 1, 1)
            if down_intrablock_additional_residuals is not None:
                down_intrablock_additional_residuals = [
                    item.repeat(2, 1, 1, 1, 1)
                    for item in down_intrablock_additional_residuals
                ]
            if audio_frame_embeddings is not None:
                audio_frame_embeddings = audio_frame_embeddings.repeat(2, 1, 1, 1)

        if audio_frame_embeddings is not None:
            audio_frame_embeddings = audio_frame_embeddings.reshape(-1, num_embeds, dim) # (bs x f, num_embeds, dim)
        if prompt_embeds is not None:
            prompt_embeds = prompt_embeds.reshape(-1, num_tokens_prompt, dim_prompt) # (bs x f, num_embeds, dim)

        noise_pred = self.denoising_unet(
            noisy_latents,
            timesteps,
            kps_features=kps_features,
            down_intrablock_additional_residuals=down_intrablock_additional_residuals,
            encoder_hidden_states=audio_frame_embeddings,
            text_hidden_states=prompt_embeds,
            cross_attention_kwargs=cross_attention_kwargs,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        # perform CFG
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

        if do_classifier_free_guidance and guidance_rescale > 0.0:
            # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
            noise_pred = rescale_noise_cfg(
                noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
            )

        return noise_pred


def get_denoising_unet_state_dict(
    old_state_dict, state_dict_type, motion_module_path=None
):
    new_denoising_unet_state_dict = copy.deepcopy(old_state_dict)
    if state_dict_type == "old_attn":
        for name in old_state_dict.keys():
            if "norm1" in name:
                new_denoising_unet_state_dict[name.replace("norm1", "norm1_5")] = (
                    old_state_dict[name]
                )
            if "attn1" in name:
                new_denoising_unet_state_dict[name.replace("attn1", "attn1_5")] = (
                    old_state_dict[name]
                )
            if "attn2.to_q" in name:
                new_denoising_unet_state_dict[name] = old_state_dict[
                    name.replace("attn2.to_q", "attn2.processor.to_q_aud")
                ]
            if "attn2.to_k" in name:
                new_denoising_unet_state_dict[name] = old_state_dict[
                    name.replace("attn2.to_k", "attn2.processor.to_k_aud")
                ]
            if "attn2.to_v" in name:
                new_denoising_unet_state_dict[name] = old_state_dict[
                    name.replace("attn2.to_v", "attn2.processor.to_v_aud")
                ]
            if "attn2.to_out" in name:
                new_denoising_unet_state_dict[name] = old_state_dict[
                    name.replace("attn2.to_out", "attn2.processor.to_out_aud")
                ]
    elif state_dict_type == "moore_pretrained":
        #! TODO: since referenceNet is disabled, the params loading should be modified???
        for name in old_state_dict.keys():
            if "norm1" in name:
                new_denoising_unet_state_dict[name.replace("norm1", "norm1_5")] = (
                    old_state_dict[name]
                )
            if "attn1" in name:
                new_denoising_unet_state_dict[name.replace("attn1", "attn1_5")] = (
                    old_state_dict[name]
                )
    elif state_dict_type == "new_attn":
        pass
    else:
        raise ValueError(
            f"The state_dict_type {state_dict_type} is not supported. "
            f'Only support "moore_pretrained", "old_attn", and "new_attn".'
        )
    return new_denoising_unet_state_dict


def count_params(net: Net):
    if hasattr(net, "denoising_unet"):
        num_params, num_trainable_params = get_module_params(net.denoising_unet)
        logger.info(
            f"#parameters of Denoising U-Net is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
        num_params, num_trainable_params = get_module_params(net.denoising_unet, specify_key='motion_module')
        logger.info(
            f"#parameters of Motion Module is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
        num_params, num_trainable_params = get_module_params(net.denoising_unet, specify_key='attn2')
        logger.info(
            f"#parameters of Attn2 in Denoising U-Net is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if hasattr(net, "ip_layers"):
        num_params, num_trainable_params = get_module_params(net.ip_layers)
        logger.info(
            f"#parameters of IPA Layers is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if hasattr(net, "image_encoder"):
        num_params, num_trainable_params = get_module_params(net.image_encoder)
        logger.info(
            f"#parameters of IP-Adapter Image Encoder is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if hasattr(net, "v_kps_guider"):
        num_params, num_trainable_params = get_module_params(net.v_kps_guider)
        logger.info(
            f"#parameters of V-Kps Guider is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if hasattr(net, "t2i_adapter"):
        num_params, num_trainable_params = get_module_params(net.t2i_adapter)
        logger.info(
            f"#parameters of T2I-Adapter is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )
    if hasattr(net, "audio_projection"):
        num_params, num_trainable_params = get_module_params(net.audio_projection)
        logger.info(
            f"#parameters of Audio Projection is {num_params:.3f} M ({num_trainable_params:.3f} M is trainable)."
        )


def compute_snr(noise_scheduler, timesteps):
    """
    Computes SNR as per
    https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
    """
    alphas_cumprod = noise_scheduler.alphas_cumprod
    sqrt_alphas_cumprod = alphas_cumprod**0.5
    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

    # Expand the tensors.
    # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
    sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
        timesteps
    ].float()
    while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
    alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
        device=timesteps.device
    )[timesteps].float()
    while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
    sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

    # Compute SNR.
    snr = (alpha / sigma) ** 2
    return snr


def convert_omegaconf_list(obj):
    if isinstance(obj, ListConfig):
        return list(obj)
    return obj

def check_configs(cfg):
    if not cfg.disable_kps and cfg.apply_t2i_adapter:
        raise ValueError(
            f"`disable_kps` and `apply_t2i_adapter` cannot be set as True or False at the same time."
        )


def main():
    """
    =============== 1. CONFIGURATION ===============
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
    args = parser.parse_args()

    cfg = OmegaConf.load(args.config)
    # check_configs(cfg)
    ipa_scale =  convert_omegaconf_list(cfg.ipa_scale)
    num_tokens = convert_omegaconf_list(cfg.num_tokens)

    exp_name = ".".join(Path(args.config).name.split(".")[:-1])

    # 1.1 Configure Accelerator
    kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
    accelerator = Accelerator(
        gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
        mixed_precision=cfg.solver.mixed_precision,
        log_with="mlflow",
        project_dir="./mlruns",
        kwargs_handlers=[kwargs],
    )

    # 1.2 Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    logger.info(accelerator.state, main_process_only=True)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    if cfg.seed is not None:
        seed_everything(cfg.seed)

    # 1.3 Configure Device, Type, Save Dir
    local_rank = accelerator.device

    if cfg.weight_dtype == "fp16":
        weight_dtype = torch.float16
    elif cfg.weight_dtype == "fp32":
        weight_dtype = torch.float32
    else:
        raise ValueError(
            f"Do not support weight dtype: {cfg.weight_dtype} during training"
        )

    save_dir = f"{cfg.output_dir}/{exp_name}"
    pathlib.Path(save_dir).mkdir(exist_ok=True, parents=True)

    # Dump the configuration as a YAML file
    config_save_path = os.path.join(save_dir, "config_dump.yaml")
    with open(config_save_path, "w") as f:
        OmegaConf.save(config=cfg, f=f)

    # Load LoRA Configs
    APPLY_LORA = OmegaConf.select(cfg, "lora.denoise_unet.rank")
    if APPLY_LORA:
        lora_config = load_denoise_unet_lora(cfg, denoising_unet=None)
        logger.info(
            f"LoRA Rank: {cfg.lora.denoise_unet.rank} \n"
            f"LoRA Alpha: {cfg.lora.denoise_unet.alpha} \n"
            f"LoRA Target Modules: {cfg.lora.denoise_unet.target_modules} \n"
            f"LoRA DropOut: {cfg.lora.denoise_unet.dropout} \n"
        )

    """
    =============== 2. MODULES LOADING ===============
    """
    # 2.1 initialize the noise scheduler
    scheduler_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
    if cfg.enable_zero_snr:
        scheduler_kwargs.update(
            rescale_betas_zero_snr=True,
            timestep_spacing="trailing",
            prediction_type="v_prediction",
        )
    noise_scheduler = DDIMScheduler(**scheduler_kwargs)
    logger.info(f"Noise Schedule Setting: {cfg.noise_scheduler_kwargs}")

    # 2.2 initialize the pretrained fixed modules
    vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
        local_rank, dtype=weight_dtype
    )
    audio_encoder = Wav2Vec2Model.from_pretrained(cfg.audio_encoder_path).to(
        dtype=weight_dtype, device=local_rank
    )
    audio_processor = Wav2Vec2Processor.from_pretrained(cfg.audio_encoder_path)

    # initialize our modules
    # 2.3 Diffusion Models
    denoising_unet = UNet3DConditionModel.from_pretrained_2d(
        cfg.base_model_path,
        cfg.motion_module_path,
        subfolder="unet",
        unet_additional_kwargs=OmegaConf.to_container(cfg.unet_additional_kwargs),
        mm_zero_proj_out=cfg.unet_additional_kwargs.mm_zero_proj_out,
        logger=logger,
    ).to(device=local_rank)

    # 2.4 Residual Condition Encoder
    if not cfg.disable_kps:
        v_kps_guider = VKpsGuider(
            conditioning_embedding_channels=320,
            block_out_channels=(16, 32, 96, 256),
        ).to(device=local_rank, dtype=weight_dtype)
    else:
        v_kps_guider = None

    if cfg.apply_t2i_adapter:
        if os.path.isdir(cfg.t2i_adapter_model_path):
            t2i_adapter = T2IAdapter.from_pretrained(
                cfg.t2i_adapter_model_path, torch_dtype=weight_dtype
            ).to(local_rank)
        elif cfg.t2i_adapter_model_path.endswith(".pth"):
            t2i_adapter = T2IAdapter(
                in_channels=3,
                channels=[320, 640, 1280, 1280],
                num_res_blocks=2,
                downscale_factor=8,
                adapter_type="full_adapter",
            ).to(dtype=weight_dtype, device=local_rank)
            t2i_state_dict_ = torch.load(cfg.t2i_adapter_model_path, map_location="cpu")
            t2i_state_dict = t2i_adapter_map_keys(
                t2i_state_dict_
            )  # Create the new state dictionary with mapped keys
            # zero initialized the last block weights of the t2i-adapter
            if cfg.t2i_adapter_zero_out:
                new_t2i_adapter_state_dict = OrderedDict()
                for k in t2i_state_dict:
                    if "block2" in k:
                        continue
                    new_t2i_adapter_state_dict[k] = t2i_state_dict[k]
                info = t2i_adapter.load_state_dict(
                    new_t2i_adapter_state_dict, strict=False
                )  # Load the renamed state dictionary into the model
            else:
                info = t2i_adapter.load_state_dict(
                    t2i_state_dict, strict=True
                )  # Load the renamed state dictionary into the model
        else:
            raise NotImplementedError("T2I Adapter checkpoint incorrect configuration!")
        logger.info(f"Loaded T2I-Adapter from {cfg.t2i_adapter_model_path}. Info: {info}")
    else:
        t2i_adapter = None

    # 2.5 Audio Projection and Encoder
    if cfg.module_training.audio_projection:
        if cfg.data.audio_embeddings_type == "global":
            inp_dim = 768
        elif cfg.data.audio_embeddings_type == "xlsr_global":
            inp_dim = 1024
        else:
            raise ValueError(
                f"Do not support {cfg.data.audio_embeddings_type}. "
                f'Now only support "global" and "xlsr_global".'
            )
        mid_dim = denoising_unet.config.cross_attention_dim
        out_dim = denoising_unet.config.cross_attention_dim
        inp_seq_len = 2 * (2 * cfg.data.num_padding_audio_frames + 1)
        out_seq_len = 2 * cfg.data.num_padding_audio_frames + 1
        if OmegaConf.select(cfg, "aud_proj_depth"):
            aud_proj_depth = cfg.aud_proj_depth
        else:
            aud_proj_depth = 4
        audio_projection = AudioProjection(
            dim=mid_dim,
            depth=aud_proj_depth,
            dim_head=64,
            heads=12,
            num_queries=out_seq_len,
            embedding_dim=inp_dim,
            output_dim=out_dim,
            ff_mult=4,
            max_seq_len=inp_seq_len,
        )
        print_highlighted_block_log(
            title="Initialize Audio Projection",
            message=f"""
            aud_proj_depth: {aud_proj_depth}; num_queries: {out_seq_len}; max_seq_len: {inp_seq_len}
            """,
            logger=logger
        )
    else:
        audio_projection = None

    # 2.6 Text and Image Encoders
    logger.info(f"[INFO] loading Image / Text Encoders...")
    tokenizer = CLIPTokenizer.from_pretrained(
        cfg.base_model_path,
        subfolder="tokenizer",
        torch_dtype=weight_dtype,
    )
    text_encoder = CLIPTextModel.from_pretrained(
        cfg.base_model_path,
        subfolder="text_encoder",
        torch_dtype=weight_dtype,
    ).to(device=local_rank)

    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        cfg.image_encoder_path
    ).to(device=local_rank, dtype=weight_dtype)
    logger.info(f"[INFO] loaded Image / Text Encoders!")

    # 2.7 load parameters to our modules
    if not cfg.disable_kps and cfg.v_kps_guider_path != "":
        kps_state_dict = torch.load(cfg.v_kps_guider_path, map_location="cpu")
        # zero initialized the proj-out weights of the kps-guider
        if cfg.v_kps_guider_zero_out:
            new_kps_state_dict = OrderedDict()
            for k in kps_state_dict:
                if "conv_out" in k:
                    continue
                new_kps_state_dict[k] = kps_state_dict[k]
            info = v_kps_guider.load_state_dict(new_kps_state_dict, strict=False)
        else:
            info = v_kps_guider.load_state_dict(kps_state_dict)
        logger.info(f"Loaded VKpsGuider from {cfg.v_kps_guider_path}. Info: {info}")

    if cfg.module_training.audio_projection and cfg.audio_projection_path != "":
        info = audio_projection.load_state_dict(
            torch.load(cfg.audio_projection_path, map_location="cpu")
        )
        logger.info(
            f"Loaded AudioProjection from {cfg.audio_projection_path}. Info: {info}"
        )

    if cfg.denoising_unet_path != "":
        state_dict = torch.load(cfg.denoising_unet_path, map_location="cpu")
        new_state_dict = get_denoising_unet_state_dict(
            state_dict, cfg.denoising_unet_state_dict_type
        )
        m, u = denoising_unet.load_state_dict(new_state_dict, strict=False)
        logger.info(
            f"Loaded Denoising U-Net from {cfg.denoising_unet_path} in type of {cfg.denoising_unet_state_dict_type}.\n "
            f"### missing keys: {m}; \n### unexpected keys: {u};"
        )

    if cfg.motion_module_path is not None and cfg.motion_module_path != "":
        state_dict = torch.load(cfg.motion_module_path, map_location="cpu")
        m, u = denoising_unet.load_state_dict(state_dict, strict=False)
        logger.info(
            f"Loaded Motion Module from {cfg.motion_module_path}. "
            f"Info: ### missing keys: {len(m)}; ### unexpected keys: {len(u)};"
        )

    for name, params in denoising_unet.named_parameters():
        if 'stage_1' in exp_name:
            if 'temporal_transformer.proj_out' in name:
                zero_module_params(params)
            if 'attn2.to_out' in name:
                zero_module_params(params)
        elif 'stage_2' in exp_name:
            if cfg.unet_additional_kwargs.mm_zero_proj_out:
                if 'temporal_transformer.proj_out' in name:
                    zero_module_params(params)
            if 'attn2.to_out' in name:
                zero_module_params(params)
        elif 'stage_3' in exp_name:
            pass
        else:
            raise NotImplementedError(f"{exp_name} not implement")

    """
    =============== 3. MODULES CONFIGURATION ===============
    """
    # 3.1 set gradient state of all modules
    vae.requires_grad_(False)
    audio_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    text_encoder.requires_grad_(False)

    denoising_unet.requires_grad_(cfg.module_training.denoising_unet)
    if cfg.module_training.get('trainable_modules', None):
        for name, param in denoising_unet.named_parameters():
            for trainable_module_name in cfg.module_training.trainable_modules:
                if trainable_module_name in name:
                    param.requires_grad = True
                    break

    if v_kps_guider:
        v_kps_guider.requires_grad_(cfg.module_training.v_kps_guider)
    if t2i_adapter:
        t2i_adapter.requires_grad_(cfg.module_training.t2i_adapter)
    if audio_projection:
        audio_projection.requires_grad_(cfg.module_training.audio_projection)

    # Make Motion Module Trainable
    if cfg.module_training.get('motion_trainable_modules', None):
        for name, param in denoising_unet.named_parameters():
            for trainable_module_name in cfg.module_training.motion_trainable_modules:
                if "motion_modules" in name and trainable_module_name in name:
                    param.requires_grad = True
                    break
    else:
        for name, module in denoising_unet.named_modules():
            if "motion_modules" in name:
                for params in module.parameters():
                    params.requires_grad = cfg.module_training.motion_module

    # TODO: LoRA for denoising net
    if APPLY_LORA:
        logger.info(
            f"Before PEFT: denoising_unet: {get_module_params(denoising_unet, only_trainable=True)} M"
        )

        denoising_unet = get_peft_model(denoising_unet, lora_config)
        denoising_unet.print_trainable_parameters()
        if cfg.lora.denoise_unet.only_mm:
            logger.info(
                f"After PEFT, denoising_unet: "
                f"attn1: {denoising_unet.down_blocks[1].motion_modules[0].temporal_transformer.transformer_blocks[0].attention_blocks[0].to_q.weight.requires_grad}; "
                f"lora: {denoising_unet.down_blocks[1].motion_modules[0].temporal_transformer.transformer_blocks[0].attention_blocks[0].to_q.lora_A.default.weight.requires_grad}"
            )
        else:
            logger.info(
                f"After PEFT, denoising_unet: "
                f"attn1: {denoising_unet.down_blocks[1].attentions[0].transformer_blocks[0].attn1.to_q.weight.requires_grad}; "
                f"lora: {denoising_unet.down_blocks[1].attentions[0].transformer_blocks[0].attn1.to_q.lora_A.default.weight.requires_grad}"
            )
        logger.info("""
            After PEFT, trainable params:
            vae: {} M;
            image_encoder: {} M;
            audio_encoder: {} M;
            audio_projection: {} M;
            denoising_unet : {} M;
            v_kps_guider : {} M;
            """.format(
                get_module_params(vae, only_trainable=True),
                get_module_params(image_encoder, only_trainable=True),
                get_module_params(audio_encoder, only_trainable=True),
                get_module_params(audio_projection, only_trainable=True),
                get_module_params(denoising_unet, only_trainable=True),
                get_module_params(v_kps_guider, only_trainable=True),
            )
        )

    # Make Attention 2 (Audio Cross-Attention) Trainable
    for name, module in denoising_unet.named_modules():
        if "attentions" in name and ("attn2" in name or "norm2" in name):
            logger.info(name)
            for params in module.parameters():
                params.requires_grad = cfg.module_training.audio_projection

    # 3.2 initialize VExpress Net

    net = Net(
        denoising_unet=denoising_unet,
        v_kps_guider=v_kps_guider,
        t2i_adapter=t2i_adapter,
        audio_projection=audio_projection,
        image_encoder=image_encoder,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        device=local_rank,
        weight_dtype=weight_dtype,
        kps_drop_rate=cfg.data.kps_drop_rate,
        faceid_drop_rate=cfg.data.faceid_drop_rate,
        ip_ckpt=cfg.ip_ckpt,
        num_tokens=num_tokens,
        n_cond=cfg.n_cond,
        ip_mode=cfg.ip_mode,
    )

    """
    =============== 4. SOLVER SETTINGS ===============
    """

    if cfg.solver.denoising_unet_gradient_checkpointing:
        denoising_unet.enable_gradient_checkpointing()

    if cfg.solver.scale_lr:
        learning_rate = (
            cfg.solver.learning_rate
            * cfg.solver.gradient_accumulation_steps
            * cfg.data.train_bs
            * accelerator.num_processes
        )
    else:
        learning_rate = cfg.solver.learning_rate

    # 4.1 initialize the optimizer and lr scheduler
    if cfg.solver.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    # 4.2 Get the parameters to optimize
    trainable_params = list(filter(lambda p: p.requires_grad, net.parameters()))
    if cfg.module_training.ip_adapter:
        net.image_encoder.requires_grad_(True)
        net.ip_layers.requires_grad_(True)
        params_to_opt = chain(
            net.image_encoder.parameters(), net.ip_layers.parameters()
        )
        trainable_params = list(set(chain(trainable_params, params_to_opt)))
    optimizer = optimizer_cls(
        trainable_params,
        lr=learning_rate,
        betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
        weight_decay=cfg.solver.adam_weight_decay,
        eps=cfg.solver.adam_epsilon,
    )
    lr_scheduler = get_scheduler(
        cfg.solver.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
        num_training_steps=cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps,
    )

    """
    =============== 5. DATA LOADING ===============
    """
    if cfg.data.data_format in ["video_only", "video_image"]:
        vid_dataset = TalkingFaceVideo(
            image_size=(cfg.data.train_height, cfg.data.train_width),
            meta_paths=cfg.data.meta_paths,
            prompt_paths=cfg.data.prompt_paths,
            flip_rate=cfg.data.flip_rate,
            sample_rate=cfg.data.sample_rate,
            num_frames=cfg.data.num_frames,
            reference_margin=cfg.data.reference_margin,
            num_padding_audio_frames=cfg.data.num_padding_audio_frames,
        )
    if cfg.data.data_format in ["image_only", "video_image"]:
        img_dataset = TalkingFaceImage(
            image_size=(cfg.data.train_height, cfg.data.train_width),
            meta_paths=cfg.data.image_meta_paths,
            flip_rate=cfg.data.flip_rate,
        )
    if cfg.data.data_format not in ["video_only", "image_only", "video_image"]:
        raise KeyError("Incorrect Setting for data_format!")

    if cfg.data.data_format == "image_only":
        dataset = img_dataset
    elif cfg.data.data_format == "video_only":
        dataset = vid_dataset
    elif cfg.data.data_format == "video_image":
        from datasets.utils import CombinedDataset
        dataset = CombinedDataset(vid_dataset, img_dataset)

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
    )

    """
    =============== 6. Setting Accelerator ===============
    """
    (net, optimizer, dataloader, lr_scheduler) = accelerator.prepare(
        net, optimizer, dataloader, lr_scheduler
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(
        len(dataloader) / cfg.solver.gradient_accumulation_steps
    )
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(
        cfg.solver.max_train_steps / num_update_steps_per_epoch
    )

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        run_time = datetime.now().strftime("%Y%m%d-%H%M")
        accelerator.init_trackers(
            exp_name,
            init_kwargs={"mlflow": {"run_name": f"{save_dir.replace('/', '-')}-{run_time}"}},
        )
        # dump config file
        mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")

    total_batch_size = (
        cfg.data.train_bs
        * accelerator.num_processes
        * cfg.solver.gradient_accumulation_steps
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {cfg.data.train_bs}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {cfg.solver.max_train_steps}")
    logger.info(f"  Prediction Type: {noise_scheduler.prediction_type}")
    logger.info(f"  Using Classifier Free Guidance: {cfg.guidance_scale > 1.0}")

    global_step = 0
    first_epoch = 0
    if accelerator.is_main_process:
        count_params(net)
        # verify zero initialization for finetuning
        if t2i_adapter and cfg.t2i_adapter_zero_out:
            check_zero_initialization(net.t2i_adapter, "block2", logger=logger)
        if v_kps_guider and cfg.v_kps_guider_zero_out:
            check_zero_initialization(net.v_kps_guider, "conv_out", logger=logger)
        if cfg.module_training.motion_module:
            check_zero_initialization(
                net.denoising_unet, "temporal_transformer.proj_out", logger=logger
            )
        check_zero_initialization(
            net.denoising_unet, "attn2.to_out", logger=logger
        )

    # Potentially load in the weights and states from a previous save
    if cfg.resume_from_checkpoint:
        if cfg.resume_from_checkpoint != "latest":
            resume_dir = cfg.resume_from_checkpoint
        else:
            resume_dir = save_dir
        # Get the most recent checkpoint
        dirs = os.listdir(resume_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        if len(dirs) > 0:
            path = dirs[-1]
            accelerator.load_state(
                os.path.join(resume_dir, path), load_module_strict=False
            )
            accelerator.print(f"Resuming from checkpoint {path}")
            global_step = int(path.split("-")[1])

            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = global_step % num_update_steps_per_epoch
        else:
            logger.info(f'There is no checkpoint to load in path {resume_dir}. Resuming skipped.')

    """
    =============== 7. Starting Training!!! ===============
    """
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(
        range(global_step, cfg.solver.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )
    progress_bar.set_description(f"{exp_name}, Steps")

    for epoch in range(first_epoch, num_train_epochs):
        train_loss = 0.0
        t_data_start = time.time()
        for step, batch in enumerate(dataloader):
            t_data = time.time() - t_data_start
            with accelerator.accumulate(net):
                # 7.1 Load Inputs
                target_images = batch["target_images"].to(weight_dtype)
                with torch.no_grad():
                    length = target_images.shape[2]
                    target_images = rearrange(target_images, "b c f h w -> (b f) c h w")
                    latents = vae.encode(target_images).latent_dist.sample()
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=length)
                    latents = latents * 0.18215

                target_images = batch["target_images"].to(weight_dtype)

                # 7.2 Load Conditions
                if cfg.ip_mode in ['plus', 'vanilla', 'faceid-decoupled']:
                    ref_images = batch["reference_image"].to(weight_dtype)
                    ref_images_pil = []
                    for tensor_image in ref_images:
                        pil_image = transforms_f.to_pil_image((tensor_image + 1.) / 2.)
                        ref_images_pil.append(pil_image)
                else:
                    ref_images_pil = None

                with torch.no_grad():
                    kps_images = batch["kps_images"].to(local_rank, dtype=weight_dtype)  # (bs, c, f, H, W)
                    audio_frame_embeddings = batch["audio_frame_embeddings"].to(local_rank, dtype=weight_dtype)
                    prompt = batch["caption"]
                    face_masks = batch["face_masks"].to(dtype=vae.dtype, device=vae.device)
                    lip_masks = batch["lip_masks"].to(dtype=vae.dtype, device=vae.device)
                    ref_face_embed = batch.get(
                        "ref_face_embed_mask" if cfg.data.load_face_mask else "ref_face_embed"
                    ).to(dtype=vae.dtype, device=vae.device)

                # 7.3 Generate Noise
                noise = torch.randn_like(latents)
                if cfg.repeat_start:
                    repeat_idx = random.randint(0, length - 2)
                    noise = noise[:, :, repeat_idx : repeat_idx + 1].repeat(1, 1, length, 1, 1)
                if cfg.noise_offset > 0:
                    noise += cfg.noise_offset * torch.randn(
                        (latents.shape[0], latents.shape[1], 1, 1, 1),
                        device=latents.device,
                    )

                bsz = latents.shape[0]
                # 7.4 Sample a random timestep for each video
                if OmegaConf.select(cfg, "sample_time_skew") and cfg.sample_time_skew != 1.0:
                    timesteps = custom_sample_timesteps(
                        noise_scheduler.num_train_timesteps,
                        bsz,
                        skew_type="late",  # early, late
                        skew_factor=cfg.sample_time_skew,
                        device=latents.device,
                    )
                else:
                    timesteps = torch.randint(
                        0,
                        noise_scheduler.num_train_timesteps,
                        (bsz,),
                        device=latents.device,
                    )
                timesteps = timesteps.long()

                # 7.5 # REVIEW: Prefix logic | Add noise
                is_prefix = (random.random() < cfg.prefix_ratio) and (cfg.data.num_prefix_frames != 0)
                if is_prefix:
                    _noisy_latents = noise_scheduler.add_noise(
                        latents[:, :, cfg.data.num_prefix_frames :],
                        noise[:, :, cfg.data.num_prefix_frames :],
                        timesteps,
                    )
                else:
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # 7.6 Get the target for loss depending on the prediction type
                if noise_scheduler.prediction_type == "epsilon":
                    if is_prefix:
                        noisy_latents = torch.cat(
                            (latents[:, :, : cfg.data.num_prefix_frames], _noisy_latents), # NOTE: VAE image latents as Prefix 
                            dim=2,
                        )
                        target = noise[:, :, cfg.data.num_prefix_frames:]
                    else:
                        target = noise

                elif noise_scheduler.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(
                        f"Unknown prediction type {noise_scheduler.prediction_type}"
                    )

                # 7.6.x: prepare added_cond_kwargs for UNet
                added_cond_kwargs = {}

                # 7.7 Call UNet to predict noise
                model_pred = net(
                    noisy_latents,
                    timesteps,
                    audio_frame_embeddings,
                    kps_images,
                    ref_images_pil=ref_images_pil,
                    prompt=prompt,
                    negative_prompt=cfg.data.negative_prompt,
                    do_unconditional_forward=random.random() < cfg.uncond_ratio,
                    guidance_scale=cfg.guidance_scale,
                    face_embeds=ref_face_embed,
                    t2i_adapter_control_type=cfg.t2i_adapter_control_type,
                    t2i_adapter_conditioning_scale=cfg.t2i_adapter_conditioning_scale,
                    ipa_scale=ipa_scale,
                    added_cond_kwargs=added_cond_kwargs,
                )

                # 7.8 Compute Loss
                if is_prefix:
                    model_pred = model_pred[:, :, cfg.data.num_prefix_frames :]
                    face_masks = face_masks[:, :, cfg.data.num_prefix_frames :]
                    lip_masks = lip_masks[:, :, cfg.data.num_prefix_frames :]

                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                if "bg_loss_weight" in cfg.data:
                    loss *= ((cfg.data.bg_loss_weight - 1) * (1. - face_masks) + 1.0)
                if "lip_loss_weight" in cfg.data:
                    loss *= ((cfg.data.lip_loss_weight - 1) * lip_masks + 1.0)

                if cfg.snr_gamma != 0:
                    snr = compute_snr(noise_scheduler, timesteps)
                    if noise_scheduler.config.prediction_type == "v_prediction":
                        # Velocity objective requires that we add one to SNR values before we divide by them.
                        snr = snr + 1
                    mse_loss_weights = (
                        torch.stack(
                            [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
                        ).min(dim=1)[0]
                        / snr
                    )
                    loss = (
                        loss.mean(dim=list(range(1, len(loss.shape))))
                        * mse_loss_weights
                    )
                loss = loss.mean()

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
                train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps

                # 7.9 Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(
                        trainable_params, cfg.solver.max_grad_norm
                    )
                if global_step == 1:
                    logger.info(
                        f"Denoising UNet: {count_updated_params(net.denoising_unet)}"
                    )
                    if ("stage_1" in exp_name or cfg.module_training.t2i_adapter) and net.t2i_adapter:
                        logger.info(f"T2I-Adapter: {count_updated_params(net.t2i_adapter)}")
                    if "stage_1" in exp_name and net.v_kps_guider:
                        logger.info(f"KPS Guider: {count_updated_params(net.v_kps_guider)}")
                    if "stage_2" in exp_name:
                        if net.audio_projection:
                            logger.info(
                                f"Audio: {count_updated_params(net.audio_projection)}"
                            )
                        logger.info(
                            f"Attn2: {count_updated_params(net.denoising_unet, specify_key='attn2')}"
                        )
                        logger.info(
                            f"Motion: {count_updated_params(net.denoising_unet, specify_key='motion_module')}"
                        )
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

            logs = {
                "step_loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "td": f"{t_data:.2f}s",
            }
            t_data_start = time.time()
            progress_bar.set_postfix(**logs)

            if global_step >= cfg.solver.max_train_steps:
                break

            # 7.10 save model after each epoch
            if global_step % cfg.checkpointing_steps == 0 or global_step == 1:
                save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
                accelerator.save_state(save_path)

                if accelerator.is_main_process:
                    module = accelerator.unwrap_model(net.denoising_unet)
                    save_module_checkpoint(
                        module, save_dir, "denoising_unet", global_step
                    )
                    if APPLY_LORA:
                        accelerator.unwrap_model(net).denoising_unet.save_pretrained(
                            save_path
                        )  # 保存lora

                    if cfg.module_training.v_kps_guider and net.v_kps_guider:
                        module = accelerator.unwrap_model(net.v_kps_guider)
                        save_module_checkpoint(
                            module, save_dir, "v_kps_guider", global_step
                        )

                    if cfg.module_training.t2i_adapter and net.t2i_adapter:
                        module = accelerator.unwrap_model(net.t2i_adapter)
                        save_module_checkpoint(
                            module,
                            save_dir,
                            f"t2i_adapter_{cfg.t2i_adapter_control_type}",
                            global_step,
                        )

                    if cfg.module_training.audio_projection:
                        module = accelerator.unwrap_model(net.audio_projection)
                        save_module_checkpoint(
                            module, save_dir, "audio_projection", global_step
                        )

                    if cfg.module_training.motion_module:
                        module = accelerator.unwrap_model(net.denoising_unet)
                        save_motion_module_checkpoint(
                            module, save_dir, "motion_module", global_step
                        )

                    if cfg.module_training.ip_adapter:
                        module = accelerator.unwrap_model(net.image_encoder)
                        save_motion_module_checkpoint(
                            module, save_dir, "image_encoder", global_step
                        )
                        module = accelerator.unwrap_model(net.ip_layers)
                        save_motion_module_checkpoint(
                            module, save_dir, "ip_layers", global_step
                        )

    # save model after each epoch
    # if accelerator.is_main_process:
    save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
    accelerator.save_state(save_path)

    if accelerator.is_main_process:
        module = accelerator.unwrap_model(net.denoising_unet)
        save_module_checkpoint(module, save_dir, "denoising_unet", global_step)
        if APPLY_LORA:
            accelerator.unwrap_model(net).denoising_unet.save_pretrained(
                save_path
            )  # 保存lora

        if cfg.module_training.v_kps_guider and net.v_kps_guider:
            module = accelerator.unwrap_model(net.v_kps_guider)
            save_module_checkpoint(module, save_dir, "v_kps_guider", global_step)

        if cfg.module_training.t2i_adapter and net.t2i_adapter:
            module = accelerator.unwrap_model(net.t2i_adapter)
            save_module_checkpoint(
                module,
                save_dir,
                f"t2i_adapter_{cfg.t2i_adapter_control_type}",
                global_step,
            )

        if cfg.module_training.audio_projection:
            module = accelerator.unwrap_model(net.audio_projection)
            save_module_checkpoint(module, save_dir, "audio_projection", global_step)

        if cfg.module_training.motion_module:
            module = accelerator.unwrap_model(net.denoising_unet)
            save_motion_module_checkpoint(
                module, save_dir, "motion_module", global_step
            )

        if cfg.module_training.ip_adapter:
            module = accelerator.unwrap_model(net.image_encoder)
            save_motion_module_checkpoint(
                module, save_dir, "image_encoder", global_step
            )
            module = accelerator.unwrap_model(net.ip_layers)
            save_motion_module_checkpoint(module, save_dir, "ip_layers", global_step)

    # Create the pipeline using the trained modules and save it.
    accelerator.wait_for_everyone()
    accelerator.end_training()


def save_module_checkpoint(module, save_dir, prefix, ckpt_num):
    save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")

    state_dict = module.state_dict()
    new_state_dict = {}
    for n, p in state_dict.items():
        new_state_dict[n] = p.clone()
    torch.save(new_state_dict, save_path)


def save_motion_module_checkpoint(model, save_dir, prefix, ckpt_num):
    save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")

    mm_state_dict = OrderedDict()
    state_dict = model.state_dict()
    for key in state_dict:
        if "motion_module" in key:
            mm_state_dict[key] = state_dict[key].clone()

    torch.save(mm_state_dict, save_path)


if __name__ == "__main__":
    main()
