# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tqdm import tqdm
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import html
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import ftfy
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

from ..loader import WanLoraLoaderMixin
from ..autoencoder import AutoencoderKLWan
from ..transformer_joint_s3 import WanTransformer3DModel
from ..flow_frame import FlowMatchScheduler
from .pipeline_output import WanPipelineOutput

from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm
    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False

logger = logging.get_logger(__name__)

def transform_and_repeat_timesteps(timesteps, device=None, num_segments=4, prefix=[358, 358, 358, 0, 0]):
    """
    Transform a list or tensor of N timesteps into K lists of (len(prefix) + num_segments*4) numbers each,
    in reversed order with ascending numbers within each list.

    Steps:
    1. Divide the N timesteps into `num_segments` segments of equal length.
    2. Create K lists (K = segment_size) by taking elements in reverse order from each segment.
    3. Sort the `num_segments` numbers in each list in ascending order.
    4. For each list, repeat each of the `num_segments` numbers 4 times and prepend `prefix`.
    5. Reverse the order of the K lists.

    Args:
        timesteps (list or torch.Tensor): List or 1D tensor of N numbers, N must be divisible by `num_segments`.
        device (torch.device, optional): Device for tensor output.
        num_segments (int): How many segments to divide the input into (default 4).
        prefix (list): List to prepend to each output list (default [358, 358, 358, 0, 0]).

    Returns:
        list or torch.Tensor: List of K lists, each with (len(prefix) + num_segments*4) numbers,
                              or a (K, len(prefix) + num_segments*4) tensor if device is specified.

    Example:
        Input: tensor([999, 998, ..., 900, 799, ..., 700, 599, ..., 500, 399, ..., 300])
        Output: [[358, 358, 358, 0, 0, a, a, a, a, b, b, b, b, c, c, c, c, d, d, d, d], ...]
    """
    if isinstance(timesteps, torch.Tensor):
        if timesteps.dim() != 1:
            raise ValueError("Tensor input must be 1D")
        input_device = timesteps.device if device is None else device
        timesteps = timesteps.cpu().tolist()
    elif not isinstance(timesteps, list):
        raise ValueError("Input must be a list or tensor of numbers")
    else:
        input_device = device

    N = len(timesteps)
    if N % num_segments != 0:
        raise ValueError(f"Input length {N} must be divisible by num_segments={num_segments}")
    segment_size = N // num_segments

    segments = [timesteps[i * segment_size:(i + 1) * segment_size] for i in range(num_segments)]

    result = []
    for i in range(segment_size):
        group = [segments[j][segment_size - 1 - i] for j in range(num_segments)]
        group = sorted(group)
        repeated = []
        for num in group:
            repeated.extend([num] * 4)
        new_list = prefix + repeated
        result.append(new_list)

    result = result[::-1]

    if input_device is not None:
        result = torch.tensor(result, dtype=torch.int64, device=input_device)

    return result

EXAMPLE_DOC_STRING = """
    Examples:
        ```python
        >>> import torch
        >>> from diffusers.utils import export_to_video
        >>> from diffusers import AutoencoderKLWan, WanPipeline
        >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler

        >>> model_id = "path/to/model"
        >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
        >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
        >>> flow_shift = 5.0
        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
        >>> pipe.to("cuda")

        >>> prompt = "A generic prompt for video generation."
        >>> negative_prompt = "Low quality, blurry, static, deformed, incomplete."

        >>> output = pipe(
        ...     prompt=prompt,
        ...     negative_prompt=negative_prompt,
        ...     height=720,
        ...     width=1280,
        ...     num_frames=81,
        ...     guidance_scale=5.0,
        ... ).frames[0]
        >>> export_to_video(output, "output.mp4", fps=16)
        ```
"""

def randn_like(tensor, generator=None):
    return randn_tensor(
        tensor.shape,
        generator=generator, dtype=tensor.dtype, device=tensor.device
    )

def generate_list_and_repeat(num_range, repeat_factor, divisible_by):
    # Generate a list of numbers and repeat those divisible by a specified number
    original_list = list(range(num_range))
    result_list = []
    for i, num in enumerate(original_list):
        if num % divisible_by == 0 and i != 0:
            result_list.extend([num] * (repeat_factor + 1))
        else:
            result_list.append(num)
    return result_list

def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()

def whitespace_clean(text):
    text = re.sub(r"\s+", " ", text)
    text = text.strip()
    return text

def prompt_clean(text):
    text = whitespace_clean(basic_clean(text))
    return text

class WanPipeline_flow(DiffusionPipeline, WanLoraLoaderMixin):
    """
    Pipeline for text-to-video generation.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Args:
        tokenizer ([`T5Tokenizer`]):
            Tokenizer from T5, specifically a large-scale variant.
        text_encoder ([`T5EncoderModel`]):
            T5 encoder model, specifically a large-scale variant.
        transformer ([`WanTransformer3DModel`]):
            Conditional Transformer to denoise the input latents.
        scheduler ([`FlowMatchScheduler`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
        vae ([`AutoencoderKLWan`]):
            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
    """

    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

    def __init__(
        self,
        tokenizer: AutoTokenizer,
        text_encoder: UMT5EncoderModel,
        transformer: WanTransformer3DModel,
        vae: AutoencoderKLWan,
        scheduler: FlowMatchScheduler,
    ):
        super().__init__()

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            transformer=transformer,
            scheduler=scheduler,
        )

        self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
        self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)

    def _get_t5_prompt_embeds(
        self,
        prompt: Union[str, List[str]] = None,
        num_videos_per_prompt: int = 1,
        max_sequence_length: int = 226,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        device = device or self._execution_device
        dtype = dtype or self.text_encoder.dtype

        prompt = [prompt] if isinstance(prompt, str) else prompt
        prompt = [prompt_clean(u) for u in prompt]
        batch_size = len(prompt)

        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            add_special_tokens=True,
            return_attention_mask=True,
            return_tensors="pt",
        )
        text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
        seq_lens = mask.gt(0).sum(dim=1).long()

        prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
        prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
        prompt_embeds = torch.stack(
            [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
        )

        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)

        return prompt_embeds

    def encode_prompt(
        self,
        prompt: Union[str, List[str]],
        negative_prompt: Optional[Union[str, List[str]]] = None,
        do_classifier_free_guidance: bool = True,
        num_videos_per_prompt: int = 1,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        max_sequence_length: int = 226,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        """
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                prompt to be encoded
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation.
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                Whether to use classifier free guidance or not.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                Number of videos that should be generated per prompt.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings.
            max_sequence_length (`int`, *optional*, defaults to 226):
                Maximum sequence length for tokenization.
            device: (`torch.device`, *optional*):
                torch device
            dtype: (`torch.dtype`, *optional*):
                torch dtype
        """
        device = device or self._execution_device

        prompt = [prompt] if isinstance(prompt, str) else prompt
        if prompt is not None:
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            prompt_embeds = self._get_t5_prompt_embeds(
                prompt=prompt,
                num_videos_per_prompt=num_videos_per_prompt,
                max_sequence_length=max_sequence_length,
                device=device,
                dtype=dtype,
            )

        if do_classifier_free_guidance and negative_prompt_embeds is None:
            negative_prompt = negative_prompt or ""
            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt

            if prompt is not None and type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )

            negative_prompt_embeds = self._get_t5_prompt_embeds(
                prompt=negative_prompt,
                num_videos_per_prompt=num_videos_per_prompt,
                max_sequence_length=max_sequence_length,
                device=device,
                dtype=dtype,
            )

        return prompt_embeds, negative_prompt_embeds

    def check_inputs(
        self,
        prompt,
        negative_prompt,
        height,
        width,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        callback_on_step_end_tensor_inputs=None,
    ):
        if height % 16 != 0 or width % 16 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")

        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )

        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif negative_prompt is not None and negative_prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
        elif negative_prompt is not None and (
            not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
        ):
            raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")

    def prepare_latents(
        self,
        batch_size: int,
        num_channels_latents: int = 16,
        height: int = 480,
        width: int = 832,
        num_frames: int = 81,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if latents is not None:
            return latents.to(device=device, dtype=dtype)

        num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
        shape = (
            batch_size,
            num_channels_latents,
            num_latent_frames,
            int(height) // self.vae_scale_factor_spatial,
            int(width) // self.vae_scale_factor_spatial,
        )
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        return latents

    @property
    def guidance_scale(self):
        return self._guidance_scale

    @property
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 1.0

    @property
    def num_timesteps(self):
        return self._num_timesteps

    @property
    def current_timestep(self):
        return self._current_timestep

    @property
    def interrupt(self):
        return self._interrupt

    @property
    def attention_kwargs(self):
        return self._attention_kwargs

    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        self,
        prompts: Union[str, List[str]] = None,
        negative_prompt: Union[str, List[str]] = None,
        height: int = 480,
        width: int = 832,
        num_frames: int = 81,
        num_inference_steps: int = 50,
        guidance_scale: float = 5.0,
        num_videos_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        output_type: Optional[str] = "np",
        return_dict: bool = True,
        attention_kwargs: Optional[Dict[str, Any]] = None,
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        max_sequence_length: int = 512,
        num_sample_groups: int = 8, 
        num_noise_groups: int = 4,
        infer_real: bool = False,
    ):
        """
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation.
            height (`int`, defaults to `480`):
                The height in pixels of the generated image.
            width (`int`, defaults to `832`):
                The width in pixels of the generated image.
            num_frames (`int`, defaults to `81`):
                The number of frames in the generated video.
            num_inference_steps (`int`, defaults to `50`):
                The number of denoising steps.
            guidance_scale (`float`, defaults to `5.0`):
                Guidance scale for classifier-free diffusion guidance.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A torch generator to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings.
            output_type (`str`, *optional*, defaults to `"np"`):
                The output format of the generated image.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`WanPipelineOutput`].
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary passed to the `AttentionProcessor`.
            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
                A function called at the end of each denoising step.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function.
            max_sequence_length (`int`, *optional*, defaults to 512):
                Maximum sequence length for tokenization.
            num_sample_groups (`int`, *optional*, defaults to 8):
                Number of sample groups.
            num_noise_groups (`int`, *optional*, defaults to 4):
                Number of noise groups.
            infer_real (`bool`, *optional*, defaults to False):
                Whether to infer real video tokens.

        Returns:
            [`~WanPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned.
        """
        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
        outer_steps = num_sample_groups
        inner_steps = num_inference_steps // num_noise_groups
        window_size = (num_frames // 4 + 1 - 5) // num_noise_groups

        self.check_inputs(
            prompts,
            negative_prompt,
            height,
            width,
            prompt_embeds,
            negative_prompt_embeds,
            callback_on_step_end_tensor_inputs,
        )

        self._guidance_scale = guidance_scale
        self._attention_kwargs = attention_kwargs
        self._current_timestep = None
        self._interrupt = False

        device = self._execution_device

        batch_size = 1

        prompt_embeds_positive = []
        prompt_embeds_negative = []

        transformer_dtype = self.transformer.dtype
        for prompt in prompts:
            prompt_embeds, negative_prompt_embeds = self.encode_prompt(
                prompt=prompt,
                negative_prompt=negative_prompt,
                do_classifier_free_guidance=self.do_classifier_free_guidance,
                num_videos_per_prompt=num_videos_per_prompt,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds,
                max_sequence_length=max_sequence_length,
                device=device,
            )
            prompt_embeds_positive.append(prompt_embeds.to(transformer_dtype))   
            prompt_embeds_negative.append(negative_prompt_embeds.to(transformer_dtype))   

        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=1.0, shift=5.0)

        timesteps = self.scheduler.timesteps.to(torch.bfloat16)

        timesteps_grouped = transform_and_repeat_timesteps(timesteps)

        num_channels_latents = self.transformer.config.in_channels

        latents_prompt = latents
        latents_prompt_decode = latents
        if infer_real:
            latents_prompt_relatents = latents_prompt[:, :, 4].unsqueeze(2).repeat(1, 1, 21, 1, 1)
            latents_prompt_relatents[:, :, :5] = latents_prompt[:, :, 0:5]
            latents_prompt = latents_prompt_relatents
        latents = self.prepare_latents_random_tensor(
            batch_size * num_videos_per_prompt,
            num_channels_latents * 2,
            num_frames,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents_prompt,
            timesteps_grouped[0],
        )

        num_warmup_steps = 0
        self._num_timesteps = len(timesteps)
        outer_steps = generate_list_and_repeat(num_range=num_sample_groups, repeat_factor=3, divisible_by=4)
        latents_pop_stream = []
        latents_pop_stream.append(latents[:, :, :5])

        for i, group_idx in enumerate(outer_steps):
            if infer_real:
                if i < 4:
                    continue
            prompt_embeds_index = group_idx // 4
            prompt_embeds_p = prompt_embeds_positive[prompt_embeds_index]
            prompt_embeds_n = prompt_embeds_negative[prompt_embeds_index]
            warm_up = False
            re_latents = False

            if i + 1 < len(outer_steps):
                if outer_steps[i + 1] % 4 == 0 and group_idx % 4 != 0:
                    re_latents = True
                if outer_steps[i + 1] % 4 == 0 and group_idx % 4 == 0:
                    warm_up = True

            with self.progress_bar(total=inner_steps) as progress_bar:
                for i in range(inner_steps):
                    if self.interrupt:
                        continue
                    latent_model_input = latents.to(transformer_dtype)
                    timestep = timesteps_grouped[i : i + 1].repeat(batch_size, 1).to(latent_model_input.device)
                    noise_pred = self.transformer(
                        hidden_states=latent_model_input,
                        timestep=timestep,
                        encoder_hidden_states=prompt_embeds_p,
                        attention_kwargs=attention_kwargs,
                        return_dict=False,
                    )[0]

                    if self.do_classifier_free_guidance:
                        noise_uncond = self.transformer(
                            hidden_states=latent_model_input,
                            timestep=timestep,
                            encoder_hidden_states=prompt_embeds_n,
                            attention_kwargs=attention_kwargs,
                            return_dict=False,
                        )[0]
                        noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
                    latents_cond = latents[:, :, :5]
                    latents = self.scheduler.step(noise_pred[:, :, 5:], timestep[:, 5:], latents[:, :, 5:])
                    latents = torch.cat([latents_cond, latents], dim=2)

                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0):
                        progress_bar.update()

                    if XLA_AVAILABLE:
                        xm.mark_step()

            latents_pop = latents[:, :, 5 : window_size + 5]
            latents_remain = latents[:, :, window_size + 5 :]
            if re_latents is not True:
                if warm_up is not True:
                    latents_cond_new = latents[:, :, 4 : 9]
                    timesteps_readd = timesteps_grouped[0][:5]
                    shape_add = (1, 32, 5, 60, 104)
                    
                    noise = randn_tensor(shape_add, generator=generator, device=device, dtype=prompt_embeds.dtype)
                    
                    latents_cond_new = self.scheduler.add_noise(latents_cond_new, noise, timesteps_readd)
                else:
                    latents_cond_new = latents[:, :, 0:5]
                
                latents_new = torch.cat(
                    [latents_cond_new, latents_remain, randn_like(latents_pop, generator)],
                    dim=2,
                )
            if warm_up is not True:
                latents_pop_stream.append(latents_pop)
            latents = latents_new

        latents = torch.cat(latents_pop_stream, dim=2)
        self._current_timestep = None
        latents_video = latents[:, :16, :, :, :]
        latents_depth = latents[:, 16:, :, :, :]
        gt = latents_prompt_decode[:, :16, :, :, :]

        if not output_type == "latent":
            latents_video = latents_video.to(self.vae.dtype)
            gt = gt.to(self.vae.dtype)
            latents_mean = (
                torch.tensor(self.vae.config.latents_mean)
                .view(1, self.vae.config.z_dim, 1, 1, 1)
                .to(latents.device, latents.dtype)
            )

            latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
                latents.device, latents.dtype
            )

            latents_video = latents_video / latents_std + latents_mean
            gt = gt / latents_std + latents_mean
            video = self.vae.decode(latents_video, return_dict=False)[0]
            video = self.video_processor.postprocess_video(video, output_type=output_type)
            video_gt = self.vae.decode(gt, return_dict=False)[0]
            video_gt = self.video_processor.postprocess_video(video_gt, output_type=output_type)

            latents_depth = latents_depth.to(self.vae.dtype)
            latents_depth = latents_depth / latents_std + latents_mean
            video_depth = self.vae.decode(latents_depth, return_dict=False)[0]
            video_depth = video_depth.mean(dim=1, keepdim=True)
            video_depth = torch.clamp(video_depth, -1, 1)

            video_depth = (video_depth.squeeze() + 1.0) / 2
            video_depth = video_depth.float().cpu().numpy().astype(np.float32)

        else:
            video = latents

        self.maybe_free_model_hooks()

        if not return_dict:
            return (video,)

        return WanPipelineOutput(frames=video), WanPipelineOutput(frames=video_depth), WanPipelineOutput(frames=video_gt)

    def prepare_latents_random_tensor(
        self,
        batch_size: int = 1,
        num_channels_latents: int = 16,
        num_frames: int = 49,
        height: int = 60,
        width: int = 104,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        generator: Optional[torch.Generator] = None,
        latents: Optional[torch.Tensor] = None,
        timestep: Optional[torch.Tensor] = None,
    ):
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        shape = (
            batch_size,
            num_channels_latents,
            num_frames // 4 + 1,
            height // self.vae_scale_factor_spatial,
            width // self.vae_scale_factor_spatial,
        )
        shape_2 = (
            batch_size,
            num_channels_latents,
            4,
            height // self.vae_scale_factor_spatial,
            width // self.vae_scale_factor_spatial,
        )

        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

        latents = self.scheduler.add_noise(latents, noise, timestep=timestep)
        new_tensor = randn_tensor(shape_2, generator=generator, device=device, dtype=dtype)
        latents[:, :, -4:] = new_tensor

        latents = latents.to(device)

        return latents