
from typing import Tuple
from einops import rearrange
from torch import nn
import torch.distributed as dist
import torch

from utils.loss import get_denoising_loss
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper, WanCLIPEncoder

import os 

from pipeline import BidirectionalTrainingPipeline

import torch.nn.functional as F
from typing import Optional, Tuple
import torch

from torchvision.io import write_video
from einops import rearrange
import os
import math

from utils.rm_video_align_wrapper import VideoAlignWrapper

def all_gather_same_group(tensor, group_size, subgroups):
    rank = dist.get_rank()
    group_id = rank // group_size
    # print(f"GET group_id = {group_id}")
    group = subgroups[group_id]  # 从缓存中获取子组

    gathered = [torch.empty_like(tensor) for _ in range(group_size)]
    dist.all_gather(gathered, tensor, group=group)
    gathered = torch.stack(gathered)  # shape [group_size]
    return gathered 

class T2V_DMD_GRPO(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        
        self._initialize_models(args, device)

        self.global_rank = dist.get_rank()
        self.is_main_process = self.global_rank == 0
        
        self.step = 0
        
        self.gen_train_videos_dir = os.path.join(args.logdir, "gen_train_videos")
        self.gen_eval_videos_dir = os.path.join(args.logdir, "gen_eval_videos")
        self.gen_log_videos_dir = os.path.join(args.logdir, "gen_log_videos")
        if self.is_main_process:
            os.makedirs(self.gen_train_videos_dir, exist_ok=True)
            os.makedirs(self.gen_eval_videos_dir, exist_ok=True)
            os.makedirs(self.gen_log_videos_dir, exist_ok=True)

        self.device = device
        self.args = args

        self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32

        self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long) # [1000, 750, 500, 250]
        timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
        print(f"------- raw self.denoising_step_list {self.denoising_step_list}")
        self.denoising_step_list = timesteps[1000 - self.denoising_step_list] # [0, 250, 500, 750]
        print(f"------- self.denoising_step_list {self.denoising_step_list}")

        # this will be init later with fsdp-wrapped modules
        self.inference_pipeline: BidirectionalTrainingPipeline = None

        # Step 2: Initialize all dmd hyperparameters
        self.num_train_timestep = args.num_train_timestep  # 1000
        self.min_step = int(0.02 * self.num_train_timestep)  # 20
        self.max_step = int(0.98 * self.num_train_timestep)  # 980

        self.real_guidance_scale = args.guidance_scale  # 6.0
        
        self.timestep_shift = getattr(args, "timestep_shift", 5.0)  # 5.0
        self.min_score_timestep = getattr(args, "min_score_timestep", 0)


    def _initialize_models(self, args, device):
        self.real_model_name = getattr(args, "real_name", "Wan2.1-I2V-14B-720P") 
        self.fake_model_name = getattr(args, "fake_name", "Wan2.1-I2V-14B-720P") 
        self.generator_name = getattr(args, "generator_name", "Wan2.1-I2V-14B-720P") 

        self.real_model_path = getattr(args, "real_path") 
        self.fake_model_path = getattr(args, "fake_path") 
        self.generator_path = getattr(args, "generator_path")

        print("--------- loading generator model")
        self.generator = WanDiffusionWrapper(
            **getattr(args, "model_kwargs", {}), # timestep_shift: 5.0
            model_name=self.generator_name, model_path=self.generator_path, wan_type="gen"
        )
        self.generator.model.requires_grad_(True)

        print("--------- loading real_score model")
        self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, model_path=self.real_model_path, wan_type="real")
        self.real_score.model.requires_grad_(False)

        print("--------- loading fake_score model")
        self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, model_path=self.fake_model_path, wan_type="fake")
        self.fake_score.model.requires_grad_(True)

        print(f"loading WanVAEWrapper ...")
        self.vae = WanVAEWrapper(model_name=self.generator_name).eval()
        self.vae.requires_grad_(False)
        self.vae = self.vae.to(dtype=torch.bfloat16)

        self.generator.enable_gradient_checkpointing()
        self.fake_score.enable_gradient_checkpointing()

        print(f"--------- loading RewardModel ...")
        video_align_path = 
        self.reward_model = VideoAlignWrapper(video_align_path, device)
        self.reward_model.inferencer.model.requires_grad_(False)

        self.scheduler = self.generator.get_scheduler()
        self.scheduler.timesteps = self.scheduler.timesteps.to(device)
        self.scheduler.alphas_cumprod = None  # None

    def _get_timestep(
            self,
            min_timestep: int, max_timestep: int,
            batch_size: int, num_frame: int
    ) -> torch.Tensor:
        """
        Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
        from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
        - uniform_timestep, it will use the same timestep for all frames.
        """
        timestep = torch.randint(
            min_timestep, max_timestep,
            [batch_size, 1],
            device=self.device, dtype=torch.long
        ).repeat(1, num_frame)

        if self.timestep_shift > 1:
            timestep = self.timestep_shift * (timestep / 1000) / (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000

        return timestep

    def _compute_dmd_grad(
        self, 
        noisy_image_or_video: torch.Tensor, estimated_clean_image_or_video: torch.Tensor,
        timestep: torch.Tensor,
        conditional_dict: dict, unconditional_dict: dict,
        normalization: bool = True,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: Compute the fake score
        _, pred_fake_image_cond = self.fake_score(
            noisy_image_or_video=noisy_image_or_video,
            conditional_dict=conditional_dict,
            timestep=timestep
        )

        pred_fake_image = pred_fake_image_cond
        # self.save_latent_as_video(pred_fake_image, f"gen_fake_pred_{int(timestep[0, 0].item())}")

        # Step 2: Compute the real score
        _, pred_real_image_cond = self.real_score(
            noisy_image_or_video=noisy_image_or_video,
            conditional_dict=conditional_dict,
            timestep=timestep,
        )

        _, pred_real_image_uncond = self.real_score(
            noisy_image_or_video=noisy_image_or_video,
            conditional_dict=unconditional_dict,
            timestep=timestep,
        )

        pred_real_image = pred_real_image_cond + (
            pred_real_image_cond - pred_real_image_uncond
        ) * self.real_guidance_scale
        # self.save_latent_as_video(pred_real_image, f"gen_real_pred_{int(timestep[0, 0].item())}")

        # Step 3: Compute the DMD gradient (DMD paper eq. 7).
        grad = (pred_fake_image - pred_real_image)
        # self.save_latent_as_video(grad, f"dmd_grad")

        # TODO: Change the normalizer for causal teacher
        if normalization:
            # Step 4: Gradient normalization (DMD paper eq. 8).
            p_real = (estimated_clean_image_or_video - pred_real_image)
            normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
            print(f"dmd normalizer = {normalizer}")
            grad = grad / normalizer
    
        if torch.isnan(grad).any() or torch.isinf(grad).any():
            grad = torch.nan_to_num(grad)
            print("WARNING: NaN/Inf in grad, nan_to_num applied")

        # self.save_latent_as_video(grad, f"dmd_grad_norm")
        
        dmd_log_dict = {
            "dmd_gradient_norm": torch.mean(torch.abs(grad)).detach(),
            "timestep": int(timestep[0, 0].item())
        }
        return grad, dmd_log_dict

    def _compute_grpo_grad(
        self, image_or_video, reward_log_dict,
        conditional_dict: dict, unconditional_dict: dict, denoised_timestep_to,
        max_grad_norm = 0.8, 
        normalization: bool = True,
    ):
        video_latents = image_or_video.detach()
        reward = reward_log_dict['avg_reward']
        # reward = reward_log_dict['ta_reward']
        reward_tensor = torch.tensor([reward], dtype=torch.float32, device=self.device)
        all_rewards = all_gather_same_group(reward_tensor, self.group_size, self.subgroups)  # [32]
        all_rewards = all_rewards.squeeze(-1)  
        all_video_latents = all_gather_same_group(video_latents, self.group_size, self.subgroups)  # [32, 1, F, C, H, W]
        
        # print(f"all_rewards shape {all_rewards.shape}")
        # 组内 advantages 
        group_mean = all_rewards.mean()
        group_std  = all_rewards.std() + 1e-8
        advantages = ((all_rewards - group_mean) / group_std).squeeze(-1) 
        # 当前 rank 在 group 内的索引
        local_rank_in_group = self.global_rank % self.group_size

        # 当前 rank 对应的 advantage
        local_advantage = advantages[local_rank_in_group]
        # mean_latent = all_video_latents.mean(dim=0)  # shape: [1, F, C, H, W]
        # 1. 取前 4 高 reward 的索引
        _, top4_idx = torch.topk(all_rewards, k=min(4, all_rewards.size(0)))  # 防止不够 4 条
        # 2. 抽出对应 latent
        top4_latents = all_video_latents[top4_idx]          # shape: [4, 1, F, C, H, W]
        mean_latent = top4_latents.mean(dim=0)              # shape: [1, F, C, H, W]
        top4_adv = all_rewards[top4_idx]          # shape: [4, 1, F, C, H, W]
        mean_adv = top4_adv.mean(dim=0)              # shape: [1, F, C, H, W]

        timestep = self._get_timestep(
            self.min_score_timestep, self.num_train_timestep,
            video_latents.shape[0], video_latents.shape[1]
        )
        noise = torch.randn_like(video_latents)
        batch_size, num_frame = image_or_video.shape[:2]
        noisy_video_latents = self.scheduler.add_noise(
                video_latents.flatten(0, 1),
                noise.flatten(0, 1), timestep.flatten(0, 1)
        ).detach().unflatten(0, (batch_size, num_frame))
        noisy_mean_latent = self.scheduler.add_noise(
                mean_latent.flatten(0, 1),
                noise.flatten(0, 1), timestep.flatten(0, 1)
        ).detach().unflatten(0, (batch_size, num_frame))


        with torch.no_grad():
            _, pred_gen_image = self.fake_score(
                noisy_image_or_video=noisy_video_latents,
                conditional_dict=conditional_dict,
                timestep=timestep
            )        

            _, pred_mean_image = self.fake_score(
                noisy_image_or_video=noisy_mean_latent,
                conditional_dict=conditional_dict,
                timestep=timestep
            )            

        grad = local_advantage * (pred_mean_image - pred_gen_image)
        # grad = local_advantage * (pred_old_image - pred_new_image)
        # if local_advantage < 0:
        #     grad = grad / 4

        if normalization:
            p_real = (video_latents - pred_gen_image)
            # p_real = (video_latents - pred_new_image)
            normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
            print(f"grpo normalizer = {normalizer}")
            grad = grad / normalizer
        
        # 限制 grad 的绝对值均值不超过 max_grad_norm (0.8)
        raw_grad_norm = torch.mean(torch.abs(grad)).detach()
        scale = torch.clamp(max_grad_norm / (raw_grad_norm + 1e-8), max=1.0)
        grad = grad * scale

        # self.save_latent_as_video(grad, f"grpo_grad_norm")
        grpo_log_dict = {
            "raw_grpo_grad_norm": raw_grad_norm,
            "grpo_grad_norm": torch.mean(torch.abs(grad)).detach(),
            "grpo_timestep": int(timestep[0, 0].item()),
            "local_advantage": local_advantage,
            "top4_adv": mean_adv
        }
        return grad, grpo_log_dict, mean_adv



    def compute_distribution_matching_loss(
        self, 
        reward_log_dict,
        image_or_video: torch.Tensor, image_or_video_end,
        conditional_dict: dict, unconditional_dict: dict,
        denoised_timestep_from: int = 0, denoised_timestep_to: int = 0,
    ) -> Tuple[torch.Tensor, dict]:

        original_latent = image_or_video
        batch_size, num_frame = image_or_video.shape[:2]

        with torch.no_grad():
            # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
            timestep = self._get_timestep(
                self.min_score_timestep, self.num_train_timestep,
                batch_size, num_frame,
            )
            timestep = timestep.clamp(self.min_step, self.max_step)
            # print(f"------------[generator_loss] dmd sampled timestep = {timestep[0, 0].item()}")

            noise = torch.randn_like(image_or_video)
            noisy_latent = self.scheduler.add_noise(
                image_or_video.flatten(0, 1),
                noise.flatten(0, 1), timestep.flatten(0, 1)
            ).detach().unflatten(0, (batch_size, num_frame))

            # Step 2: Compute the dmd grad
            dmd_grad, dmd_log_dict = self._compute_dmd_grad(
                noisy_image_or_video=noisy_latent, estimated_clean_image_or_video=original_latent,
                timestep=timestep,
                conditional_dict=conditional_dict, unconditional_dict=unconditional_dict
            )
            
            if self.step >= 80 and self.step <= 600:
                grpo_grad, grpo_log_dict, group_mean = self._compute_grpo_grad(
                    image_or_video=image_or_video_end, reward_log_dict=reward_log_dict,
                    conditional_dict=conditional_dict, unconditional_dict=unconditional_dict, denoised_timestep_to=denoised_timestep_to
                )
                dmd_log_dict.update(grpo_log_dict)
            else:
                grpo_grad = 0
                group_mean = 0

        def compute_t(step, group_mean):
            if step < 80:
                return 1.0  
            elif step > 600:
                return 1.0 
            else:
                return 0.5

        # 使用方式
        t = compute_t(self.step, group_mean)
        grad = (1-t) * grpo_grad + t * dmd_grad
        print(f"log_dict = {dmd_log_dict}")
        dmd_loss = 0.5 * F.mse_loss(original_latent.double(), (original_latent.double() - grad.double()).detach(), reduction="mean")
        return dmd_loss, dmd_log_dict

    def generator_loss(
        self, prompt,
        image_or_video_shape,
        conditional_dict: dict, unconditional_dict: dict,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: Unroll generator to obtain fake videos
        pred_image, denoised_timestep_from, denoised_timestep_to = self._run_generator(
            image_or_video_shape=image_or_video_shape, conditional_dict=conditional_dict,
        )
        pred_image_end = self.get_gen_end_image(pred_image.detach(), conditional_dict, denoised_timestep_to)
        # print(f"------------[generator_loss] timestep from {denoised_timestep_from} to {denoised_timestep_to}")
        # Step 1.1 save video and get reward
        reward_log_dict = self.save_latent_and_reward(pred_image_end, f"gen_gen_pred_{denoised_timestep_from}_{denoised_timestep_to}", prompt)

        # Step 2: Compute the DMD loss
        dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
            reward_log_dict,
            image_or_video=pred_image, image_or_video_end=pred_image_end,
            conditional_dict=conditional_dict, unconditional_dict=unconditional_dict,
            denoised_timestep_from=denoised_timestep_from, denoised_timestep_to=denoised_timestep_to
        )

        del pred_image, denoised_timestep_from, denoised_timestep_to
        dmd_log_dict.update(reward_log_dict)
        return dmd_loss, dmd_log_dict

    def critic_loss(
        self, prompt,
        image_or_video_shape,
        conditional_dict: dict, unconditional_dict: dict,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: Run generator on backward simulated noisy input
        with torch.no_grad():
            generated_image, denoised_timestep_from, denoised_timestep_to = self._run_generator(
                image_or_video_shape=image_or_video_shape, conditional_dict=conditional_dict
            )
        # print(f"------------[critic_loss] timestep from {denoised_timestep_from} to {denoised_timestep_to}")
        # Step 1.1 save video
        # self.save_latent_as_video(generated_image, f"critic_gen_{denoised_timestep_from}_{denoised_timestep_to}", prompt)
        # Step 2: Compute the fake prediction
        critic_timestep = self._get_timestep(
            self.min_score_timestep, self.num_train_timestep,
            image_or_video_shape[0], image_or_video_shape[1]  # B, F
        )
        critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)

        critic_noise = torch.randn_like(generated_image)
        noisy_generated_image = self.scheduler.add_noise(
            generated_image.flatten(0, 1), critic_noise.flatten(0, 1),
            critic_timestep.flatten(0, 1)
        ).unflatten(0, image_or_video_shape[:2])
        # print(f"------------[critic_loss] critic sampled timestep = {critic_timestep[0, 0].item()}")

        pred_flow, pred_fake_image = self.fake_score(
            noisy_image_or_video=noisy_generated_image, conditional_dict=conditional_dict,
            timestep=critic_timestep
        )
        # self.save_latent_as_video(pred_fake_image, f"critic_gen_pred_{critic_timestep[0, 0].item()}")

        # Step 3: Compute the denoising loss for the fake critic
        denoising_loss = torch.mean((pred_flow - (critic_noise - generated_image)) ** 2)
        grad = pred_flow - (critic_noise - generated_image)
        # self.save_latent_as_video(grad, f"critic_gen_grad_{denoised_timestep_from}_{denoised_timestep_to}")

        # Step 5: Debugging Log
        critic_log_dict = {
            "critic_timestep": critic_timestep.detach(),
            "critic_sft_gradient_norm": torch.mean(torch.abs(grad)).detach(),
        }

        if torch.isnan(denoising_loss).any() or torch.isinf(denoising_loss).any():
            denoising_loss = torch.nan_to_num(denoising_loss)
            print("WARNING: NaN/Inf in denoising_loss, nan_to_num applied")

        return denoising_loss, critic_log_dict

    # def save_latent_as_video(self, video_latents, task_type, prompt=None, save_dir=None):
    #     if self.step % 5 == 4 and self.global_rank < 8:
    #     # if self.global_rank < 4:
    #         if save_dir is None:
    #             save_dir = self.gen_log_videos_dir

    #         # 1. 解码视频
    #         video_latents = video_latents.detach()
    #         self.vae.to(device=self.device, dtype=self.dtype)
    #         videos = self.vae.decode_to_pixel(video_latents)
    #         videos = (videos * 0.5 + 0.5).clamp(0, 1)
    #         videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
    #         self.vae.model.clear_cache()
    #         self.vae.to(device='cpu')

    #         # 2. 保存 mp4
    #         base_name = f'step{self.step}-{task_type}-rank{self.global_rank}'
    #         video_path = os.path.join(save_dir, f'{base_name}.mp4')
    #         write_video(video_path, videos[0], fps=16)

    #         # 3. 保存 prompt 到同名 txt
    #         if prompt is not None:
    #             txt_path = os.path.join(save_dir, f'{base_name}.txt')
    #             with open(txt_path, 'w', encoding='utf-8') as f:
    #                 f.write(prompt[0])

            # torch.cuda.empty_cache()

    def save_latent_and_reward(self, video_latents, task_type, prompt):
        save_dir = self.gen_train_videos_dir
        # 1. 解码视频
        video_latents = video_latents.detach()
        self.vae.to(device=self.device, dtype=self.dtype)
        videos = self.vae.decode_to_pixel(video_latents)
        videos = (videos * 0.5 + 0.5).clamp(0, 1)
        videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
        self.vae.model.clear_cache()
        self.vae.to(device='cpu')
        # 2. 保存 mp4
        base_name = f'step{self.step}-{task_type}-rank{self.global_rank}'
        video_path = os.path.join(save_dir, f'{base_name}.mp4')
        write_video(video_path, videos[0], fps=16)
        torch.cuda.empty_cache()
        # 3. reward
        all_reward = self.reward_model.inferencer.reward([video_path], prompt, use_norm=False)[0]
        avg_reward = 0.1 * all_reward['MQ'] + 0.1 * all_reward["VQ"] + 0.8 * all_reward["TA"]
        # avg_reward = all_reward["TA"]
        # avg_reward = (all_reward['MQ'] + all_reward["VQ"] + all_reward["TA"]) / 3

        reward_log_dict = {
            "avg_reward" : avg_reward,
            "mq_reward" : all_reward['MQ'],
            "vq_reward" : all_reward['VQ'],
            "ta_reward" : all_reward['TA']
        }
        # 4. save prompt
        txt_path = os.path.join(save_dir, f'{base_name}.txt')
        with open(txt_path, 'w', encoding='utf-8') as f:
            f.write(prompt[0])

        return reward_log_dict

    def _run_generator(
        self,
        image_or_video_shape, conditional_dict: dict
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        noise_shape = image_or_video_shape  # 1, 21, 16, tgt_h // 8, tgt_w // 8
        # _consistency_backward_simulation, run random times (1-4) add_noise-de_noise to get a pred_image, only con_forward
        pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(  
            noise=torch.randn(*noise_shape, device=self.device, dtype=self.dtype),
            **conditional_dict
        )
        pred_image_or_video = pred_image_or_video.to(self.dtype)
        
        return pred_image_or_video, denoised_timestep_from, denoised_timestep_to

    def _consistency_backward_simulation(
        self,
        noise: torch.Tensor,
        **conditional_dict: dict
    ) -> torch.Tensor:
        if self.inference_pipeline is None:
            self._initialize_inference_pipeline()

        return self.inference_pipeline.inference_with_trajectory(
            noise=noise, clip_fea=None, y=None, **conditional_dict
        )

    def _initialize_inference_pipeline(self):
        """
        Lazy initialize the inference pipeline during the first backward simulation run.
        Here we encapsulate the inference code with a model-dependent outside function.
        We pass our FSDP-wrapped modules into the pipeline to save memory.
        """
        self.inference_pipeline = BidirectionalTrainingPipeline(
            model_name=self.generator_name,
            denoising_step_list=self.denoising_step_list,
            scheduler=self.scheduler,
            generator=self.generator,
        )

    def gen_video(   
        self, idx,
        prompt, image_or_video_shape,
        conditional_dict, unconditional_dict
    ):
        with torch.no_grad():
            noisy_image_or_video = torch.randn(*image_or_video_shape, device=self.device, dtype=self.dtype) 
            noise = noisy_image_or_video
            for index, current_timestep in enumerate(self.denoising_step_list[:-1]):
                _, pred_image_or_video = self.generator(
                    noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                    timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep, 
                )  # [B, F, C, H, W]

                next_timestep = self.denoising_step_list[index + 1] * torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device)
                noisy_image_or_video = self.scheduler.add_noise(
                    pred_image_or_video.flatten(0, 1), torch.randn_like(pred_image_or_video.flatten(0, 1)),
                    next_timestep.flatten(0, 1)
                ).unflatten(0, noise.shape[:2])

            _, pred_image_or_video = self.generator(
                noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * self.denoising_step_list[-1], 
            )  # [B, F, C, H, W]

        self.vae.to(device=self.device, dtype=self.dtype)
        videos = self.vae.decode_to_pixel(pred_image_or_video)
        videos = (videos * 0.5 + 0.5).clamp(0, 1)
        videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
        self.vae.model.clear_cache()
        self.vae.to(device='cpu')
        
        base_name = f'step{self.step}-rank{self.global_rank}-{idx}'
        video_path = os.path.join(self.gen_eval_videos_dir, f'{base_name}.mp4')
        write_video(video_path, videos[0], fps=16)
        torch.cuda.empty_cache()

        return video_path


    def get_gen_end_image(self, pred_image_or_video_raw, conditional_dict, denoised_timestep_to):
        left_step = denoised_timestep_to // 250
        print(f"denoised_timestep_to = {denoised_timestep_to}, so left_step = {left_step}")

        pred_image_or_video = pred_image_or_video_raw.clone()
        noise = torch.randn(*(pred_image_or_video.shape), device=self.device, dtype=self.dtype)  # 1, 31, 16, tgt_h // 8, tgt_w // 8
        
        for current_timestep in self.denoising_step_list[-left_step:]:
            with torch.no_grad():
                cur_timestep = current_timestep * torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device)

                noisy_image_or_video = self.scheduler.add_noise(
                    pred_image_or_video.flatten(0, 1),
                    torch.randn_like(pred_image_or_video.flatten(0, 1)),
                    cur_timestep.flatten(0, 1)
                ).unflatten(0, noise.shape[:2])

                _, pred_image_or_video = self.generator(
                    noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                    timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep
                )  # [B, F, C, H, W]

        return pred_image_or_video