# Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
import os, inspect
import math
from dataclasses import dataclass
from typing import Callable, List, Optional, Union, Dict, Union, Any, Tuple
import PIL
from PIL import Image
from collections import OrderedDict
import cv2
import scipy.ndimage as ndimage
import torch.fft as fft

import numpy as np
import torch
import torchvision.transforms.functional as transforms_f
from diffusers import DiffusionPipeline, StableDiffusionMixin, ControlNetModel
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pag.pag_utils import PAGMixin
from diffusers.loaders import LoraLoaderMixin, StableDiffusionLoraLoaderMixin
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
from diffusers.schedulers import (
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    LMSDiscreteScheduler,
    PNDMScheduler,
)
from diffusers.utils import (
    PIL_INTERPOLATION,
    USE_PEFT_BACKEND,
    BaseOutput,
    scale_lora_layers,
    unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from einops import rearrange
from tqdm import tqdm
from transformers import CLIPImageProcessor

from modules.unet_3d_blocks import CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D, UNetMidBlock3DCrossAttn
from modules import T2IAdapter, MultiAdapter, DDIMScheduler
from .context import get_context_scheduler
from .utils import load_img, get_face_mask, load_masked_image_from_faceinfo
from utils.ip_adapter_utils import (
    init_proj,
    set_ip_adapter,
    load_ip_adapter,
    set_ipa_scale,
)
from utils.convert_lora_safetensor_to_diffusers import load_diffusers_lora
from utils.postprocess import correct_color_offset
from utils.dynthres_core import DynThresh
from utils.hook_utils import get_net_attn_map, get_net_qk


def rearrange_images(image_sublists):
    # Check if all sublists have the same length
    lengths = [len(sublist) for sublist in image_sublists]
    if len(set(lengths)) != 1:
        raise ValueError("Not all sublists are of the same length")

    sublist_length = lengths[0]

    rearranged = []
    for i in range(sublist_length):
        new_sublist = [sublist[i] for sublist in image_sublists]
        rearranged.append(new_sublist)

    return rearranged


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    """
    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
    """
    std_text = noise_pred_text.std(
        dim=list(range(1, noise_pred_text.ndim)), keepdim=True
    )
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = (
        guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    )
    return noise_cfg


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
    encoder_output: torch.Tensor,
    generator: Optional[torch.Generator] = None,
    sample_mode: str = "sample",
):
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
        return encoder_output.latent_dist.mode()
    elif hasattr(encoder_output, "latents"):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents of provided encoder_output")


def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    **kwargs,
):
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used,
            `timesteps` must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
                must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys()
        )
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps


def _preprocess_adapter_image(image, height, width):
    if isinstance(image, torch.Tensor):
        return image
    elif isinstance(image, PIL.Image.Image):
        image = [image]

    if isinstance(image[0], PIL.Image.Image):
        image = [
            np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))
            for i in image
        ]
        image = [
            i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
        ]  # expand [h, w] or [h, w, c] to [b, h, w, c]
        image = np.concatenate(image, axis=0)
        image = np.array(image).astype(np.float32) / 255.0
        image = image.transpose(0, 3, 1, 2)
        image = torch.from_numpy(image)
    elif isinstance(image[0], torch.Tensor):
        if image[0].ndim == 3:
            image = torch.stack(image, dim=0)
        elif image[0].ndim == 4:
            image = torch.cat(image, dim=0)
        else:
            raise ValueError(
                f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
            )
    return image


def Fourier_filter(x, threshold, scale):
    dtype = x.dtype
    x = x.type(torch.float32)
    
    # FFT
    x_freq = fft.fftn(x, dim=(-3, -2, -1))  # Perform FFT across frames, height, and width
    x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))  # Shift FFT to center frequencies
    
    B, C, T, H, W = x_freq.shape
    mask = torch.ones((B, C, T, H, W), device=x.device)  # Create a 5D mask
    
    tmid, crow, ccol = T // 2, H // 2, W // 2
    mask[..., tmid - threshold:tmid + threshold, crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
    x_freq = x_freq * mask  # Apply the mask to the frequency domain
    
    # IFFT
    x_freq = fft.ifftshift(x_freq, dim=(-3, -2, -1))  # Shift back to original FFT layout
    x_filtered = fft.ifftn(x_freq, dim=(-3, -2, -1)).real  # Perform inverse FFT and get real part
    
    x_filtered = x_filtered.type(dtype)  # Convert back to original data type
    return x_filtered


@dataclass
class PipelineOutput(BaseOutput):
    video_latents: Union[torch.Tensor, np.ndarray]
    all_latents: Union[torch.Tensor, np.ndarray, List[torch.Tensor]]


class VExpressPipelinePrefixMeanVarFace(
    DiffusionPipeline,
    StableDiffusionMixin,
    LoraLoaderMixin,
    StableDiffusionLoraLoaderMixin,
    PAGMixin,
):
    r"""
    Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
    https://arxiv.org/abs/2302.08453

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
            Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
            list, the outputs from each Adapter are added together to create one combined additional conditioning.
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        denoising_unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
    """

    _optional_components = []

    def __init__(
        self,
        vae,
        unet,
        v_kps_guider,
        audio_processor,
        audio_encoder,
        audio_projection,
        scheduler: Union[
            DDIMScheduler,
            PNDMScheduler,
            LMSDiscreteScheduler,
            EulerDiscreteScheduler,
            EulerAncestralDiscreteScheduler,
            DPMSolverMultistepScheduler,
        ],
        image_encoder=None,
        tokenizer=None,
        text_encoder=None,
        face_analysis_app=None,
        controlnet=None,
        ip_ckpt: str = None,
        ip_mode: str = None,
        resampler_depth: int = 4,
        num_tokens: int = 16,
        n_cond: int = 1,
        device: str = "cuda",
        adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter], None] = None,
        lora_path: str = None,
        is_inverted: bool = False,
        invert_kwargs: Optional[Dict[str, Any]] = None,
        lora_scale: Optional[float] = 0.5,
        extra_lora_setting: Optional[Dict] = None,
        pag_applied_layers: Union[str, List[str], None] = [],
        store_attn: bool = False,
        store_attn_key: str = None,
        store_qk: bool = False,
        store_qk_key: str = None,
    ):
        super().__init__()
        if isinstance(adapter, (list, tuple)):
            adapter = MultiAdapter(adapter)

        self.register_modules(
            vae=vae,
            unet=unet,
            v_kps_guider=v_kps_guider,
            audio_processor=audio_processor,
            audio_encoder=audio_encoder,
            audio_projection=audio_projection,
            scheduler=scheduler,
            image_encoder=image_encoder,
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            adapter=adapter,
            controlnet=controlnet,
        )
        self.is_animatediff = self.unet.__class__.__name__ == "UNetMotionModel"
        self.ip_ckpt = ip_ckpt
        self.ip_mode = ip_mode
        self.resampler_depth = resampler_depth
        self.num_tokens = num_tokens
        self.n_cond = n_cond
        self.disable_kps = self.v_kps_guider is None
        self.apply_t2i_adapter = self.adapter is not None
        self.disable_audio = self.audio_encoder is None
        self.disable_ipa = self.image_encoder is None
        self.face_analysis_app = face_analysis_app
        self.store_attn = store_attn
        self.store_qk = store_qk

        device = self._execution_device

        if not self.disable_ipa:
            set_ip_adapter(self, device, store_attn=store_attn, hook_attn_key=store_attn_key, store_qk=store_qk, hook_qk_key=store_qk_key)

        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.clip_image_processor = CLIPImageProcessor()
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
        self.condition_image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor,
            do_convert_rgb=True,
            do_normalize=False,
        )

        self.set_pag_applied_layers(pag_applied_layers)

        # IP-Adapter: image proj model, IPAttnProcessor
        if not self.disable_ipa:
            if self.ip_mode in ['faceid-decoupled']:
                self.image_proj_model_1, self.image_proj_model_2 = init_proj(self)
            else:
                self.image_proj_model = init_proj(self)
            load_ip_adapter(self)

        # LoRA Settings
        if lora_path is not None:
            print(f"[INFO] loading LoRA weights from {lora_path}")
            import safetensors

            lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
            lora_new_state, motion_module_lora_state_dict = OrderedDict(), OrderedDict()
            for k in lora_state_dict.keys():
                if k.startswith("base_model.model."):
                    key = k.replace("base_model.model.", "")
                    if "motion_module" in key:
                        motion_module_lora_state_dict[key] = lora_state_dict[k]
                    else:
                        lora_new_state[key] = lora_state_dict[k]

            if len(motion_module_lora_state_dict) != 0:
                self.unet = load_diffusers_lora(
                    self.unet, motion_module_lora_state_dict, alpha=lora_scale
                )
            if len(lora_new_state) != 0:
                self.unet.load_attn_procs(lora_new_state)
            print(f"[INFO] loaded LoRA weights!")
        # TEST: Load LoRA Weights
        self.unet.__class__.__name__ = "UNetMotionModel"
        lora_names, lora_weights = [], []
        for lora_name, lora_set in extra_lora_setting.items():
            self.load_lora_weights(lora_set["lora_path"], adapter_name=lora_set["lora_name"])
            lora_names.append(lora_set["lora_name"])
            lora_weights.append(lora_set["lora_weight"])
        if len(lora_names) != 0:
            self.set_adapters(lora_names, adapter_weights=lora_weights)
            self.fuse_lora(adapter_names=lora_names, lora_scale=lora_scale)

    def enable_vae_slicing(self):
        self.vae.enable_slicing()

    def disable_vae_slicing(self):
        self.vae.disable_slicing()

    def enable_sequential_cpu_offload(self, gpu_id=0):
        if is_accelerate_available():
            from accelerate import cpu_offload
        else:
            raise ImportError("Please install accelerate via `pip install accelerate`")

        device = torch.device(f"cuda:{gpu_id}")

        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
            if cpu_offloaded_model is not None:
                cpu_offload(cpu_offloaded_model, device)

    @property
    def _execution_device(self):
        if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
            return self.device
        for module in self.unet.modules():
            if (
                hasattr(module, "_hf_hook")
                and hasattr(module._hf_hook, "execution_device")
                and module._hf_hook.execution_device is not None
            ):
                return torch.device(module._hf_hook.execution_device)
        return self.device

    @torch.no_grad()
    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
        if isinstance(generator, list):
            image_latents = [
                retrieve_latents(
                    self.vae.encode(image[i : i + 1]), generator=generator[i]
                )
                for i in range(image.shape[0])
            ]
            image_latents = torch.cat(image_latents, dim=0)
        else:
            image_latents = retrieve_latents(
                self.vae.encode(image), generator=generator, sample_mode = "sample",
            )

        image_latents = self.vae.config.scaling_factor * image_latents

        return image_latents

    @torch.no_grad()
    def decode_latents(self, latents):
        video_length = latents.shape[2]
        latents = 1 / 0.18215 * latents
        latents = rearrange(latents, "b c f h w -> (b f) c h w")
        # video = self.vae.decode(latents).sample
        video = []
        for frame_idx in tqdm(range(latents.shape[0])):
            image = self.vae.decode(
                latents[frame_idx : frame_idx + 1].to(self.vae.device)
            ).sample
            video.append(image)
        video = torch.cat(video)
        video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
        video = (video / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
        video = video.cpu().float().numpy()
        return video

    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(
            inspect.signature(self.scheduler.step).parameters.keys()
        )
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(
            inspect.signature(self.scheduler.step).parameters.keys()
        )
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

    def prepare_latents(
        self,
        batch_size,
        num_channels_latents,
        width,
        height,
        video_length,
        dtype,
        device,
        generator,
        latents=None,
        image=None,
        timestep=None,
        is_strength_max=True,
        return_noise=False,
        return_image_latents=False,
    ):
        shape = (
            batch_size,
            num_channels_latents,
            video_length,
            height // self.vae_scale_factor,
            width // self.vae_scale_factor,
        )
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if (image is None or timestep is None) and not is_strength_max:
            raise ValueError(
                "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
                "However, either the image or the noise timestep has not been provided."
            )

        if return_image_latents or (latents is None and not is_strength_max):
            image = image.to(device=device, dtype=dtype)

            if image.shape[1] == 4:
                image_latents = image
            else:
                image_latents = self._encode_vae_image(image=image, generator=generator)
            image_latents = image_latents.repeat(
                batch_size // image_latents.shape[0], 1, 1, 1
            )
            image_latents = image_latents.unsqueeze(2).repeat(1, 1, video_length, 1, 1)

        if latents is None:
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
            # if strength is 1. then initialise the latents to noise, else initial to image + noise
            if is_strength_max:
                latents = noise
                # if pure noise then scale the initial latents by the  Scheduler's init sigma
                latents = latents * self.scheduler.init_noise_sigma
            else:
                print(f"Init Start Latents by adding Noise at timestep {timestep} to image latents")
                latents = self.scheduler.add_noise(image_latents, noise, timestep)

        else:
            noise = latents.to(device)
            latents = noise * self.scheduler.init_noise_sigma

        outputs = (latents,)

        if return_noise:
            outputs += (noise,)

        if return_image_latents:
            outputs += (image_latents,)

        return outputs

    def _encode_prompt(
        self,
        prompt,
        device,
        num_videos_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
    ):
        batch_size = len(prompt) if isinstance(prompt, list) else 1

        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        untruncated_ids = self.tokenizer(
            prompt, padding="longest", return_tensors="pt"
        ).input_ids

        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
            text_input_ids, untruncated_ids
        ):
            removed_text = self.tokenizer.batch_decode(
                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
            )

        if (
            hasattr(self.text_encoder.config, "use_attention_mask")
            and self.text_encoder.config.use_attention_mask
        ):
            attention_mask = text_inputs.attention_mask.to(device)
        else:
            attention_mask = None

        text_embeddings = self.text_encoder(
            text_input_ids.to(device),
            attention_mask=attention_mask,
        )
        text_embeddings = text_embeddings[0]

        # duplicate text embeddings for each generation per prompt, using mps friendly method
        bs_embed, seq_len, _ = text_embeddings.shape
        text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
        text_embeddings = text_embeddings.view(
            bs_embed * num_videos_per_prompt, seq_len, -1
        )

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = text_input_ids.shape[-1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if (
                hasattr(self.text_encoder.config, "use_attention_mask")
                and self.text_encoder.config.use_attention_mask
            ):
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            uncond_embeddings = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            uncond_embeddings = uncond_embeddings[0]

            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = uncond_embeddings.shape[1]
            uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
            uncond_embeddings = uncond_embeddings.view(
                batch_size * num_videos_per_prompt, seq_len, -1
            )
        else:
            uncond_embeddings = torch.zeros_like(text_embeddings)

        return text_embeddings, uncond_embeddings

    def get_timesteps(self, num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

        return timesteps, num_inference_steps - t_start

    def prepare_kps_feature(
        self, kps_images, height, width, do_classifier_free_guidance
    ):
        kps_image_tensors = []
        for idx, kps_image in enumerate(kps_images):
            kps_image_tensor = self.condition_image_processor.preprocess(
                kps_image, height=height, width=width
            )
            kps_image_tensor = kps_image_tensor.unsqueeze(2)  # [bs, c, 1, h, w]
            kps_image_tensors.append(kps_image_tensor)
        kps_images_tensor = torch.cat(kps_image_tensors, dim=2)  # [bs, c, t, h, w]
        kps_images_tensor = kps_images_tensor.to(device=self.device, dtype=self.dtype)

        kps_feature = self.v_kps_guider(kps_images_tensor)

        if do_classifier_free_guidance:
            uc_kps_feature = torch.zeros_like(kps_feature)
            kps_feature = torch.cat([uc_kps_feature, kps_feature], dim=0)

        return kps_feature

    def prepare_t2iadapter_feature(
        self,
        adapter_images,
        video_length,
        height,
        width,
        num_images_per_prompt,
        do_classifier_free_guidance,
        t2i_adapter_conditioning_scale,
        device,
    ):
        adapter_inputs = []

        for image in adapter_images:
            if isinstance(self.adapter, MultiAdapter):
                adapter_input = []
                for one_image in image:
                    one_image = _preprocess_adapter_image(
                        one_image, height, width
                    )  # pil -> tensor (b, c, h, w)
                    one_image = one_image.to(device=device, dtype=self.adapter.dtype)
                    adapter_input.append(one_image)
            else:
                adapter_input = _preprocess_adapter_image(image, height, width)
                adapter_input = adapter_input.to(
                    device=device, dtype=self.adapter.dtype
                )

            adapter_inputs.append(adapter_input)

        down_intrablock_additional_residuals = []
        for adapter_input in adapter_inputs:
            if isinstance(self.adapter, MultiAdapter):
                adapter_state = self.adapter(
                    adapter_input, t2i_adapter_conditioning_scale
                )
                for k, v in enumerate(adapter_state):
                    adapter_state[k] = v
            else:
                adapter_state = self.adapter(adapter_input)
                for k, v in enumerate(adapter_state):
                    adapter_state[k] = v * t2i_adapter_conditioning_scale
            if num_images_per_prompt > 1:
                for k, v in enumerate(adapter_state):
                    adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
            if do_classifier_free_guidance:
                for k, v in enumerate(adapter_state):
                    # adapter_state[k] = torch.cat([v] * 2, dim=0) # REVIEW: zeros T2I conditions or just repeat??????
                    adapter_state[k] = torch.cat([torch.zeros_like(v), v], dim=0)
            down_intrablock_additional_residuals.append(
                [state.clone() for state in adapter_state]
            )

        if len(down_intrablock_additional_residuals) > 0:
            # Initialize the output list
            output = [
                torch.zeros_like(tensor)
                .unsqueeze(2)
                .repeat(
                    1, 1, video_length, 1, 1
                )  # inject frame dimension, form (bs, c, f, h, w) tensor
                for tensor in down_intrablock_additional_residuals[0]
            ]
            # Iterate over frames and concatenate along the new dimension
            for i, tensors in enumerate(zip(*down_intrablock_additional_residuals)):
                output[i] = torch.stack(tensors, dim=2)

            return output

    def prepare_audio_embeddings(
        self,
        audio_waveform,
        video_length,
        num_pad_audio_frames,
        do_classifier_free_guidance,
    ):
        audio_waveform = self.audio_processor(
            audio_waveform, return_tensors="pt", sampling_rate=16000
        )["input_values"]
        audio_waveform = audio_waveform.to(self.device, self.dtype)
        audio_embeddings = self.audio_encoder(
            audio_waveform
        ).last_hidden_state  # [1, num_embeds, d]

        audio_embeddings = torch.nn.functional.interpolate(
            audio_embeddings.permute(0, 2, 1),
            size=2 * video_length,
            mode="linear",
        )[0, :, :].permute(
            1, 0
        )  # [2*vid_len, dim]

        audio_embeddings = torch.cat(
            [
                torch.zeros_like(audio_embeddings)[: 2 * num_pad_audio_frames, :],
                audio_embeddings,
                torch.zeros_like(audio_embeddings)[: 2 * num_pad_audio_frames, :],
            ],
            dim=0,
        )  # [2*num_pad+2*vid_len+2*num_pad, dim]

        frame_audio_embeddings = []
        for frame_idx in range(video_length):
            start_sample = frame_idx
            end_sample = frame_idx + 2 * num_pad_audio_frames

            frame_audio_embedding = audio_embeddings[
                2 * start_sample : 2 * (end_sample + 1), :
            ]  # [2*num_pad+1, dim]
            frame_audio_embeddings.append(frame_audio_embedding)
        audio_embeddings = torch.stack(
            frame_audio_embeddings, dim=0
        )  # [vid_len, 2*num_pad+1, dim]

        audio_embeddings = self.audio_projection(audio_embeddings).unsqueeze(0)
        if do_classifier_free_guidance:
            uc_audio_embeddings = torch.zeros_like(audio_embeddings)
            audio_embeddings = torch.cat([uc_audio_embeddings, audio_embeddings], dim=0)
        return audio_embeddings

    @torch.inference_mode()
    def get_image_embeds(self, pil_image=None, faceid_embeds=None):
        if pil_image is not None:
            if isinstance(pil_image, Image.Image):
                pil_image = [pil_image]
            clip_image = self.clip_image_processor(
                images=pil_image, return_tensors="pt"
            ).pixel_values
            clip_image = clip_image.to(self.device, dtype=torch.float16)
            clip_image_embeds = self.image_encoder(
                clip_image, output_hidden_states=True
            ).hidden_states[-2]
            uncond_clip_image_embeds = self.image_encoder(
                torch.zeros_like(clip_image), output_hidden_states=True
            ).hidden_states[-2]
        if faceid_embeds is not None:
            if faceid_embeds.dim() == 3:
                b, n, c = faceid_embeds.shape
                faceid_embeds = faceid_embeds.reshape(b * n, c)

            faceid_embeds = faceid_embeds.to(self.device, dtype=self.vae.dtype)

        if self.ip_mode in ['faceid', 'faceid-lora', 'portrait']:
            image_prompt_embeds = self.image_proj_model(faceid_embeds)
            uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
        elif self.ip_mode in ['faceid-plus', 'faceid-plus-lora']:
            image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds)
            uncond_image_prompt_embeds = self.image_proj_model(
                torch.zeros_like(faceid_embeds), uncond_clip_image_embeds
            )
        elif self.ip_mode in ["vanilla", "plus", "full_face"]:
            image_prompt_embeds = self.image_proj_model(clip_image_embeds)
            uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
        elif self.ip_mode in ['faceid-decoupled']:
            image_prompt_embeds_1 = self.image_proj_model_1(faceid_embeds)
            uncond_image_prompt_embeds_1 = self.image_proj_model_1(torch.zeros_like(faceid_embeds))
            image_prompt_embeds_2 = self.image_proj_model_2(clip_image_embeds)
            uncond_image_prompt_embeds_2 = self.image_proj_model_2(uncond_clip_image_embeds)
            image_prompt_embeds = torch.cat([image_prompt_embeds_1, image_prompt_embeds_2], dim=1)
            uncond_image_prompt_embeds = torch.cat(
                [uncond_image_prompt_embeds_1, uncond_image_prompt_embeds_2], dim=1
            )

        return image_prompt_embeds, uncond_image_prompt_embeds

    def prepare_control_image(
        self,
        image,
        width,
        height,
        batch_size,
        device,
        dtype,
        num_images_per_prompt=1,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        image = self.condition_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
        image_batch_size = image.shape[0]

        if image_batch_size == 1:
            repeat_by = batch_size
        else:
            # image batch size is the same as prompt batch size
            repeat_by = num_images_per_prompt

        image = image.repeat_interleave(repeat_by, dim=0)

        image = image.to(device=device, dtype=dtype)

        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        return image

    def process_controlnet_inputs(self, latent_model_input, t, controlnet_prompt_embeds, controlnet_image, cond_scale):
        batch_size, channel, num_frames, height, width = latent_model_input.shape
        down_blocks = []
        mid_blocks = []

        for frame_index in range(num_frames):
            if isinstance(controlnet_image, List) and len(controlnet_image) == num_frames:
                from modules.adapter.face_adapter.utils import pil2tensor
                ctrlnet_image = controlnet_image[frame_index]
                image_size = ctrlnet_image.size[0]
                ctrlnet_image = pil2tensor(ctrlnet_image).view(1, 3, image_size, image_size).to(device=latent_model_input.device, dtype=latent_model_input.dtype)
            else:
                ctrlnet_image = controlnet_image
            single_frame_input = latent_model_input[:, :, frame_index, :, :]
            down_block_res_samples, mid_block_res_sample = self.controlnet(
                single_frame_input,
                t,
                encoder_hidden_states=controlnet_prompt_embeds,
                controlnet_cond=ctrlnet_image,
                conditioning_scale=cond_scale,
                return_dict=False,
            )

            down_blocks.append(down_block_res_samples)
            mid_blocks.append(mid_block_res_sample)

        concatenated_down_blocks = [torch.stack([block[i] for block in down_blocks], dim=2) for i in range(len(down_blocks[0]))]
        concatenated_mid_block = torch.stack(mid_blocks, dim=2)

        return concatenated_down_blocks, concatenated_mid_block

    @torch.no_grad()
    def __call__(
        self,
        audio_waveform,
        width: int,
        height: int,
        video_length: int,
        num_inference_steps: int,
        guidance_scale: float,
        kps_images=None,
        mask_images=None,
        reference_image_path=None,
        refbg_image_path=None,
        face_embeds=None,
        prompt: Optional[Union[str, List[str]]] = None,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: int = 1,
        strength: float = 1.0,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "tensor",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: Optional[int] = 1,
        context_schedule: str = "uniform_prefix",
        context_frames: int = 24,
        context_stride: int = 1,
        context_overlap: int = 4,
        context_batch_size: int = 1,
        num_pad_audio_frames: int = 2,
        guidance_rescale: float = 0.0,
        ipa_scale: float = 1.0,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        t2i_adapter_control_type: Optional[Union[str, List[str]]] = None,
        t2i_adapter_conditioning_scale: Union[float, List[float]] = 1.0,
        align_color_alpha: float = 0.6,
        reference_adain: bool = True,
        b1=1.2,
        b2=1.4,
        s1=0.9,
        s2=0.2,
        threshold=1,
        dynthresh_kwargs: Optional[Dict] = None,
        controlnet_image=None,
        controlnet_prompt_embeds: Optional[torch.FloatTensor] = None,
        controlnet_negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
        control_guidance_start: Union[float, List[float]] = 0.0,
        control_guidance_end: Union[float, List[float]] = 1.0,
        point_kps_images=None,
        ctrl_kps=False,
        pag_scale: float = 0.0,
        pag_adaptive_scale: float = 0,
        **kwargs,
    ):
        # 1. Define call parameters
        self._pag_scale = pag_scale
        self._pag_adaptive_scale = pag_adaptive_scale
        # align format for control guidance
        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
            mult = len(self.controlnet.nets) if isinstance(self.controlnet, MultiControlNetModel) else 1
            control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
                control_guidance_end
            ]

        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        n_motion_frames = (
            cross_attention_kwargs.pop("n_motion_frames", 0)
            if cross_attention_kwargs is not None
            else 0
        )

        device = self._execution_device

        do_classifier_free_guidance = guidance_scale > 1.0
        batch_size = 1

        set_ipa_scale(self, ipa_scale)  # set ip_attn_processor scale
        print(f"{'='*30}\nPROMPT: {prompt}\n{'='*30}")

        # 2. Prepare timesteps
        timesteps = None
        timesteps, num_inference_steps = retrieve_timesteps(
            self.scheduler, num_inference_steps, device, timesteps
        )
        timesteps, num_inference_steps = self.get_timesteps(
            num_inference_steps, strength, device
        )
        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
        print(f"With strength {strength}; The timesteps are {timesteps}")
        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
        is_strength_max = strength == 1.0

        # 3. Prepare latent variables
        # TEST: Temp Add Prefix Ref Image Latents
        ref_image = load_img(image_path=reference_image_path, device=device)
        ref_image_vae = self.image_processor.preprocess(
            ref_image, height=height, width=width
        )
        num_channels_latents = self.unet.in_channels

        # 3.1 start latents for all frames
        latents, noise, image_latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            width,
            height,
            video_length,
            self.dtype,
            device,
            generator,
            image=ref_image_vae,
            timestep=latent_timestep,
            is_strength_max=is_strength_max,
            return_noise=True,
            return_image_latents=(num_channels_latents == 4),
        )  # [1, 4, 120, 64, 64]) bs x c x video_len x h x w

        # 4. prepare weak-condition prompts
        # text prompts embeds
        if prompt is None:
            prompt = "best quality, high quality"
        # TEST: enhanced prompt
        else:
            prompt_prefix = "((ultra realistic, RAW photo)), (high quality, best quality:1.25), (lifelike texture), Realistic lighting, "
            prompt = prompt_prefix + prompt
            prompt += "sharp focus, hyper realistic, Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - (better illustration), (better shadow), beautiful detailed shine, realistic maximum detail, Color Grading, skin pore detailing, hyper sharpness, perfect without deformations."
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
        if not isinstance(prompt, List):
            prompt = [prompt]
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt]
        with torch.inference_mode():
            prompt_embeds_, negative_prompt_embeds_ = self._encode_prompt(
                prompt,
                device=device,
                num_videos_per_prompt=num_images_per_prompt,
                do_classifier_free_guidance=do_classifier_free_guidance,
                negative_prompt=negative_prompt,
            )
            # prompt_embeds_, negative_prompt_embeds_ = prompt_embeds_.unsqueeze(0), negative_prompt_embeds_.unsqueeze(0)
            if not self.disable_ipa:
                # face-id image latents
                if self.ip_mode in ['vanilla', 'plus', "full_face", 'faceid-plus', 'faceid-plus-lora', 'faceid-decoupled']:
                    source_image = Image.open(reference_image_path).resize((512, 512))  # ((256, 256))
                    # TEST: set the source image as another image to provide bg information
                    if refbg_image_path is not None and self.ip_mode in ['faceid-decoupled']:
                        refbg_image = Image.open(refbg_image_path).resize((height, width))
                        refbg_image_cv2 = cv2.cvtColor(np.array(refbg_image), cv2.COLOR_RGB2BGR)
                        refbg_faces = self.face_analysis_app.get(refbg_image_cv2)
                        print(f"Number of faces in the BG Image: {len(refbg_faces)}")
                        try:
                            source_image = load_masked_image_from_faceinfo(refbg_image, refbg_faces[0])
                        except:
                            print("[Warning] Cannot Load Mask from the Source Image !!")
                            if kwargs.get('face_info', None):
                                faces_info = kwargs["face_info"]
                                refbg_image = load_masked_image_from_faceinfo(refbg_image, faces_info[0])
                            source_image = refbg_image
                else:
                    source_image = None
                image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
                    pil_image=source_image, faceid_embeds=face_embeds
                )
                prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
                negative_prompt_embeds = torch.cat(
                    [negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1
                )
                controlnet_prompt_embeds = torch.cat([controlnet_negative_prompt_embeds, controlnet_prompt_embeds]).to(prompt_embeds)
            else:
                prompt_embeds = prompt_embeds_
                negative_prompt_embeds = negative_prompt_embeds_

            # TEST: Zeros NEG Prompts
            # prompt_embeds = torch.cat([torch.zeros_like(prompt_embeds), prompt_embeds])
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
            prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, video_length, 1, 1)

        if controlnet_image is not None:
            controlnet_image = self.prepare_control_image(
                controlnet_image, width, height, batch_size, device, self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance,
            )

        # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 6. Prepare extra conditions
        # 6.1 prepare residual embeddings
        if self.apply_t2i_adapter:
            # T2I-Adapter Inputs
            all_adapter_images = []
            if "kps" in t2i_adapter_control_type or "openpose" in t2i_adapter_control_type:
                all_adapter_images.append(kps_images)
            if "mask" in t2i_adapter_control_type:
                all_adapter_images.append(mask_images)
            if len(t2i_adapter_control_type) > 1:
                adapter_images = rearrange_images(all_adapter_images)
            elif len(t2i_adapter_control_type) == 1:
                adapter_images = all_adapter_images[0]
            else:
                raise NotImplementedError("t2i_adapter_control_type not supported!")

            # Encode Adapter Inputs Images as residuals
            down_intrablock_additional_residuals = self.prepare_t2iadapter_feature(
                adapter_images, video_length, height, width, num_images_per_prompt, do_classifier_free_guidance, t2i_adapter_conditioning_scale, device,
            )
        if not self.disable_kps:
            # KPS Inputs
            kps_feature = self.prepare_kps_feature(
                kps_images, height, width, do_classifier_free_guidance
            )  # ([2, 320, 120, 64, 64]), (bs, c, f, h, w)
        if self.disable_kps and not self.apply_t2i_adapter:
            print("[INFO] No residual conditions are applied!")

        # 6.2 prepare audio embeddings
        if not self.disable_audio:
            audio_embeddings = self.prepare_audio_embeddings(
                audio_waveform, video_length, num_pad_audio_frames, do_classifier_free_guidance,
            )  # ([2, 120, 5, 768])

        # TEST: Hack as a Charm
        MODE = "write"
        _TIMESTEP = 0

        def hack_CrossAttnDownBlock3D_forward(
            self,
            hidden_states: torch.FloatTensor,
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            text_hidden_states: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            additional_residuals: Optional[torch.FloatTensor] = None,
        ):
            eps = 1e-6
            output_states = ()

            blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
            for i, (resnet, attn, motion_module) in enumerate(blocks):
                hidden_states = resnet(hidden_states, temb)
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    text_hidden_states=text_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                ).sample
                if i == len(blocks) - 1 and additional_residuals is not None:
                    hidden_states = hidden_states + additional_residuals
                hidden_states = (
                    motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
                    if motion_module is not None else hidden_states
                )
                if MODE == "write":
                    var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                    self.mean_bank[_TIMESTEP].append([mean])
                    self.var_bank[_TIMESTEP].append([var])
                if MODE == "read":
                    if len(self.mean_bank[_TIMESTEP]) > 0 and len(self.var_bank[_TIMESTEP]) > 0:
                        var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
                        mean_acc = sum(self.mean_bank[_TIMESTEP][i]) / float(len(self.mean_bank[_TIMESTEP][i]))
                        var_acc = sum(self.var_bank[_TIMESTEP][i]) / float(len(self.var_bank[_TIMESTEP][i]))
                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
                        hidden_states[:, :, n_motion_frames:] = (
                            ((hidden_states[:, :, n_motion_frames:] - mean) / std) * std_acc
                        ) + mean_acc

                output_states = output_states + (hidden_states,)

            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    hidden_states = downsampler(hidden_states)

                output_states = output_states + (hidden_states,)

            return hidden_states, output_states

        def hack_DownBlock3D_forward(
            self,
            hidden_states: torch.FloatTensor,
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            additional_residuals: Optional[torch.FloatTensor] = None,
        ):
            eps = 1e-6
            output_states = ()

            blocks = list(zip(self.resnets, self.motion_modules))
            for i, (resnet, motion_module) in enumerate(blocks):
                if i == len(blocks) - 1 and additional_residuals is not None:
                    hidden_states = hidden_states + additional_residuals
                hidden_states = (
                    motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
                    if motion_module is not None else hidden_states
                )
                if MODE == "write":
                    var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                    self.mean_bank[_TIMESTEP].append([mean])
                    self.var_bank[_TIMESTEP].append([var])
                if MODE == "read":
                    if len(self.mean_bank[_TIMESTEP]) > 0 and len(self.var_bank[_TIMESTEP]) > 0:
                        var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
                        mean_acc = sum(self.mean_bank[_TIMESTEP][i]) / float(len(self.mean_bank[_TIMESTEP][i]))
                        var_acc = sum(self.var_bank[_TIMESTEP][i]) / float(len(self.var_bank[_TIMESTEP][i]))
                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
                        hidden_states[:, :, n_motion_frames:] = (
                            ((hidden_states[:, :, n_motion_frames:] - mean) / std) * std_acc
                        ) + mean_acc

                output_states = output_states + (hidden_states,)

            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    hidden_states = downsampler(hidden_states)

                output_states = output_states + (hidden_states,)

            return hidden_states, output_states

        def hack_CrossAttnUpBlock3D_forward(
            self,
            hidden_states: torch.FloatTensor,
            res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            text_hidden_states: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            upsample_size: Optional[int] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
        ):
            eps = 1e-6

            blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
            for i, (resnet, attn, motion_module) in enumerate(blocks):
                res_hidden_states = res_hidden_states_tuple[-1]
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
                # --------------- FreeU code -----------------------
                # Only operate on the first two stages
                if hidden_states.shape[1] == 1280:
                    hidden_states[:,:640] = hidden_states[:,:640] * b1
                    res_hidden_states = Fourier_filter(res_hidden_states, threshold=threshold, scale=s1)
                if hidden_states.shape[1] == 640:
                    hidden_states[:,:320] = hidden_states[:,:320] * b2
                    res_hidden_states = Fourier_filter(
                        res_hidden_states, threshold=threshold, scale=s2
                    )
                # ---------------------------------------------------------
                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

                hidden_states = resnet(hidden_states, temb)
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    text_hidden_states=text_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                ).sample
                hidden_states = (
                    motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
                    if motion_module is not None else hidden_states
                )
                if MODE == "write":
                    var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                    self.mean_bank[_TIMESTEP].append([mean])
                    self.var_bank[_TIMESTEP].append([var])
                if MODE == "read":
                    if len(self.mean_bank[_TIMESTEP]) > 0 and len(self.var_bank[_TIMESTEP]) > 0:
                        var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
                        mean_acc = sum(self.mean_bank[_TIMESTEP][i]) / float(len(self.mean_bank[_TIMESTEP][i]))
                        var_acc = sum(self.var_bank[_TIMESTEP][i]) / float(len(self.var_bank[_TIMESTEP][i]))
                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
                        hidden_states[:, :, n_motion_frames:] = (
                            ((hidden_states[:, :, n_motion_frames:] - mean) / std) * std_acc
                        ) + mean_acc

            if self.upsamplers is not None:
                for upsampler in self.upsamplers:
                    hidden_states = upsampler(hidden_states, upsample_size)

            return hidden_states

        def hack_UpBlock3D_forward(
            self,
            hidden_states: torch.FloatTensor,
            res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            upsample_size: Optional[int] = None,
        ):
            eps = 1e-6

            blocks = list(zip(self.resnets, self.motion_modules))
            for i, (resnet, motion_module) in enumerate(blocks):
                res_hidden_states = res_hidden_states_tuple[-1]
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
                # --------------- FreeU code -----------------------
                # Only operate on the first two stages
                if hidden_states.shape[1] == 1280:
                    hidden_states[:,:640] = hidden_states[:,:640] * b1
                    res_hidden_states = Fourier_filter(res_hidden_states, threshold=threshold, scale=s1)
                if hidden_states.shape[1] == 640:
                    hidden_states[:,:320] = hidden_states[:,:320] * b2
                    res_hidden_states = Fourier_filter(res_hidden_states, threshold=threshold, scale=s2)
                # ---------------------------------------------------------
                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

                hidden_states = resnet(hidden_states, temb)
                hidden_states = (
                    motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
                    if motion_module is not None else hidden_states
                )
                if MODE == "write":
                    var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                    self.mean_bank[_TIMESTEP].append([mean])
                    self.var_bank[_TIMESTEP].append([var])
                if MODE == "read":
                    if len(self.mean_bank[_TIMESTEP]) > 0 and len(self.var_bank[_TIMESTEP]) > 0:
                        var, mean = torch.var_mean(hidden_states[:, :, n_motion_frames:], dim=(2, 3, 4), keepdim=True, correction=0)
                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
                        mean_acc = sum(self.mean_bank[_TIMESTEP][i]) / float(len(self.mean_bank[_TIMESTEP][i]))
                        var_acc = sum(self.var_bank[_TIMESTEP][i]) / float(len(self.var_bank[_TIMESTEP][i]))
                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
                        hidden_states[:, :, n_motion_frames:] = (
                            ((hidden_states[:, :, n_motion_frames:] - mean) / std) * std_acc
                        ) + mean_acc

            if self.upsamplers is not None:
                for upsampler in self.upsamplers:
                    hidden_states = upsampler(hidden_states, upsample_size)

            return hidden_states

        if reference_adain:
            gn_modules = [self.unet.mid_block]

            down_blocks = self.unet.down_blocks
            for w, module in enumerate(down_blocks):
                gn_modules.append(module)

            up_blocks = self.unet.up_blocks
            for w, module in enumerate(up_blocks):
                gn_modules.append(module)

            for i, module in enumerate(gn_modules):
                if getattr(module, "original_forward", None) is None:
                    module.original_forward = module.forward
                if i == 0:
                    # mid_block
                    # module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
                    pass
                # elif isinstance(module, CrossAttnDownBlock3D):
                #     module.forward = hack_CrossAttnDownBlock3D_forward.__get__(module, CrossAttnDownBlock3D)
                # elif isinstance(module, DownBlock3D):
                #     module.forward = hack_DownBlock3D_forward.__get__(module, DownBlock3D)
                elif isinstance(module, CrossAttnUpBlock3D):
                    module.forward = hack_CrossAttnUpBlock3D_forward.__get__(module, CrossAttnUpBlock3D)
                elif isinstance(module, UpBlock3D):
                    module.forward = hack_UpBlock3D_forward.__get__(module, UpBlock3D)

                module.mean_bank, module.var_bank = {}, {}
                for t in timesteps:
                    t = t.cpu().item()
                    module.mean_bank[t], module.var_bank[t] = [], []

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        # Set PAG
        if self.do_perturbed_attention_guidance:
            original_attn_proc = self.unet.attn_processors
            self._set_pag_attn_processor(
                pag_applied_layers=self.pag_applied_layers,
                do_classifier_free_guidance=do_classifier_free_guidance,
            )
        # 7.0 Create tensor stating which controlnets to keep
        if not self.disable_ipa:
            controlnet_keep = []
            for i in range(len(timesteps)):
                keeps = [
                    1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
                    for s, e in zip(control_guidance_start, control_guidance_end)
                ]
                controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)

        # 7.1 Initialize Contect Queue
        context_scheduler = get_context_scheduler(context_schedule)
        context_queue = list(
            context_scheduler(
                step=0,
                num_frames=video_length,
                context_size=context_frames,
                context_stride=context_stride,
                context_overlap=context_overlap,
                closed_loop=True,
                start_num_frames=context_frames + n_motion_frames,
            )
        )  # video context, 120 video_len, 24 context_frames, 1 context_stride, 4 context_overlap -> 6 queues, each queue contains 24 frames,

        num_context_batches = math.ceil(len(context_queue) / context_batch_size)
        global_context = []
        for k in range(num_context_batches):
            global_context.append(
                context_queue[
                    k * context_batch_size : (k + 1) * context_batch_size
                ]
            )

        tensor_result = []

        context_ids_prefix = []

        # 7.2 Denoising by Iterating Each Queue
        for idx, context in enumerate(global_context):

            context_ids = context[0]

            print("Current Context Window: ", context_ids)
            print("Prefix Context Window: ", context_ids_prefix)

            # 7.2.x.1 Prepare prefix Motion Frames [1, 4, n_motion_frames, 64, 64])
            if n_motion_frames != 0:
                if len(tensor_result) == 0:
                    # The first iteration
                    prefix_frames_ori = None
                else:
                    prefix_frames_ori = tensor_result[-1]
                    prefix_frames_ori = prefix_frames_ori[:, :, 0 - n_motion_frames :].to(
                        dtype=self.dtype, device=device
                    )
            else:
                prefix_frames_ori = None

            # 7.3 Noise Prediction within current queue
            with self.progress_bar(total=num_inference_steps) as progress_bar:
                for i, t in enumerate(timesteps):

                    latent_model_input = latents[:, :, context_ids, ...].to(device)

                    # 7.2.x.2 Add Noise to prefix Motion Frames [1, 4, n_motion_frames, 64, 64])
                    if prefix_frames_ori is not None:

                        # 7.3.1 expand the latents if we are doing classifier free guidance
                        # Concat Prefix with Latents,  # [2, 4, n_motion_frames + context_frames, 64, 64])
                        latent_model_input = torch.cat([prefix_frames_ori.clone(), latent_model_input], dim=2)

                    latent_model_input = latent_model_input.repeat(
                        2 if do_classifier_free_guidance else 1, 1, 1, 1, 1
                    )  # [2, 4, n_motion_frames + context_frames, 64, 64])

                    latent_model_input = self.scheduler.scale_model_input(
                        latent_model_input, t
                    )

                    # 7.3.2 Process Various Conditions based on the Context Window
                    # 7.3.2.1 Process Prefix Conditions of Various Conditions based on the Context Window
                    if prefix_frames_ori is not None:
                        prefix_kps = kps_feature[:, :, context_ids_prefix] if not self.disable_kps else None
                        prefix_residuals = [
                            residual[:, :, context_ids_prefix, ...]
                            for residual in down_intrablock_additional_residuals
                        ] if self.apply_t2i_adapter else None
                        prefix_audio_embeds = audio_embeddings[:, context_ids_prefix] if not self.disable_audio else None
                        prefix_prompt_embeds = prompt_embeds[:, context_ids_prefix] if prompt_embeds is not None else None
                    else:
                        prefix_kps = None
                        prefix_residuals = None
                        prefix_audio_embeds = None
                        prefix_prompt_embeds = None

                    # 7.3.2.2 Process Various Conditions based on the Context Window
                    if not self.disable_kps:
                        latent_kps_feature = kps_feature[:, :, context_ids, ...]  # [bs, c, context frames, h, w]
                        if prefix_kps is not None:
                            latent_kps_feature = torch.cat([prefix_kps, latent_kps_feature], dim=2)
                    else:
                        latent_kps_feature = None

                    if self.apply_t2i_adapter:
                        latent_down_intrablock_additional_residuals = [
                            residual[:, :, context_ids, ...]  # [bs, c, context frames, h, w]
                            for residual in down_intrablock_additional_residuals
                        ]
                        if prefix_residuals is not None:
                            latent_down_intrablock_additional_residuals = [
                                torch.cat([prefix, latent], dim=2)
                                for prefix, latent in zip(prefix_residuals, latent_down_intrablock_additional_residuals)
                            ]
                    else:
                        latent_down_intrablock_additional_residuals = None

                    if not self.disable_audio:
                        latent_audio_embeddings = audio_embeddings[:, context_ids, ...]  # [bs, n_motion_frames, num_tokens, dim]
                        if prefix_audio_embeds is not None:
                            latent_audio_embeddings = torch.cat([prefix_audio_embeds, latent_audio_embeddings], dim=1)
                        _, _, num_tokens, dim = latent_audio_embeddings.shape
                        latent_audio_embeddings = latent_audio_embeddings.reshape(
                            -1, num_tokens, dim
                        )  # ([24, 5, 768])
                    else:
                        latent_audio_embeddings = None

                    if prompt_embeds is not None:
                        latent_prompt_embeds = prompt_embeds[:, context_ids, ...]  # [bs, n_motion_frames, num_tokens, dim]
                        if prefix_prompt_embeds is not None:
                            latent_prompt_embeds = torch.cat([prefix_prompt_embeds, latent_prompt_embeds], dim=1)
                        _, _, num_tokens_prompt, dim_prompt = latent_prompt_embeds.shape
                        latent_prompt_embeds = latent_prompt_embeds.reshape(
                            -1, num_tokens_prompt, dim_prompt
                        )
                    else:
                        latent_prompt_embeds = None

                    if not self.disable_ipa:
                        # prepare controlnet input
                        if isinstance(controlnet_keep[i], list):
                            cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
                        else:
                            controlnet_cond_scale = controlnet_conditioning_scale
                            if isinstance(controlnet_cond_scale, list):
                                controlnet_cond_scale = controlnet_cond_scale[0]
                            cond_scale = controlnet_cond_scale * controlnet_keep[i]
                        if ctrl_kps:
                            # controlnet_image = [kps_images[i] for i in context_ids_prefix+context_ids]
                            controlnet_image = [point_kps_images[i] for i in context_ids_prefix+context_ids]
                        down_block_res_samples, mid_block_res_sample = self.process_controlnet_inputs(
                            latent_model_input,
                            t,
                            controlnet_prompt_embeds,
                            controlnet_image,
                            cond_scale,
                        )
                    else:
                        down_block_res_samples, mid_block_res_sample = None, None

                    # 7.3.3 Apply Denoising UNet for Noise Prediction
                    MODE = "write" if idx == 0 else "read"
                    _TIMESTEP = t.cpu().item()

                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=latent_audio_embeddings,  # [bs x context frames, 5, 768]
                        text_hidden_states=latent_prompt_embeds,  # [bs x context frames, 93, 768]
                        kps_features=latent_kps_feature,
                        cross_attention_kwargs=cross_attention_kwargs,
                        down_intrablock_additional_residuals=latent_down_intrablock_additional_residuals,
                        down_block_additional_residuals=down_block_res_samples,
                        mid_block_additional_residual=mid_block_res_sample,
                        return_dict=False,
                    )[0]

                    # 7.3.4 perform classifier free guidance
                    # TEST: Dynamic Threshold
                    # Initialize DynThresh with arguments from dynthresh_kwargs
                    dynamic_thresh = DynThresh(
                        mimic_scale=dynthresh_kwargs["mimic_scale"],
                        threshold_percentile=dynthresh_kwargs["threshold_percentile"],
                        mimic_mode=dynthresh_kwargs["mimic_mode"],
                        mimic_scale_min=dynthresh_kwargs["mimic_scale_min"],
                        cfg_mode=dynthresh_kwargs["cfg_mode"],
                        cfg_scale_min=dynthresh_kwargs["cfg_scale_min"],
                        sched_val=dynthresh_kwargs["sched_val"],
                        experiment_mode=dynthresh_kwargs["experiment_mode"],
                        max_steps=dynthresh_kwargs["max_steps"],
                        separate_feature_channels=(dynthresh_kwargs["separate_feature_channels"] == "enable"),
                        scaling_startpoint=dynthresh_kwargs["scaling_startpoint"],
                        variability_measure=dynthresh_kwargs["variability_measure"],
                        interpolate_phi=dynthresh_kwargs["interpolate_phi"],
                    )

                    # Perform classifier-free guidance with DynThresh
                    if do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        dynamic_thresh.step = 999 - t.item()

                        if guidance_scale == dynthresh_kwargs["mimic_scale"]:
                            noise_pred = noise_pred_uncond + guidance_scale * (
                                noise_pred_text - noise_pred_uncond
                            )
                        else:
                            noise_pred = dynamic_thresh.dynthresh(
                                noise_pred_text,
                                noise_pred_uncond,
                                guidance_scale,
                                None,
                                dynamic_thresh.step,
                            )

                    # Perform Guidance Rescale
                    if do_classifier_free_guidance and guidance_rescale > 0.0:
                        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        noise_pred = rescale_noise_cfg(
                            noise_pred,
                            noise_pred_text,
                            guidance_rescale=guidance_rescale,
                        )

                    # 7.3.5 Denoising: Get the next step latents (less noisy)
                    if latent_model_input.shape[0] == 2:
                        output_latents = self.scheduler.step(
                            noise_pred, t, latent_model_input[0], **extra_step_kwargs
                        ).prev_sample
                    else:
                        output_latents = self.scheduler.step(
                            noise_pred, t, latent_model_input, **extra_step_kwargs
                        ).prev_sample

                    if len(tensor_result) == 0:
                        latents[:, :, context_ids] = output_latents.clone()
                    else:
                        latents[:, :, context_ids] = output_latents.clone()[:, :, n_motion_frames:]  # ! Important: Clone can keep the latents intact? Check if this is necessary

                    # 7.3.7 update progress bar
                    if i == len(timesteps) - 1 or (
                        (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
                    ):
                        progress_bar.update()
                        if callback is not None and i % callback_steps == 0:
                            step_idx = i // getattr(self.scheduler, "order", 1)
                            callback(step_idx, t, latents)

                    # 7.3.8 store attns
                    if idx == 0:
                        if self.store_attn:
                            attn_maps = get_net_attn_map((height, width), layerwise_average=True)
                            torch.save(attn_maps, f"output/intermediates/attns/adriana_tys_0.6-attn2-t_{i}.pt")
                            print(f"{i}-th Timestep Attn stored")
                        if self.store_qk:
                            qs, ks = get_net_qk()
                            torch.save(qs, f"output/intermediates/attns/adriana_tys_0.6-attn1_qs-t_{i}.pt")
                            torch.save(ks, f"output/intermediates/attns/adriana_tys_0.6-attn1_ks-t_{i}.pt")
                            print(f"{i}-th Timestep qks stored")

            # decode video latents, prepare for prefix motion frames
            tensor_result.append(output_latents.clone())
            context_ids_prefix = [] if n_motion_frames == 0 else context_ids[-n_motion_frames:]

        # CORRECT COLOR OFFSET
        if align_color_alpha != 0.:
            print("[PostProcess] Performing Low-Pass Filter for Color Correction")
            for idx, context in enumerate(global_context):
                context_ids = context[0]
                if idx == 0:
                    ref_latents = latents[:, :, context_ids].squeeze(0)
                    continue

                cur_latents = latents[:, :, context_ids].squeeze(0)
                corrected_latents = correct_color_offset(ref_latents, cur_latents, alpha=align_color_alpha)
                latents[:, :, context_ids] = corrected_latents.unsqueeze(0)

        # Convert to tensor
        if output_type == "tensor":
            latents = latents

        if not return_dict:
            return latents

        return PipelineOutput(video_latents=latents, all_latents=tensor_result)
