from videox_fun.models.wan_wrapper import WanDiffusionWrapper
from videox_fun.utils.scheduler import SchedulerInterface
from typing import List, Optional
import torch
import torch.distributed as dist
import math
import random
from typing import Optional, Union
from diffusers.utils.torch_utils import randn_tensor


class SelfForcingRLPipeline:
    def __init__(self,
                 denoising_step_list: List[int],
                 scheduler: SchedulerInterface,
                 generator: WanDiffusionWrapper,
                 ref_model: WanDiffusionWrapper = None,
                 vae=None,
                 image_processor=None,
                 num_frame_per_block=1,
                 num_max_frames: int = 21,
                 num_last_frames_with_grad: int = 21,  # number of suffix frames to backpropagate gradients
                 **kwargs):
        super().__init__()
        self.scheduler = scheduler
        self.generator = generator
        self.ref_model = ref_model
        self.vae = vae
        self.image_processor = image_processor
        self.denoising_step_list = denoising_step_list
        if self.denoising_step_list[-1] == 0:
            self.denoising_step_list = self.denoising_step_list[:-1]  # remove the zero timestep for inference
        if self.denoising_step_list[-1] == 20:
            self.denoising_step_list = self.denoising_step_list[:-1]
        print(f"SelfForcingTrainingPipeline using denoising steps: {self.denoising_step_list}")

        # Wan specific hyperparameters
        self.num_transformer_blocks = 30
        self.frame_seq_length = 1560
        self.num_frame_per_block = num_frame_per_block
        self.num_last_frames_with_grad = num_last_frames_with_grad

        self.kv_cache1 = None
        self.kv_cache2 = None
        self.kv_cache_size = num_max_frames * self.frame_seq_length


    def generate_and_sync_list(self, num_blocks, num_denoising_steps, device):
        rank = dist.get_rank() if dist.is_initialized() else 0

        if rank == 0:
            # Generate random indices
            indices = torch.randint(
                low=0,
                high=num_denoising_steps,
                size=(num_blocks,),
                device=device
            )
            random_block = torch.randint(1, num_blocks, (1,), device=device)

        else:
            indices = torch.empty(num_blocks, dtype=torch.long, device=device)
            random_block = torch.empty(1, dtype=torch.long, device=device)
        if dist.is_initialized():
            dist.broadcast(indices, src=0)  # Broadcast the random indices to all ranks
            dist.broadcast(random_block, src=0)
        return indices.tolist(), random_block.item()
    
    def decode_latents(self, latents: torch.Tensor, vae_cache = None) -> torch.Tensor:
        frames, vae_cache = self.vae.decode(latents.to(self.vae.dtype), cache=vae_cache, return_dict=False)
        frames = (frames / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
        frames = frames.cpu().float()
        return frames, vae_cache
    
    def compute_log_prob(self, sample, i, j, cache_sample, noise_level: float = 0.7, is_ref=None):
        # noise_pred = self.generator(
        #     hidden_states=sample["latents"][:, j].permute(0,2,1,3,4),
        #     timestep=sample["timesteps"][:, j],
        #     encoder_hidden_states=embeds,
        #     pooled_projections=pooled_embeds,
        #     return_dict=False,
        # ).permute(0,2,1,3,4)
        # print(">>>>>> x_shape:", sample["latents"][:, j].shape)
        # print(">>>>>> prompt_embeds shape:", sample["prompt_embeds"].shape,sample["timesteps"].shape,sample["block_y"].shape,sample["clip_context"].shape)
        if is_ref:
            noise_pred = self.ref_model(
                        x=sample["latents"][:, j].permute(0,2,1,3,4),
                        context=sample["prompt_embeds"],
                        t=sample["timesteps"][:, j],
                        seq_len=cache_sample["seq_len"] * self.num_frame_per_block,
                        y=sample["block_y"].permute(0,2,1,3,4),
                        clip_fea=sample["clip_context"],
                        kv_cache=cache_sample["kv_cache1"],
                        crossattn_cache=cache_sample["crossattn_cache"],
                        current_start=cache_sample["current_start"],
                        current_end=cache_sample["current_end"],
                        shift_cache_length=cache_sample["shift_cache_length"],
                    ).permute(0,2,1,3,4)
        else:
            noise_pred = self.generator(
                        x=sample["latents"][:, j].permute(0,2,1,3,4),
                        context=sample["prompt_embeds"],
                        t=sample["timesteps"][:, j],
                        seq_len=cache_sample["seq_len"] * self.num_frame_per_block,
                        y=sample["block_y"].permute(0,2,1,3,4),
                        clip_fea=sample["clip_context"],
                        kv_cache=cache_sample["kv_cache1"],
                        crossattn_cache=cache_sample["crossattn_cache"],
                        current_start=cache_sample["current_start"],
                        current_end=cache_sample["current_end"],
                        shift_cache_length=cache_sample["shift_cache_length"],
                    ).permute(0,2,1,3,4)
        # compute the log prob of next_latents given latents under the current model
        timestep = sample["timesteps"][:, j]
        next_index = cache_sample["next_timestep_index"]
        if next_index < len(self.denoising_step_list) - 1:
            next_timestep = torch.ones_like(timestep) * self.denoising_step_list[next_index]
            prev_sample, log_prob, prev_sample_mean, std_dev_t, dt = self.sde_step_with_logprob_to_next_step(
                noise_pred.float(), 
                timestep, 
                sample["latents"][:, j].float(),
                next_timestep,
                noise_level=noise_level,
                prev_sample=sample["next_latents"][:, j].float()
            )

        else:
            prev_sample, log_prob, prev_sample_mean, std_dev_t, dt = self.sde_step_with_logprob_to_x0(
                noise_pred.float(), 
                timestep, 
                sample["latents"][:, j].float(),
                prev_sample=sample["next_latents"][:, j].float(),
                noise_level=noise_level,
            )
        
        
        return prev_sample, log_prob, prev_sample_mean, std_dev_t, dt

    def sde_step_with_logprob_to_next_step(
        self,
        model_output: torch.FloatTensor,
        timestep: Union[float, torch.FloatTensor],
        sample: torch.FloatTensor,
        next_timestep: Union[float, torch.FloatTensor],
        noise_level: float = 0.7,
        prev_sample: Optional[torch.FloatTensor] = None,
        generator: Optional[torch.Generator] = None,
    ):
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
        process from the learned model outputs (most often the predicted velocity).

        Args:
            model_output (`torch.FloatTensor`):
                The direct output from learned flow model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
        """
        # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32

        model_output=model_output.float()
        sample=sample.float()
        if prev_sample is not None:
            prev_sample=prev_sample.float()
        device = model_output.device
        self.scheduler.timesteps = self.scheduler.timesteps.to(device)
        timesteps_list = self.scheduler.timesteps
        self.scheduler.sigmas = self.scheduler.sigmas.to(device)
        step_index = [torch.argmin(torch.abs(timesteps_list - t)).item() for t in timestep]
        next_timestep_index = [torch.argmin(torch.abs(timesteps_list - t)).item() for t in next_timestep]
        # prev_step_index = [step+1 for step in step_index]
        sigma = self.scheduler.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
        # print("######sigma",sigma, timestep)
        sigma_prev = self.scheduler.sigmas[next_timestep_index].view(-1, *([1] * (len(sample.shape) - 1)))
        sigma_max = self.scheduler.sigmas[1].item()
        sigma_min = self.scheduler.sigmas[-1].item()
        dt = sigma_prev - sigma

        # std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level
        std_dev_t = sigma_min + (sigma_max - sigma_min) * sigma
        if noise_level == 0:
            # print("noise_level is 0, no noise will be added!")
            std_dev_t = std_dev_t * 0
        # our sde
        prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
        
        if prev_sample is None:
            variance_noise = randn_tensor(
                model_output.shape,
                generator=generator,
                device=model_output.device,
                dtype=model_output.dtype,
            )
            prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise

        log_prob = (
            -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
            - torch.log(std_dev_t * torch.sqrt(-1*dt))
            - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
        )

        # mean along all but batch dimension
        log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
        
        return prev_sample, log_prob, prev_sample_mean, std_dev_t, torch.sqrt(-1*dt)

    def sde_step_with_logprob_to_x0(
        self,
        model_output: torch.FloatTensor,
        timestep: Union[float, torch.FloatTensor],
        sample: torch.FloatTensor,
        noise_level: float = 0.7,
        prev_sample: Optional[torch.FloatTensor] = None,
        generator: Optional[torch.Generator] = None,
    ):
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
        process from the learned model outputs (most often the predicted velocity).

        Args:
            model_output (`torch.FloatTensor`):
                The direct output from learned flow model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
        """
        # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
        model_output=model_output.float()
        sample=sample.float()
        if prev_sample is not None:
            prev_sample=prev_sample.float()
        device = model_output.device
        self.scheduler.timesteps = self.scheduler.timesteps.to(device)
        timesteps_list = self.scheduler.timesteps
        self.scheduler.sigmas = self.scheduler.sigmas.to(device)
        step_index = [torch.argmin(torch.abs(timesteps_list - t)).item() for t in timestep]
        # prev_step_index = [step+1 for step in step_index]
        sigma = self.scheduler.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
        # sigma_prev = self.scheduler.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
        sigma_max = self.scheduler.sigmas[1].item()
        sigma_min = self.scheduler.sigmas[-1].item()
        dt = - sigma

        # std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level
        std_dev_t = sigma_min + (sigma_max - sigma_min) * sigma
        if noise_level == 0:
            std_dev_t = std_dev_t * 0
        
        # our sde
        prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
        
        if prev_sample is None:
            variance_noise = randn_tensor(
                model_output.shape,
                generator=generator,
                device=model_output.device,
                dtype=model_output.dtype,
            )
            prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise

        log_prob = (
            -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
            - torch.log(std_dev_t * torch.sqrt(-1*dt))
            - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
        )

        # mean along all but batch dimension
        log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
        
        return prev_sample, log_prob, prev_sample_mean, std_dev_t, torch.sqrt(-1*dt)

    def inference_with_trajectory(
            self,
            noise: torch.Tensor,
            prompt_embeds: torch.Tensor,
            initial_latent: torch.Tensor,  # first frame latent
            inpaint_latents: torch.Tensor,
            clip_context: torch.Tensor,
            output_type: Optional[str] = "pil",
            mini_num_image_per_prompt: int = 2,
            noise_level: float = 0.7,
            train_num_steps: int = 1,
            process_index: int = 0,
            sample_num_steps: int = 10,
            random_timestep: Optional[int] = None,
            random_block: Optional[int] = None,
    ) -> torch.Tensor:
        batch_size, num_frames, num_channels, height, width = noise.shape
        vae_cache = None
        output = []
        output_x0 = torch.zeros_like(noise)
        next_timesteps = torch.zeros([batch_size, num_frames], device=noise.device, dtype=torch.int64)
        try:
            seq_len = math.ceil((height * width) / (self.generator.module.config.patch_size[1] * self.generator.module.config.patch_size[2]))  # token sequence length per frame
        except:
            seq_len = math.ceil((height * width) / (self.generator.config.patch_size[1] * self.generator.config.patch_size[2]))
        num_blocks = num_frames // self.num_frame_per_block

        # Step 1: Initialize KV cache to all zeros
        self._initialize_kv_cache(
            batch_size=batch_size, dtype=noise.dtype, device=noise.device
        )
        self._initialize_crossattn_cache(
            batch_size=batch_size, dtype=noise.dtype, device=noise.device
        )

        # Step 3: Temporal denoising loop
        num_denoising_steps = len(self.denoising_step_list)
        exit_flags, random_block = self.generate_and_sync_list(num_frames, num_denoising_steps, device=noise.device)
        start_gradient_frame_index = num_frames - self.num_last_frames_with_grad
        shift_cache_length = 0

        # if random_timestep is None:
        #     random_timestep = random.randint(0, len(self.denoising_step_list)-1)
        random_block = random.randint(1, num_blocks - 1)
        # all_latents_video = []
        # all_log_probs_video = []
        # all_timesteps_video = []
        output = []
        all_latents = []
        all_log_probs = []
        all_timesteps = []
        exit_flag = False
        
        for block_index in range(num_blocks):
            repeat_flag = False
            block_latents = noise[:, block_index * self.num_frame_per_block:(block_index + 1) * self.num_frame_per_block]
            block_y = inpaint_latents[:, block_index * self.num_frame_per_block:(block_index + 1) * self.num_frame_per_block]

            # 如果超出 cache 最大长度，则整体左移，丢弃最左边的部分
            current_start = block_index * self.num_frame_per_block * seq_len - shift_cache_length
            current_end = (block_index + 1) * self.num_frame_per_block * seq_len - shift_cache_length
            block_len =  current_end - current_start
            while current_end > self.kv_cache1[0]["k"].shape[1]:
                shift_cache_length += block_len
                current_end -= block_len
                current_start -= block_len
                for layer_kv_cache in self.kv_cache1:
                    layer_kv_cache["k"][:, :-block_len] = layer_kv_cache["k"][:, block_len:].clone()
                    layer_kv_cache["v"][:, :-block_len] = layer_kv_cache["v"][:, block_len:].clone()
                print(f"Shift cache length: {shift_cache_length}, current_start: {current_start}, current_end: {current_end}")

            if block_index == 0:
                block_latents = initial_latent
                # next_timestep_index = len(self.denoising_step_list)
                # next_timestep = 0
            else:
                # Step 3.1: Spatial denoising loop
                for index, current_timestep in enumerate(self.denoising_step_list):
                    cur_noise_level = 0
                    if index == exit_flags[block_index] and block_index == random_block:
                        exit_flag = True
                        # Repeat kv_cache1
                        self.kv_cache1 = [
                            {
                                "k": cache["k"].repeat(mini_num_image_per_prompt, 1, 1, 1),
                                "v": cache["v"].repeat(mini_num_image_per_prompt, 1, 1, 1),
                            }
                            for cache in self.kv_cache1
                        ]
                        # Repeat crossattn_cache
                        self.crossattn_cache = [
                            {
                                "k": cache["k"].repeat(mini_num_image_per_prompt, 1, 1, 1),
                                "v": cache["v"].repeat(mini_num_image_per_prompt, 1, 1, 1),
                                "is_init": cache["is_init"]
                            }
                            for cache in self.crossattn_cache
                        ]

                        cur_noise_level= noise_level
                        
                        prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
                        
                        clip_context = clip_context.repeat(mini_num_image_per_prompt, 1, 1)

                    # exit_flag = (index == exit_flags[block_index] and block_index == random_block)  # Only backprop at the randomly selected timestep (consistent across all ranks)
                    timestep = torch.ones(
                        [batch_size, self.num_frame_per_block],
                        device=noise.device,
                        dtype=torch.int64) * current_timestep
                    if exit_flag:
                        timestep = timestep.repeat(mini_num_image_per_prompt, 1)

                    # print(f">  block_index: {block_index}, index: {index}, current_timestep: {current_timestep}, exit_flag: {exit_flag}")
                    if exit_flag and not repeat_flag:
                        block_latents = block_latents.repeat(mini_num_image_per_prompt, 1, 1, 1, 1)
                        block_y = block_y.repeat(mini_num_image_per_prompt, 1, 1, 1, 1)
                        repeat_flag = True
                    if index == exit_flags[block_index] and block_index == random_block:
                        all_latents.append(block_latents)
                    with torch.no_grad():
                        # print(block_latents.shape, prompt_embeds.shape, timestep.shape, block_y.shape, clip_context.shape)
                        noise_pred = self.generator(
                            x=block_latents.permute(0,2,1,3,4),
                            context=prompt_embeds,
                            t=timestep,
                            seq_len=seq_len * self.num_frame_per_block,
                            y=block_y.permute(0,2,1,3,4),
                            clip_fea=clip_context,
                            kv_cache=self.kv_cache1,
                            crossattn_cache=self.crossattn_cache,
                            current_start=current_start,
                            current_end=current_end,
                            shift_cache_length=shift_cache_length,
                        ).permute(0,2,1,3,4)
                        # next_timestep = torch.ones_like(timestep) * self.denoising_step_list[index + 1]
                        # block_latents = self._convert_flow_pred_to_next_step(
                        #     noise_pred, block_latents, timestep, next_timestep
                        # )
                        latents_dtype = block_latents.dtype
                        if index < len(self.denoising_step_list) - 1:
                            next_timestep = torch.ones_like(timestep) * self.denoising_step_list[index + 1]
                            block_latents, log_prob, prev_latents_mean, std_dev_t, dt = self.sde_step_with_logprob_to_next_step(
                                noise_pred.float(), 
                                timestep, 
                                block_latents.float(),
                                next_timestep,
                                noise_level=cur_noise_level,
                            )
                        else:
                            block_latents, log_prob, prev_latents_mean, std_dev_t, dt = self.sde_step_with_logprob_to_x0(
                                noise_pred.float(), 
                                timestep, 
                                block_latents.float(),
                                noise_level=cur_noise_level,
                            )
                        if block_latents.dtype != latents_dtype:
                            block_latents = block_latents.to(latents_dtype)
                        if index == exit_flags[block_index] and block_index == random_block:
                            all_latents.append(block_latents)
                            all_log_probs.append(log_prob)
                            all_timesteps.append(timestep)
                            sample={
                                "block_y": block_y,
                                "clip_context": clip_context,
                                "prompt_embeds": prompt_embeds

                            }
                            cache_sample={
                                "random_block": random_block,
                                "seq_len": seq_len,
                                "kv_cache1": self.kv_cache1,
                                "crossattn_cache": self.crossattn_cache,
                                "current_start": current_start,
                                "current_end": current_end,
                                "shift_cache_length": shift_cache_length,
                                "next_timestep_index": index+1
                            }

            if block_index < random_block:
                output.append(block_latents.repeat(mini_num_image_per_prompt, 1, 1, 1, 1))
            else:
                output.append(block_latents)
            
            # Step 3.3: rerun with timestep zero to update the cache
            timestep = torch.zeros([batch_size, self.num_frame_per_block],
                device=noise.device, dtype=torch.int64)
            if exit_flag:
                timestep = timestep.repeat(mini_num_image_per_prompt, 1)

            with torch.no_grad():
                self.generator(
                    x=block_latents.permute(0,2,1,3,4),
                    context=prompt_embeds,
                    t=timestep,
                    seq_len=seq_len * self.num_frame_per_block,
                    y=block_y.permute(0,2,1,3,4),
                    clip_fea=clip_context,
                    kv_cache=self.kv_cache1,
                    crossattn_cache=self.crossattn_cache,
                    current_start=current_start,
                    current_end=current_end,
                    shift_cache_length=shift_cache_length,
                )

        output = torch.cat(output, dim=1).to(noise.dtype)
        sampling_video = self.decode_latents(output.permute(0,2,1,3,4))[0].permute(0,2,1,3,4)

        latents = torch.stack(
                    all_latents, dim=1
                )  # (batch_size, num_steps + 1, 16, 96, 96)
        log_probs = torch.stack(all_log_probs, dim=1)  # shape after stack (batch_size, num_steps)
        timesteps = torch.stack(all_timesteps, dim=1)[:,0]  # shape after stack (batch_size, num_steps)
        sample.update({
                    # "prompt_ids": prompt_ids.repeat(self.config.mini_num_image_per_prompt,1),
                    # "prompt_embeds": prompt_embeds.repeat(self.config.mini_num_image_per_prompt,1,1),
                    # "pooled_prompt_embeds": pooled_prompt_embeds.repeat(self.config.mini_num_image_per_prompt,1),
                    "timesteps": timesteps,
                    "latents": latents[
                        :, :-1
                    ],  # each entry is the latent before timestep t
                    "next_latents": latents[
                        :, 1:
                    ],  # each entry is the latent after timestep t
                    "log_probs": log_probs,
                })
        # 检查随机数
        shape = (2,3)
        noise=randn_tensor(
                shape,
                generator=None,
                device=noise_pred.device,
                dtype=noise_pred.dtype,
            )
        # gather_noise=self.accelerator.gather(noise)
        # gather_random_block=self.accelerator.gather(random_block)
        print("#######",noise, "#####",random_block)

        # return images, latents, log_probs, timesteps
        return sampling_video, sample, cache_sample

    def _initialize_kv_cache(self, batch_size, dtype, device):
        """
        Initialize a Per-GPU KV cache for the Wan model.
        """
        kv_cache1 = []

        for _ in range(self.num_transformer_blocks):
            kv_cache1.append({
                "k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
                "v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
                # "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
                # "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
            })

        self.kv_cache1 = kv_cache1  # always store the clean cache

    def _initialize_crossattn_cache(self, batch_size, dtype, device):
        """
        Initialize a Per-GPU cross-attention cache for the Wan model.
        """
        crossattn_cache = []

        for _ in range(self.num_transformer_blocks):
            crossattn_cache.append({
                "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
                "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
                "is_init": False
            })
        self.crossattn_cache = crossattn_cache

    def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
        """
        Convert flow matching's prediction to x0 prediction.
        flow_pred: the prediction with shape [B, C, F, H, W]
        xt: the input noisy data with shape [B, C, F, H, W]
        timestep: the timestep with shape [B, F]

        pred = noise - x0
        x_t = (1-sigma_t) * x0 + sigma_t * noise
        we have x0 = x_t - sigma_t * pred
        see derivations https://chatgpt.com/share/67bf8589-3d04-8008-bc6e-4cf1a24e2d0e
        """
        # use higher precision for calculations
        original_dtype = flow_pred.dtype
        flow_pred, xt, sigmas, timesteps = map(
            lambda x: x.double().to(flow_pred.device), [flow_pred, xt,
                                                        self.scheduler.sigmas,
                                                        self.scheduler.timesteps]
        )

        batch_size, num_frames = timestep.shape
        timestep = timestep.flatten()
        timestep_id = torch.argmin(
            (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
        sigma_t = sigmas[timestep_id].reshape(batch_size, 1, num_frames, 1, 1)
        x0_pred = xt - sigma_t * flow_pred
        return x0_pred.to(original_dtype)

    def _convert_flow_pred_to_next_step(self, flow_pred, xt, timestep, next_timestep):
        original_dtype = flow_pred.dtype
        flow_pred, xt, sigmas, timesteps = map(
            lambda x: x.double().to(flow_pred.device), [flow_pred, xt,
                                                        self.scheduler.sigmas,
                                                        self.scheduler.timesteps]
        )

        batch_size, num_frames = timestep.shape
        timestep = timestep.flatten()
        timestep_id = torch.argmin(
            (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
        sigma_t = sigmas[timestep_id].reshape(batch_size, 1, num_frames, 1, 1)

        next_timestep = next_timestep.flatten()
        next_timestep_id = torch.argmin(
            (timesteps.unsqueeze(0) - next_timestep.unsqueeze(1)).abs(), dim=1)
        sigma_next = sigmas[next_timestep_id].reshape(batch_size, 1, num_frames, 1, 1)

        prev_sample = xt + (sigma_next - sigma_t) * flow_pred
        return prev_sample.to(original_dtype)