import os
from typing import Any, Dict, List, Optional, Tuple

import torch
from accelerate import init_empty_weights
from diffusers import (
    AutoencoderKLWan,
    FlowMatchEulerDiscreteScheduler,
    WanImageToVideoPipeline,
    WanPipeline,
    # WanTransformer3DModel,
)
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from PIL.Image import Image
from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel

import finetrainers.functional as FF
from finetrainers.data import VideoArtifact
from finetrainers.logging import get_logger
from finetrainers.models.modeling_utils import ModelSpecification
from finetrainers.processors import ProcessorMixin, T5Processor
from finetrainers.typing import ArtifactType, SchedulerType
from finetrainers.utils import get_non_null_items

from .model import WanTransformer3DModel


import types
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False

logger = get_logger()


class WanLatentEncodeProcessor(ProcessorMixin):
    r"""
    Processor to encode image/video into latents using the Wan VAE.

    Args:
        output_names (`List[str]`):
            The names of the outputs that the processor returns. The outputs are in the following order:
            - latents: The latents of the input image/video.
    """

    def __init__(self, output_names: List[str]):
        super().__init__()
        self.output_names = output_names
        assert len(self.output_names) == 3

    def forward(
        self,
        vae: AutoencoderKLWan,
        image: Optional[torch.Tensor] = None,
        video: Optional[torch.Tensor] = None,
        generator: Optional[torch.Generator] = None,
        compute_posterior: bool = True,
    ) -> Dict[str, torch.Tensor]:
        device = vae.device
        dtype = vae.dtype

        if image is not None:
            video = image.unsqueeze(1)

        assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
        video = video.to(device=device, dtype=vae.dtype)
        video = video.permute(0, 2, 1, 3, 4).contiguous()  # [B, F, C, H, W] -> [B, C, F, H, W]

        if compute_posterior:
            latents = vae.encode(video).latent_dist.sample(generator=generator)
            latents = latents.to(dtype=dtype)
        else:
            # TODO(aryan): refactor in diffusers to have use_slicing attribute
            # if vae.use_slicing and video.shape[0] > 1:
            #     encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
            #     moments = torch.cat(encoded_slices)
            # else:
            #     moments = vae._encode(video)
            moments = vae._encode(video)
            latents = moments.to(dtype=dtype)

        latents_mean = torch.tensor(vae.config.latents_mean)
        latents_std = 1.0 / torch.tensor(vae.config.latents_std)

        return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std}


class WanModelSpecification(ModelSpecification):
    def __init__(
        self,
        pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
        tokenizer_id: Optional[str] = None,
        text_encoder_id: Optional[str] = None,
        transformer_id: Optional[str] = None,
        vae_id: Optional[str] = None,
        text_encoder_dtype: torch.dtype = torch.bfloat16,
        transformer_dtype: torch.dtype = torch.bfloat16,
        vae_dtype: torch.dtype = torch.bfloat16,
        revision: Optional[str] = None,
        cache_dir: Optional[str] = None,
        condition_model_processors: List[ProcessorMixin] = None,
        latent_model_processors: List[ProcessorMixin] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            tokenizer_id=tokenizer_id,
            text_encoder_id=text_encoder_id,
            transformer_id=transformer_id,
            vae_id=vae_id,
            text_encoder_dtype=text_encoder_dtype,
            transformer_dtype=transformer_dtype,
            vae_dtype=vae_dtype,
            revision=revision,
            cache_dir=cache_dir,
        )

        if condition_model_processors is None:
            condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
        if latent_model_processors is None:
            latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]

        self.condition_model_processors = condition_model_processors
        self.latent_model_processors = latent_model_processors

    @property
    def _resolution_dim_keys(self):
        return {"latents": (2, 3, 4)}

    def load_condition_models(self) -> Dict[str, torch.nn.Module]:
        common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

        if self.tokenizer_id is not None:
            tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
        else:
            tokenizer = AutoTokenizer.from_pretrained(
                self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
            )

        if self.text_encoder_id is not None:
            text_encoder = AutoModel.from_pretrained(
                self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
            )
        else:
            text_encoder = UMT5EncoderModel.from_pretrained(
                self.pretrained_model_name_or_path,
                subfolder="text_encoder",
                torch_dtype=self.text_encoder_dtype,
                **common_kwargs,
            )

        return {"tokenizer": tokenizer, "text_encoder": text_encoder}

    def load_latent_models(self) -> Dict[str, torch.nn.Module]:
        common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

        if self.vae_id is not None:
            vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
        else:
            vae = AutoencoderKLWan.from_pretrained(
                self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
            )

        return {"vae": vae}

    def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
        common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

        if self.transformer_id is not None:
            transformer = WanTransformer3DModel.from_pretrained(
                self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
            )
        else:
            transformer = WanTransformer3DModel.from_pretrained(
                self.pretrained_model_name_or_path,
                subfolder="transformer",
                torch_dtype=self.transformer_dtype,
                **common_kwargs,
            )

        scheduler = FlowMatchEulerDiscreteScheduler()

        return {"transformer": transformer, "scheduler": scheduler}

    def load_pipeline(
        self,
        tokenizer: Optional[AutoTokenizer] = None,
        text_encoder: Optional[UMT5EncoderModel] = None,
        transformer: Optional[WanTransformer3DModel] = None,
        vae: Optional[AutoencoderKLWan] = None,
        scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
        enable_slicing: bool = False,
        enable_tiling: bool = False,
        enable_model_cpu_offload: bool = False,
        training: bool = False,
        **kwargs,
    ) -> WanPipeline:
        components = {
            "tokenizer": tokenizer,
            "text_encoder": text_encoder,
            "transformer": transformer,
            "vae": vae,
            "scheduler": scheduler,
        }
        components = get_non_null_items(components)

        pipe = WanPipeline.from_pretrained(
            self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
        )
        pipe.text_encoder.to(self.text_encoder_dtype)
        pipe.vae.to(self.vae_dtype)

        if not training:
            pipe.transformer.to(self.transformer_dtype)

        # TODO(aryan): add support in diffusers
        # if enable_slicing:
        #     pipe.vae.enable_slicing()
        # if enable_tiling:
        #     pipe.vae.enable_tiling()
        if enable_model_cpu_offload:
            pipe.enable_model_cpu_offload()

        return pipe

    @torch.no_grad()
    def prepare_conditions(
        self,
        tokenizer: AutoTokenizer,
        text_encoder: UMT5EncoderModel,
        caption: str,
        max_sequence_length: int = 512,
        **kwargs,
    ) -> Dict[str, Any]:
        conditions = {
            "tokenizer": tokenizer,
            "text_encoder": text_encoder,
            "caption": caption,
            "max_sequence_length": max_sequence_length,
            **kwargs,
        }
        input_keys = set(conditions.keys())
        conditions = super().prepare_conditions(**conditions)
        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
        conditions.pop("prompt_attention_mask", None)
        return conditions

    @torch.no_grad()
    def prepare_latents(
        self,
        vae: AutoencoderKLWan,
        image: Optional[torch.Tensor] = None,
        video: Optional[torch.Tensor] = None,
        generator: Optional[torch.Generator] = None,
        compute_posterior: bool = True,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        conditions = {
            "vae": vae,
            "image": image,
            "video": video,
            "generator": generator,
            # We must force this to False because the latent normalization should be done before
            # the posterior is computed. The VAE does not handle this any more:
            # https://github.com/huggingface/diffusers/pull/10998
            "compute_posterior": False,
            **kwargs,
        }
        input_keys = set(conditions.keys())
        conditions = super().prepare_latents(**conditions)
        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
        return conditions

    def forward(
        self,
        transformer: WanTransformer3DModel,
        condition_model_conditions: Dict[str, torch.Tensor],
        latent_model_conditions: Dict[str, torch.Tensor],
        sigmas: torch.Tensor,
        generator: Optional[torch.Generator] = None,
        compute_posterior: bool = True,
        apply_target_noise_only: str = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, ...]:
        compute_posterior = False  # See explanation in prepare_latents
        if compute_posterior:
            latents = latent_model_conditions.pop("latents")
        else:
            latents = latent_model_conditions.pop("latents")
            latents_mean = latent_model_conditions.pop("latents_mean")
            latents_std = latent_model_conditions.pop("latents_std")

            mu, logvar = torch.chunk(latents, 2, dim=1)
            mu = self._normalize_latents(mu, latents_mean, latents_std)
            logvar = self._normalize_latents(logvar, latents_mean, latents_std)
            latents = torch.cat([mu, logvar], dim=1)

            posterior = DiagonalGaussianDistribution(latents)
            latents = posterior.sample(generator=generator)
            del posterior

        noise = torch.zeros_like(latents).normal_(generator=generator)
        noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
        if apply_target_noise_only == "front" or apply_target_noise_only == "front-none":
            print("[DEBUG] front noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
        elif apply_target_noise_only == "front-2" or apply_target_noise_only == "front-2-none":
            print("[DEBUG] front-2 noise applied")
            noisy_latents[:, :, :2] = latents[:, :, :2]
        elif apply_target_noise_only == "front-long" or apply_target_noise_only == "front-long-none":
            print("[DEBUG] front-long noise applied")
            noisy_latents[:, :, :6] = latents[:, :, :6]
        elif apply_target_noise_only == "front-long-none-F81":
            noisy_latents[:, :, :10] = latents[:, :, :10]
        elif apply_target_noise_only == "front-5" or apply_target_noise_only == "front-5-none":
            print("[DEBUG] front-5 noise applied")
            noisy_latents[:, :, :5] = latents[:, :, :5]
        elif apply_target_noise_only == "front-4-noise-none" or apply_target_noise_only == "front-4-noise-none-buffer":
            print(f"[DEBUG] {apply_target_noise_only} noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            # Vectorized batch processing
            mask_075 = (sigmas > 0.75).view(-1, 1, 1, 1, 1)
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            mask_025 = (sigmas > 0.25).view(-1, 1, 1, 1, 1)
            
            noisy_latents_075 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.75)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.5)
            noisy_latents_025 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.25)
            
            noisy_latents[:, :, 3:4] = torch.where(mask_075, noisy_latents_075, noisy_latents[:, :, 3:4])
            noisy_latents[:, :, 2:3] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 2:3])
            noisy_latents[:, :, 1:2] = torch.where(mask_025, noisy_latents_025, noisy_latents[:, :, 1:2])
        elif apply_target_noise_only == "front-4-noise-none-only25":
            print("[DEBUG] front-4-noise-none-only25 noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            # Vectorized batch processing
            mask_025 = (sigmas > 0.25).view(-1, 1, 1, 1, 1)
            
            noisy_latents_025 = FF.flow_match_xt(latents[:, :, 1:4], noise[:, :, 1:4], 0.25)
            
            noisy_latents[:, :, 1:4] = torch.where(mask_025, noisy_latents_025, noisy_latents[:, :, 1:4])
        elif apply_target_noise_only == "front-4-noise-none-only50":
            print("[DEBUG] front-4-noise-none-only50 noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            # Vectorized batch processing
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 1:4], noise[:, :, 1:4], 0.5)
            
            noisy_latents[:, :, 1:4] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 1:4])
        elif apply_target_noise_only == "front-4-noise-none-only75":
            print("[DEBUG] front-4-noise-none-only75 noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            # Vectorized batch processing
            mask_075 = (sigmas > 0.75).view(-1, 1, 1, 1, 1)
            
            noisy_latents_075 = FF.flow_match_xt(latents[:, :, 1:4], noise[:, :, 1:4], 0.75)
            
            noisy_latents[:, :, 1:4] = torch.where(mask_075, noisy_latents_075, noisy_latents[:, :, 1:4])
        elif apply_target_noise_only == "front-4-none":
            print(f"[DEBUG] front-4-none noise applied")
            noisy_latents[:, :, :4] = latents[:, :, :4]
        elif apply_target_noise_only == "front-4-noise-none-dual-cond":
            print("[DEBUG] front-4-noise-none-dual-cond noise applied")
            noisy_latents[:, :, :5] = latents[:, :, :5]
            # Vectorized batch processing
            mask_075 = (sigmas > 0.75).view(-1, 1, 1, 1, 1)
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            mask_025 = (sigmas > 0.25).view(-1, 1, 1, 1, 1)
            
            noisy_latents_075 = FF.flow_match_xt(latents[:, :, 7:8], noise[:, :, 7:8], 0.75)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 6:7], noise[:, :, 6:7], 0.5)
            noisy_latents_025 = FF.flow_match_xt(latents[:, :, 5:6], noise[:, :, 5:6], 0.25)
            
            noisy_latents[:, :, 7:8] = torch.where(mask_075, noisy_latents_075, noisy_latents[:, :, 7:8])
            noisy_latents[:, :, 6:7] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 6:7])
            noisy_latents[:, :, 5:6] = torch.where(mask_025, noisy_latents_025, noisy_latents[:, :, 5:6])
        elif apply_target_noise_only == "front-7-noise-none":
            print("[DEBUG] front-7-noise-none noise applied")
            noisy_latents[:, :, :4] = latents[:, :, :4]
            # Vectorized batch processing
            mask_075 = (sigmas > 0.75).view(-1, 1, 1, 1, 1)
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            mask_025 = (sigmas > 0.25).view(-1, 1, 1, 1, 1)
            
            noisy_latents_075 = FF.flow_match_xt(latents[:, :, 6:7], noise[:, :, 6:7], 0.75)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 5:6], noise[:, :, 5:6], 0.5)
            noisy_latents_025 = FF.flow_match_xt(latents[:, :, 4:5], noise[:, :, 4:5], 0.25)
            
            noisy_latents[:, :, 6:7] = torch.where(mask_075, noisy_latents_075, noisy_latents[:, :, 6:7])
            noisy_latents[:, :, 5:6] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 5:6])
            noisy_latents[:, :, 4:5] = torch.where(mask_025, noisy_latents_025, noisy_latents[:, :, 4:5])
        elif apply_target_noise_only == "front-5-noise-none":
            print("[DEBUG] front-5-noise-none noise applied")
            noisy_latents[:, :, :2] = latents[:, :, :2]
            mask_075 = (sigmas > 0.75).view(-1, 1, 1, 1, 1)
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            mask_025 = (sigmas > 0.25).view(-1, 1, 1, 1, 1)
            
            noisy_latents_075 = FF.flow_match_xt(latents[:, :, 4:5], noise[:, :, 4:5], 0.75)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.5)
            noisy_latents_025 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.25)

            noisy_latents[:, :, 4:5] = torch.where(mask_075, noisy_latents_075, noisy_latents[:, :, 4:5])
            noisy_latents[:, :, 3:4] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 3:4])
            noisy_latents[:, :, 2:3] = torch.where(mask_025, noisy_latents_025, noisy_latents[:, :, 2:3])
        elif apply_target_noise_only == "none":
            print("[DEBUG] No noise applied")
            pass
        elif apply_target_noise_only == "front-4-noise-none-dual-cond":
            print("[DEBUG] front-4-noise-none-dual-cond noise applied")
            noisy_latents[:, :, :5] = latents[:, :, :5]
            # Vectorized batch processing
            mask_075 = (sigmas > 0.75).view(-1, 1, 1, 1, 1)
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            mask_025 = (sigmas > 0.25).view(-1, 1, 1, 1, 1)
            
            noisy_latents_075 = FF.flow_match_xt(latents[:, :, 7:8], noise[:, :, 7:8], 0.75)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 6:7], noise[:, :, 6:7], 0.5)
            noisy_latents_025 = FF.flow_match_xt(latents[:, :, 5:6], noise[:, :, 5:6], 0.25)
            
            noisy_latents[:, :, 7:8] = torch.where(mask_075, noisy_latents_075, noisy_latents[:, :, 7:8])
            noisy_latents[:, :, 6:7] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 6:7])
            noisy_latents[:, :, 5:6] = torch.where(mask_025, noisy_latents_025, noisy_latents[:, :, 5:6])
        elif apply_target_noise_only == "buffer-1-noise-none":
            print("[DEBUG] buffer-1-noise-none noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.5)
            noisy_latents[:, :, 1:2] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 1:2])
        elif apply_target_noise_only == "buffer-2-noise-none":
            print("[DEBUG] buffer-2-noise-none noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            mask_033 = (sigmas > 0.33).view(-1, 1, 1, 1, 1)
            mask_066 = (sigmas > 0.66).view(-1, 1, 1, 1, 1)

            noisy_latents_033 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.33)
            noisy_latents_066 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.66)

            noisy_latents[:, :, 1:2] = torch.where(mask_033, noisy_latents_033, noisy_latents[:, :, 1:2])
            noisy_latents[:, :, 2:3] = torch.where(mask_066, noisy_latents_066, noisy_latents[:, :, 2:3])
        elif apply_target_noise_only == "buffer-4-noise-none":
            print("[DEBUG] buffer-4-noise-none noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            mask_020 = (sigmas > 0.2).view(-1, 1, 1, 1, 1)
            mask_040 = (sigmas > 0.4).view(-1, 1, 1, 1, 1)
            mask_060 = (sigmas > 0.6).view(-1, 1, 1, 1, 1)
            mask_080 = (sigmas > 0.8).view(-1, 1, 1, 1, 1)

            noisy_latents_020 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.2)
            noisy_latents_040 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.4)
            noisy_latents_060 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.6)
            noisy_latents_080 = FF.flow_match_xt(latents[:, :, 4:5], noise[:, :, 4:5], 0.8)

            noisy_latents[:, :, 1:2] = torch.where(mask_020, noisy_latents_020, noisy_latents[:, :, 1:2])
            noisy_latents[:, :, 2:3] = torch.where(mask_040, noisy_latents_040, noisy_latents[:, :, 2:3])
            noisy_latents[:, :, 3:4] = torch.where(mask_060, noisy_latents_060, noisy_latents[:, :, 3:4])
            noisy_latents[:, :, 4:5] = torch.where(mask_080, noisy_latents_080, noisy_latents[:, :, 4:5])
            
        elif apply_target_noise_only == "buffer-5-noise-none":
            print("[DEBUG] buffer-5-noise-none noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            mask_017 = (sigmas > 0.17).view(-1, 1, 1, 1, 1)
            mask_033 = (sigmas > 0.33).view(-1, 1, 1, 1, 1)
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            mask_067 = (sigmas > 0.67).view(-1, 1, 1, 1, 1)
            mask_083 = (sigmas > 0.83).view(-1, 1, 1, 1, 1)

            noisy_latents_017 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.17)
            noisy_latents_033 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.33)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.5)
            noisy_latents_067 = FF.flow_match_xt(latents[:, :, 4:5], noise[:, :, 4:5], 0.67)
            noisy_latents_083 = FF.flow_match_xt(latents[:, :, 5:6], noise[:, :, 5:6], 0.83)

            noisy_latents[:, :, 1:2] = torch.where(mask_017, noisy_latents_017, noisy_latents[:, :, 1:2])
            noisy_latents[:, :, 2:3] = torch.where(mask_033, noisy_latents_033, noisy_latents[:, :, 2:3])
            noisy_latents[:, :, 3:4] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 3:4])
            noisy_latents[:, :, 4:5] = torch.where(mask_067, noisy_latents_067, noisy_latents[:, :, 4:5])
            noisy_latents[:, :, 5:6] = torch.where(mask_083, noisy_latents_083, noisy_latents[:, :, 5:6])
        elif apply_target_noise_only == "buffer-log-convex-noise-none":
            print("[DEBUG] buffer-log-convex-noise-none noise applied")
            noisy_latents[:, :, 0] = latents[:, :, 0]
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)
            mask_070 = (sigmas > 0.7).view(-1, 1, 1, 1, 1)
            mask_090 = (sigmas > 0.9).view(-1, 1, 1, 1, 1)
            mask_095 = (sigmas > 0.95).view(-1, 1, 1, 1, 1)

            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.5)
            noisy_latents_070 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.7)
            noisy_latents_090 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.9)
            noisy_latents_095 = FF.flow_match_xt(latents[:, :, 4:5], noise[:, :, 4:5], 0.95)

            noisy_latents[:, :, 1:2] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 1:2])
            noisy_latents[:, :, 2:3] = torch.where(mask_070, noisy_latents_070, noisy_latents[:, :, 2:3])
            noisy_latents[:, :, 3:4] = torch.where(mask_090, noisy_latents_090, noisy_latents[:, :, 3:4])
            noisy_latents[:, :, 4:5] = torch.where(mask_095, noisy_latents_095, noisy_latents[:, :, 4:5])
        elif apply_target_noise_only == "buffer-log-concave-noise-none":
            print(f"[DEBUG] buffer-log-concave-noise-none noise applied")
            mask_005 = (sigmas > 0.05).view(-1, 1, 1, 1, 1)
            mask_010 = (sigmas > 0.1).view(-1, 1, 1, 1, 1)
            mask_030 = (sigmas > 0.3).view(-1, 1, 1, 1, 1)
            mask_050 = (sigmas > 0.5).view(-1, 1, 1, 1, 1)

            noisy_latents_005 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.05)
            noisy_latents_010 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.1)
            noisy_latents_030 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.3)
            noisy_latents_050 = FF.flow_match_xt(latents[:, :, 4:5], noise[:, :, 4:5], 0.5)

            noisy_latents[:, :, 1:2] = torch.where(mask_005, noisy_latents_005, noisy_latents[:, :, 1:2])
            noisy_latents[:, :, 2:3] = torch.where(mask_010, noisy_latents_010, noisy_latents[:, :, 2:3])
            noisy_latents[:, :, 3:4] = torch.where(mask_030, noisy_latents_030, noisy_latents[:, :, 3:4])
            noisy_latents[:, :, 4:5] = torch.where(mask_050, noisy_latents_050, noisy_latents[:, :, 4:5])
        elif apply_target_noise_only == "buffer-log-new-concave-noise-none":
            print(f"[DEBUG] buffer-log-new-concave-noise-none noise applied")
            mask_040 = (sigmas > 0.4).view(-1, 1, 1, 1, 1)
            mask_070 = (sigmas > 0.7).view(-1, 1, 1, 1, 1)
            mask_090 = (sigmas > 0.9).view(-1, 1, 1, 1, 1)

            noisy_latents_040 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.4)
            noisy_latents_070 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.7)
            noisy_latents_090 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.9)

            noisy_latents[:, :, 1:2] = torch.where(mask_040, noisy_latents_040, noisy_latents[:, :, 1:2])
            noisy_latents[:, :, 2:3] = torch.where(mask_070, noisy_latents_070, noisy_latents[:, :, 2:3])
            noisy_latents[:, :, 3:4] = torch.where(mask_090, noisy_latents_090, noisy_latents[:, :, 3:4])

        elif apply_target_noise_only == "buffer-log-new-convex-noise-none":
            print(f"[DEBUG] buffer-log-new-convex-noise-none noise applied")
            mask_010 = (sigmas > 0.1).view(-1, 1, 1, 1, 1)
            mask_030 = (sigmas > 0.3).view(-1, 1, 1, 1, 1)
            mask_060 = (sigmas > 0.6).view(-1, 1, 1, 1, 1)

            noisy_latents_010 = FF.flow_match_xt(latents[:, :, 1:2], noise[:, :, 1:2], 0.1)
            noisy_latents_030 = FF.flow_match_xt(latents[:, :, 2:3], noise[:, :, 2:3], 0.3)
            noisy_latents_060 = FF.flow_match_xt(latents[:, :, 3:4], noise[:, :, 3:4], 0.6)

            noisy_latents[:, :, 1:2] = torch.where(mask_010, noisy_latents_010, noisy_latents[:, :, 1:2])
            noisy_latents[:, :, 2:3] = torch.where(mask_030, noisy_latents_030, noisy_latents[:, :, 2:3])
            noisy_latents[:, :, 3:4] = torch.where(mask_060, noisy_latents_060, noisy_latents[:, :, 3:4])
            
        else:
            print("[DEBUG] Default[Full] noise applied")

        timesteps = (sigmas.flatten() * 1000.0).long()

        latent_model_conditions["hidden_states"] = noisy_latents.to(latents)

        pred = transformer(
            **latent_model_conditions,
            **condition_model_conditions,
            timestep=timesteps,
            return_dict=False,
            apply_target_noise_only=apply_target_noise_only,
        )[0]
        target = FF.flow_match_target(noise, latents)

        return pred, target, sigmas

    def validation(
        self,
        pipeline: WanPipeline,
        prompt: str,
        image: Optional[Image] = None,
        video = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_frames: Optional[int] = None,
        num_inference_steps: int = 50,
        generator: Optional[torch.Generator] = None,
        apply_target_noise_only = None,
        **kwargs,
    ) -> List[ArtifactType]:
        if image is not None:
            pipeline = WanImageToVideoPipeline.from_pipe(pipeline)

        init_latents = process_video(pipeline, video, pipeline.dtype, generator, height, width, apply_target_noise_only)
        plain_latents = process_video(pipeline, video, pipeline.dtype, generator, height, width, "plain")
        pipeline.custom_call = types.MethodType(custom_call, pipeline)

        if apply_target_noise_only == "none":
            input_latents = None
        else:
            input_latents = init_latents

        generation_kwargs = {
            "prompt": prompt,
            "image": image,
            "height": height,
            "width": width,
            "latents": input_latents,
            "init_latents": plain_latents,
            "num_frames": num_frames,
            "num_inference_steps": num_inference_steps,
            "generator": generator,
            "return_dict": True,
            "output_type": "pil",
            "apply_target_noise_only": apply_target_noise_only,
        }
        generation_kwargs = get_non_null_items(generation_kwargs)
        video = pipeline.custom_call(**generation_kwargs).frames[0]
        return [VideoArtifact(value=video)]

    def _save_lora_weights(
        self,
        directory: str,
        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
        scheduler: Optional[SchedulerType] = None,
        *args,
        **kwargs,
    ) -> None:
        # TODO(aryan): this needs refactoring
        if transformer_state_dict is not None:
            WanPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
        if scheduler is not None:
            scheduler.save_pretrained(os.path.join(directory, "scheduler"))

    def _save_model(
        self,
        directory: str,
        transformer: WanTransformer3DModel,
        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
        scheduler: Optional[SchedulerType] = None,
    ) -> None:
        # TODO(aryan): this needs refactoring
        if transformer_state_dict is not None:
            with init_empty_weights():
                transformer_copy = WanTransformer3DModel.from_config(transformer.config)
            transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
            transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
        if scheduler is not None:
            scheduler.save_pretrained(os.path.join(directory, "scheduler"))

    @staticmethod
    def _normalize_latents(
        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
    ) -> torch.Tensor:
        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
        latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
        latents = ((latents.float() - latents_mean) * latents_std).to(latents)
        return latents

@torch.no_grad()
def process_video(pipe, video, dtype, generator, height, width, apply_target_noise_only):
    if pipe.device != "cuda":
        generator = None
    if apply_target_noise_only == None:
        return None
    from diffusers.utils import load_video
    from diffusers.pipelines.wan.pipeline_wan_video2video import retrieve_latents
    from diffusers.utils.torch_utils import randn_tensor
    video = pipe.video_processor.preprocess_video(video, height=height, width=width)
    video = video.to("cuda", dtype=torch.float32)
    
    video_latents = retrieve_latents(pipe.vae.encode(video))
    latents_mean = (
        torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(pipe.device, dtype)
    )
    latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
        pipe.device, dtype
    )

    init_latents = (video_latents - latents_mean) * latents_std

    init_latents = init_latents.to(pipe.device)

    noise = randn_tensor(init_latents.shape, generator=generator, device=pipe.device, dtype=dtype)
    if apply_target_noise_only == "back":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        init_latents[:, :, :-1] = noise[:, :, :-1]
    elif apply_target_noise_only == "front" or apply_target_noise_only == "front-none":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        init_latents[:, :, 1:] = noise[:, :, 1:]
    elif apply_target_noise_only == "front-2" or apply_target_noise_only == "front-2-none":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        init_latents[:, :, 2:] = noise[:, :, 2:]
    elif apply_target_noise_only == "front-long" or apply_target_noise_only == "front-long-none":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        init_latents[:, :, 6:] = noise[:, :, 6:]
    elif apply_target_noise_only == "front-5" or apply_target_noise_only == "front-5-none":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        init_latents[:, :, 5:] = noise[:, :, 5:]
    elif apply_target_noise_only == "front-4-none":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "front-4-noise-none" or apply_target_noise_only == "front-4-noise-none-buffer":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_25 = timesteps[int(n_timesteps * (1 - 0.25))]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        t_75 = timesteps[int(n_timesteps * (1 - 0.75))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        #init_latents[:, :, 0] = scheduler.add_noise(init_latents[:, :, 0], noise[:, :, 0], torch.tensor([t_100]))
        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_25]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_50]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_75]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "front-4-noise-none-only25":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_25 = timesteps[int(n_timesteps * (1 - 0.25))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        #init_latents[:, :, 0] = scheduler.add_noise(init_latents[:, :, 0], noise[:, :, 0], torch.tensor([t_100]))
        init_latents[:, :, 1:4] = scheduler.add_noise(init_latents[:, :, 1:4], noise[:, :, 1:4], torch.tensor([t_25]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "front-4-noise-none-only50":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        #init_latents[:, :, 0] = scheduler.add_noise(init_latents[:, :, 0], noise[:, :, 0], torch.tensor([t_100]))
        init_latents[:, :, 1:4] = scheduler.add_noise(init_latents[:, :, 1:4], noise[:, :, 1:4], torch.tensor([t_50]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "front-4-noise-none-only75":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_75 = timesteps[int(n_timesteps * (1 - 0.75))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        #init_latents[:, :, 0] = scheduler.add_noise(init_latents[:, :, 0], noise[:, :, 0], torch.tensor([t_100]))
        init_latents[:, :, 1:4] = scheduler.add_noise(init_latents[:, :, 1:4], noise[:, :, 1:4], torch.tensor([t_75]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "front-4-noise-none-dual-cond":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_25 = timesteps[int(n_timesteps * (1 - 0.25))]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        t_75 = timesteps[int(n_timesteps * (1 - 0.75))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        #init_latents[:, :, 0] = scheduler.add_noise(init_latents[:, :, 0], noise[:, :, 0], torch.tensor([t_100]))
        init_latents[:, :, 5] = scheduler.add_noise(init_latents[:, :, 5], noise[:, :, 5], torch.tensor([t_25]))
        init_latents[:, :, 6] = scheduler.add_noise(init_latents[:, :, 6], noise[:, :, 6], torch.tensor([t_50]))
        init_latents[:, :, 7] = scheduler.add_noise(init_latents[:, :, 7], noise[:, :, 7], torch.tensor([t_75]))
        init_latents[:, :, 8:] = noise[:, :, 8:]
    elif apply_target_noise_only == "front-7-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_25 = timesteps[int(n_timesteps * (1 - 0.25))]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        t_75 = timesteps[int(n_timesteps * (1 - 0.75))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        #init_latents[:, :, 0] = scheduler.add_noise(init_latents[:, :, 0], noise[:, :, 0], torch.tensor([t_100]))
        init_latents[:, :, 4] = scheduler.add_noise(init_latents[:, :, 4], noise[:, :, 4], torch.tensor([t_25]))
        init_latents[:, :, 5] = scheduler.add_noise(init_latents[:, :, 5], noise[:, :, 5], torch.tensor([t_50]))
        init_latents[:, :, 6] = scheduler.add_noise(init_latents[:, :, 6], noise[:, :, 6], torch.tensor([t_75]))
        init_latents[:, :, 7:] = noise[:, :, 7:]
    elif apply_target_noise_only == "front-5-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_25 = timesteps[int(n_timesteps * (1 - 0.25))]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        t_75 = timesteps[int(n_timesteps * (1 - 0.75))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_25]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_50]))
        init_latents[:, :, 4] = scheduler.add_noise(init_latents[:, :, 4], noise[:, :, 4], torch.tensor([t_75]))
        init_latents[:, :, 5:] = noise[:, :, 5:]
    elif apply_target_noise_only == "buffer-1-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]

        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")

        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_50]))
        init_latents[:, :, 1:] = noise[:, :, 1:]
    elif apply_target_noise_only == "buffer-2-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]

        t_33 = timesteps[int(n_timesteps * (1 - 0.33))]
        t_66 = timesteps[int(n_timesteps * (1 - 0.66))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")

        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_33]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_66]))
        init_latents[:, :, 2:] = noise[:, :, 2:]
    elif apply_target_noise_only == "buffer-4-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]

        t_20 = timesteps[int(n_timesteps * (1 - 0.2))]
        t_40 = timesteps[int(n_timesteps * (1 - 0.4))]
        t_60 = timesteps[int(n_timesteps * (1 - 0.6))]
        t_80 = timesteps[int(n_timesteps * (1 - 0.8))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")

        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_20]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_40]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_60]))
        init_latents[:, :, 4] = scheduler.add_noise(init_latents[:, :, 4], noise[:, :, 4], torch.tensor([t_80]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "buffer-5-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]

        t_17 = timesteps[int(n_timesteps * (1 - 0.17))]
        t_33 = timesteps[int(n_timesteps * (1 - 0.33))]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        t_67 = timesteps[int(n_timesteps * (1 - 0.67))]
        t_83 = timesteps[int(n_timesteps * (1 - 0.83))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")

        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_17]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_33]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_50]))
        init_latents[:, :, 4] = scheduler.add_noise(init_latents[:, :, 4], noise[:, :, 4], torch.tensor([t_67]))
        init_latents[:, :, 5] = scheduler.add_noise(init_latents[:, :, 5], noise[:, :, 5], torch.tensor([t_83]))
        init_latents[:, :, 5:] = noise[:, :, 5:]
    elif apply_target_noise_only == "buffer-log-convex-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]

        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        t_70 = timesteps[int(n_timesteps * (1 - 0.7))]
        t_90 = timesteps[int(n_timesteps * (1 - 0.9))]
        t_95 = timesteps[int(n_timesteps * (1 - 0.95))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")

        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_50]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_70]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_90]))
        init_latents[:, :, 4] = scheduler.add_noise(init_latents[:, :, 4], noise[:, :, 4], torch.tensor([t_95]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "buffer-log-concave-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        
        t_05 = timesteps[int(n_timesteps * (1 - 0.05))]
        t_10 = timesteps[int(n_timesteps * (1 - 0.1))]
        t_30 = timesteps[int(n_timesteps * (1 - 0.3))]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        
        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_05]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_10]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_30]))
        init_latents[:, :, 4] = scheduler.add_noise(init_latents[:, :, 4], noise[:, :, 4], torch.tensor([t_50]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "buffer-log-new-concave-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]

        t_40 = timesteps[int(n_timesteps * (1 - 0.4))]
        t_70 = timesteps[int(n_timesteps * (1 - 0.7))]
        t_90 = timesteps[int(n_timesteps * (1 - 0.9))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")

        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_40]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_70]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_90]))
        init_latents[:, :, 4:] = noise[:, :, 4:]
    elif apply_target_noise_only == "buffer-log-new-convex-noise-none":
        timesteps = pipe.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = pipe.scheduler
        n_timesteps = timesteps.shape[0]
        
        t_10 = timesteps[int(n_timesteps * (1 - 0.1))]
        t_30 = timesteps[int(n_timesteps * (1 - 0.3))]
        t_60 = timesteps[int(n_timesteps * (1 - 0.6))]
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")

        init_latents[:, :, 1] = scheduler.add_noise(init_latents[:, :, 1], noise[:, :, 1], torch.tensor([t_10]))
        init_latents[:, :, 2] = scheduler.add_noise(init_latents[:, :, 2], noise[:, :, 2], torch.tensor([t_30]))
        init_latents[:, :, 3] = scheduler.add_noise(init_latents[:, :, 3], noise[:, :, 3], torch.tensor([t_60]))
        init_latents[:, :, 4:] = noise[:, :, 4:]

    elif apply_target_noise_only == "none":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        pass
    elif apply_target_noise_only == "plain":
        print(f"[DEBUG] applied noise mode : {apply_target_noise_only}")
        pass
    else:
        raise ValueError(f"apply_target_noise_only must be either 'back' or 'front', but got {apply_target_noise_only}")
    init_latents = init_latents.to(pipe.device)
    return init_latents

@torch.no_grad()
def retrieve_video(pipe, init_latents):
    latents = init_latents.to(pipe.vae.dtype)
    latents_mean = (
        torch.tensor(pipe.vae.config.latents_mean)
        .view(1, pipe.vae.config.z_dim, 1, 1, 1)
        .to(latents.device, latents.dtype)
    )
    latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
        latents.device, latents.dtype
    )
    latents = latents / latents_std + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    return video

@torch.no_grad()
def custom_call(
    self,
    prompt: Union[str, List[str]] = None,
    negative_prompt: Union[str, List[str]] = None,
    height: int = 480,
    width: int = 832,
    num_frames: int = 81,
    num_inference_steps: int = 50,
    guidance_scale: float = 5.0,
    num_videos_per_prompt: Optional[int] = 1,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.Tensor] = None,
    init_latents: Optional[torch.Tensor] = None,
    prompt_embeds: Optional[torch.Tensor] = None,
    negative_prompt_embeds: Optional[torch.Tensor] = None,
    output_type: Optional[str] = "np",
    return_dict: bool = True,
    attention_kwargs: Optional[Dict[str, Any]] = None,
    callback_on_step_end: Optional[
        Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
    ] = None,
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    max_sequence_length: int = 512,
    apply_target_noise_only: Optional[str] = None,
):

    if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
        callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(
        prompt,
        negative_prompt,
        height,
        width,
        prompt_embeds,
        negative_prompt_embeds,
        callback_on_step_end_tensor_inputs,
    )

    if num_frames % self.vae_scale_factor_temporal != 1:
        logger.warning(
            f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
        )
        num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
    num_frames = max(num_frames, 1)

    self._guidance_scale = guidance_scale
    self._attention_kwargs = attention_kwargs
    self._current_timestep = None
    self._interrupt = False

    device = self._execution_device

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    # 3. Encode input prompt
    prompt_embeds, negative_prompt_embeds = self.encode_prompt(
        prompt=prompt,
        negative_prompt=negative_prompt,
        do_classifier_free_guidance=self.do_classifier_free_guidance,
        num_videos_per_prompt=num_videos_per_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        max_sequence_length=max_sequence_length,
        device=device,
    )

    transformer_dtype = self.transformer.dtype
    prompt_embeds = prompt_embeds.to(transformer_dtype)
    if negative_prompt_embeds is not None:
        negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)

    # 4. Prepare timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = self.scheduler.timesteps

    # 5. Prepare latent variables
    num_channels_latents = self.transformer.config.in_channels
    latents = self.prepare_latents(
        batch_size * num_videos_per_prompt,
        num_channels_latents,
        height,
        width,
        num_frames,
        torch.float32,
        device,
        generator,
        latents,
    )

    # 6. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    self._num_timesteps = len(timesteps)
    print(f"[pipeline] apply_target_noise_only: {apply_target_noise_only}")
    if "noise" in apply_target_noise_only:
        timesteps = self.scheduler.timesteps # torch.Size([1000]), torch.float32, 999~0
        scheduler = self.scheduler
        n_timesteps = timesteps.shape[0]
        #t_100 = timesteps[0]
        t_25 = timesteps[int(n_timesteps * (1 - 0.25))]
        t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
        t_75 = timesteps[int(n_timesteps * (1 - 0.75))]
        if "buffer" in apply_target_noise_only:
            t_20 = timesteps[int(n_timesteps * (1 - 0.2))]
            t_40 = timesteps[int(n_timesteps * (1 - 0.4))]
            t_60 = timesteps[int(n_timesteps * (1 - 0.6))]
            t_80 = timesteps[int(n_timesteps * (1 - 0.8))]
            t_17 = timesteps[int(n_timesteps * (1 - 0.17))]
            t_33 = timesteps[int(n_timesteps * (1 - 0.33))]
            t_50 = timesteps[int(n_timesteps * (1 - 0.5))]
            t_67 = timesteps[int(n_timesteps * (1 - 0.67))]
            t_83 = timesteps[int(n_timesteps * (1 - 0.83))]
            t_66 = timesteps[int(n_timesteps * (1 - 0.66))]
            if "log" in apply_target_noise_only:
                if "convex" in apply_target_noise_only:
                    t_10 = timesteps[int(n_timesteps * (1 - 0.1))]
                    t_30 = timesteps[int(n_timesteps * (1 - 0.3))]
                    t_60 = timesteps[int(n_timesteps * (1 - 0.6))]
                elif "concave" in apply_target_noise_only:
                    t_40 = timesteps[int(n_timesteps * (1 - 0.4))]
                    t_70 = timesteps[int(n_timesteps * (1 - 0.7))]
                    t_90 = timesteps[int(n_timesteps * (1 - 0.9))]
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if self.interrupt:
                continue

            if apply_target_noise_only == "none":
                from diffusers.utils.torch_utils import randn_tensor
                noise = randn_tensor(init_latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
                latents[:, :, :4] = self.scheduler.add_noise(init_latents[:, :, :4], noise[:, :, :4], torch.tensor([t], device=latents.device))
            
            if apply_target_noise_only == "front-5-noise-none":
                from diffusers.utils.torch_utils import randn_tensor
                noise = randn_tensor(init_latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
                latents[:, :, 5] = self.scheduler.add_noise(init_latents[:, :, 5], noise[:, :, 5], torch.tensor([t], device=latents.device))
                latents[:, :, 8] = self.scheduler.add_noise(init_latents[:, :, 8], noise[:, :, 8], torch.tensor([t], device=latents.device))
                latents[:, :, 11] = self.scheduler.add_noise(init_latents[:, :, 11], noise[:, :, 11], torch.tensor([t], device=latents.device))
                latents[:, :, 14] = self.scheduler.add_noise(init_latents[:, :, 14], noise[:, :, 14], torch.tensor([t], device=latents.device))
                latents[:, :, 17] = self.scheduler.add_noise(init_latents[:, :, 17], noise[:, :, 17], torch.tensor([t], device=latents.device))
                latents[:, :, 20] = self.scheduler.add_noise(init_latents[:, :, 20], noise[:, :, 20], torch.tensor([t], device=latents.device))
            
            self._current_timestep = t
            latent_model_input = latents.to(transformer_dtype)
            timestep = t.expand(latents.shape[0])

            noise_pred = self.transformer(
                hidden_states=latent_model_input,
                timestep=timestep,
                encoder_hidden_states=prompt_embeds,
                attention_kwargs=attention_kwargs,
                return_dict=False,
                apply_target_noise_only=apply_target_noise_only,
            )[0]

            if self.do_classifier_free_guidance:
                noise_uncond = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=negative_prompt_embeds,
                    attention_kwargs=attention_kwargs,
                    return_dict=False,
                    apply_target_noise_only=apply_target_noise_only,
                )[0]
                noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

            if apply_target_noise_only == "front" or apply_target_noise_only == "front-none":
                noise_pred[:, :, 0] = 0
            elif apply_target_noise_only == "front-long" or apply_target_noise_only == "front-long-none":
                noise_pred[:, :, :6] = 0
            elif apply_target_noise_only == "front-2" or apply_target_noise_only == "front-2-none":
                noise_pred[:, :, :2] = 0
            elif apply_target_noise_only == "front-5" or apply_target_noise_only == "front-5-none":
                noise_pred[:, :, :5] = 0
            elif apply_target_noise_only == "front-4-none":
                noise_pred[:, :, :4] = 0
            elif apply_target_noise_only == "front-4-noise-none" or apply_target_noise_only == "front-4-noise-none-buffer":
                noise_pred[:, :, 0] = 0
                if t > t_25:
                    print(f"[DEBUG] not reached t_25")
                    noise_pred[:, :, 1] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 2] = 0
                if t > t_75:
                    print(f"[DEBUG] not reached t_75")
                    noise_pred[:, :, 3] = 0
            elif apply_target_noise_only == "buffer-1-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 1] = 0
            elif apply_target_noise_only == "buffer-2-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_33:
                    print(f"[DEBUG] not reached t_33")
                    noise_pred[:, :, 1] = 0
                if t > t_66:
                    print(f"[DEBUG] not reached t_66")
                    noise_pred[:, :, 2] = 0
            elif apply_target_noise_only == "buffer-4-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_20:
                    print(f"[DEBUG] not reached t_20")
                    noise_pred[:, :, 1] = 0
                if t > t_40:
                    print(f"[DEBUG] not reached t_40")
                    noise_pred[:, :, 2] = 0
                if t > t_60:
                    print(f"[DEBUG] not reached t_60")
                    noise_pred[:, :, 3] = 0
                if t > t_80:
                    print(f"[DEBUG] not reached t_80")
                    noise_pred[:, :, 4] = 0
            elif apply_target_noise_only == "buffer-5-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_17:
                    print(f"[DEBUG] not reached t_17")
                    noise_pred[:, :, 1] = 0
                if t > t_33:
                    print(f"[DEBUG] not reached t_33")
                    noise_pred[:, :, 2] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 3] = 0
                if t > t_67:
                    print(f"[DEBUG] not reached t_67")
                    noise_pred[:, :, 4] = 0
                if t > t_83:
                    print(f"[DEBUG] not reached t_83")
                    noise_pred[:, :, 5] = 0
            elif apply_target_noise_only == "buffer-log-convex-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 1] = 0
                if t > t_70:
                    print(f"[DEBUG] not reached t_70")
                    noise_pred[:, :, 2] = 0
                if t > t_90:
                    print(f"[DEBUG] not reached t_90")
                    noise_pred[:, :, 3] = 0
                if t > t_95:
                    print(f"[DEBUG] not reached t_95")
                    noise_pred[:, :, 4] = 0
            elif apply_target_noise_only == "buffer-log-concave-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_05:
                    print(f"[DEBUG] not reached t_05")
                    noise_pred[:, :, 1] = 0
                if t > t_10:
                    print(f"[DEBUG] not reached t_10")
                    noise_pred[:, :, 2] = 0
                if t > t_30:
                    print(f"[DEBUG] not reached t_30")
                    noise_pred[:, :, 3] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 4] = 0

            elif apply_target_noise_only == "front-4-noise-none-only25":
                noise_pred[:, :, 0] = 0
                if t > t_25:
                    print(f"[DEBUG] not reached t_25")
                    noise_pred[:, :, 1:4] = 0
            elif apply_target_noise_only == "front-4-noise-none-only50":
                noise_pred[:, :, 0] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 1:4] = 0
            elif apply_target_noise_only == "front-4-noise-none-only75":
                noise_pred[:, :, 0] = 0
                if t > t_75:
                    print(f"[DEBUG] not reached t_75")
                    noise_pred[:, :, 1:4] = 0
            elif apply_target_noise_only == "front-4-noise-none-dual-cond":
                noise_pred[:, :, :5] = 0
                if t > t_25:
                    print(f"[DEBUG] not reached t_25")
                    noise_pred[:, :, 5] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 6] = 0
                if t > t_75:
                    print(f"[DEBUG] not reached t_75")
                    noise_pred[:, :, 7] = 0
            elif apply_target_noise_only == "front-7-noise-none":
                noise_pred[:, :, :4] = 0
                if t > t_25:
                    print(f"[DEBUG] not reached t_25")
                    noise_pred[:, :, 4] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 5] = 0
                if t > t_75:
                    print(f"[DEBUG] not reached t_75")
                    noise_pred[:, :, 6] = 0
            elif apply_target_noise_only == "none":
                pass
            elif apply_target_noise_only == "front-5-noise-none":
                noise_pred[:, :, :2] = 0
                if t > t_25:
                    print(f"[DEBUG] not reached t_25")
                    noise_pred[:, :, 2] = 0
                if t > t_50:
                    print(f"[DEBUG] not reached t_50")
                    noise_pred[:, :, 3] = 0
                if t > t_75:
                    print(f"[DEBUG] not reached t_75")
                    noise_pred[:, :, 4] = 0
            elif apply_target_noise_only == "buffer-log-new-concave-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_40:
                    print(f"[DEBUG] not reached t_40")
                    noise_pred[:, :, 1] = 0
                if t > t_70:
                    print(f"[DEBUG] not reached t_70")
                    noise_pred[:, :, 2] = 0
                if t > t_90:
                    print(f"[DEBUG] not reached t_90")
                    noise_pred[:, :, 3] = 0
            elif apply_target_noise_only == "buffer-log-new-convex-noise-none":
                noise_pred[:, :, 0] = 0
                if t > t_10:
                    print(f"[DEBUG] n1ot reached t_10")
                    noise_pred[:, :, 1] = 0
                if t > t_30:
                    print(f"[DEBUG] not reached t_30")
                    noise_pred[:, :, 2] = 0
                if t > t_60:
                    print(f"[DEBUG] not reached t_60")
                    noise_pred[:, :, 3] = 0

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

            if callback_on_step_end is not None:
                callback_kwargs = {}
                for k in callback_on_step_end_tensor_inputs:
                    callback_kwargs[k] = locals()[k]
                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                latents = callback_outputs.pop("latents", latents)
                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()

            if XLA_AVAILABLE:
                xm.mark_step()

    self._current_timestep = None

    if not output_type == "latent":
        latents = latents.to(self.vae.dtype)
        latents_mean = (
            torch.tensor(self.vae.config.latents_mean)
            .view(1, self.vae.config.z_dim, 1, 1, 1)
            .to(latents.device, latents.dtype)
        )
        latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
            latents.device, latents.dtype
        )
        latents = latents / latents_std + latents_mean
        video = self.vae.decode(latents, return_dict=False)[0]
        video = self.video_processor.postprocess_video(video, output_type=output_type)
    else:
        video = latents

    # Offload all models
    self.maybe_free_model_hooks()

    if not return_dict:
        return (video,)

    return WanPipelineOutput(frames=video)