from videox_fun.models.wan_image_encoder import CLIPModel
from videox_fun.models.wan_vae import AutoencoderKLWan
from videox_fun.models.wan_text_encoder import WanT5EncoderModel
from videox_fun.pipeline.flow_match import get_sigma_from_timestep
from transformers import AutoTokenizer
import torch.nn.functional as F
from typing import Tuple
from torch import nn
import torch
import os
import math
import random
from omegaconf import OmegaConf
from diffusers import FlowMatchEulerDiscreteScheduler
from videox_fun.models import WanTransformer3DModel, CausalWanModel, Discriminator
import contextlib
from videox_fun.pipeline.self_forcing_training import SelfForcingTrainingPipeline


def get_diffusion_wrapper(model_name):
    return {
        "wan": WanTransformer3DModel,
        "causal_wan": CausalWanModel
    }[model_name]

def load_state_dict(model, transformer_path):
    if transformer_path is not None:
        print(f"From checkpoint: {transformer_path}")
        if transformer_path.endswith("safetensors"):
            from safetensors.torch import load_file, safe_open
            state_dict = load_file(transformer_path)
        else:
            state_dict = torch.load(transformer_path, map_location="cpu")
        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict

        m, u = model.load_state_dict(state_dict, strict=False)
        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0

def filter_kwargs(cls, kwargs):
    import inspect
    sig = inspect.signature(cls.__init__)
    valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
    filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
    return filtered_kwargs


class DMD(nn.Module):
    def __init__(self, args, device, dtype):
        """
        Initialize the DMD (Distribution Matching Distillation) module.
        This class is self-contained and compute generator and fake score losses
        in the forward pass.
        """
        super().__init__()

        # Step 1: Initialize all models

        self.generator_model_name = getattr(
            args, "generator_name", args.model_name)
        self.real_model_name = getattr(args, "real_name", args.model_name)
        self.fake_model_name = getattr(args, "fake_name", args.model_name)

        self.generator_task_type = getattr(
            args, "generator_task_type", args.generator_task)
        self.real_task_type = getattr(
            args, "real_task_type", args.generator_task)
        self.fake_task_type = getattr(
            args, "fake_task_type", args.generator_task)

        self.generator = get_diffusion_wrapper(model_name=self.generator_model_name).from_pretrained(
            args.model_path,
            transformer_additional_kwargs=OmegaConf.to_container(args['transformer_additional_kwargs']))
        load_state_dict(self.generator, args.generator_transformer_path)
        self.generator.requires_grad_(True)
        self.generator.train()
        self.generator.to(device=device, dtype=dtype)

        self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
        self.num_last_frames_with_grad = getattr(args, "num_last_frames_with_grad", 21)

        if self.num_frame_per_block > 1:
            self.generator.num_frame_per_block = self.num_frame_per_block

        self.real_score = get_diffusion_wrapper(model_name=self.real_model_name).from_pretrained(
            args.model_path,
            transformer_additional_kwargs=OmegaConf.to_container(args['transformer_additional_kwargs']))
        load_state_dict(self.real_score, args.critic_transformer_path)
        self.real_score.requires_grad_(False).eval()
        self.real_score.to(device=device, dtype=dtype)

        self.fake_score = get_diffusion_wrapper(model_name=self.fake_model_name).from_pretrained(
            args.model_path,
            transformer_additional_kwargs=OmegaConf.to_container(args['transformer_additional_kwargs']))
        if hasattr(args, 'fake_transformer_path') and args.fake_transformer_path is not None:
            print(f"Fake score loading from checkpoint: {args.fake_transformer_path}")
            load_state_dict(self.fake_score, args.fake_transformer_path)
        else:
            load_state_dict(self.fake_score, args.critic_transformer_path)
        self.fake_score.requires_grad_(True)
        self.fake_score.train()
        self.fake_score.to(device=device, dtype=dtype)

        if args.gradient_checkpointing:
            self.generator.enable_gradient_checkpointing()
            self.fake_score.enable_gradient_checkpointing()

        self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.model_path, 'google/umt5-xxl/'), seq_len=512, clean='whitespace')
        self.text_encoder = WanT5EncoderModel.from_pretrained(
            os.path.join(args.model_path, args.text_encoder_kwargs.get('text_encoder_subpath', 'text_encoder')),
            additional_kwargs=OmegaConf.to_container(args.text_encoder_kwargs))
        self.text_encoder.requires_grad_(False).eval()
        self.text_encoder.to(device=device, dtype=dtype)

        self.vae = AutoencoderKLWan.from_pretrained(os.path.join(args.model_path, 'Wan2.1_VAE.pth'))
        self.vae.requires_grad_(False).eval()
        self.vae.to(device=device, dtype=dtype)

        self.clip_image_encoder = CLIPModel.from_pretrained(os.path.join(args.model_path, 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'))
        self.clip_image_encoder.requires_grad_(False).eval()
        self.clip_image_encoder.to(device=device, dtype=dtype)

        self.inference_pipeline = None

        # Step 2: Initialize all dmd hyperparameters

        self.denoising_step_list = torch.tensor(
            args.denoising_step_list, dtype=torch.long, device=device)
        self.num_train_timestep = args.scheduler_kwargs.num_train_timesteps
        self.min_step = int(0.02 * self.num_train_timestep)
        self.max_step = int(0.98 * self.num_train_timestep)
        self.real_guidance_scale = args.real_guidance_scale
        self.timestep_shift = args.scheduler_kwargs.get('shift', 1.0)

        self.args = args
        self.device = device
        self.dtype = dtype
        self.scheduler = FlowMatchEulerDiscreteScheduler(
            **filter_kwargs(FlowMatchEulerDiscreteScheduler, args.scheduler_kwargs)
        )
        print(f"DMD scheduler: {self.scheduler}")
        denoising_sigma_list = get_sigma_from_timestep(self.scheduler, self.denoising_step_list, self.dtype)
        print(f"DMD timestep_shift = {self.timestep_shift}, denoising_step_list = {self.denoising_step_list}, denoising_sigma_list = {denoising_sigma_list}")
    
        self.patch_size = self.generator.config.patch_size

        self.discriminator = None
        if self.args.adv_g_loss_weight > 0:
            self.discriminator = Discriminator(
                stride=self.args.output_features_stride, num_h_per_head=self.args.discriminator_num_heads, 
                adapter_channel_dims=[self.fake_score.dim], 
                adapter_out_dim=math.prod(self.fake_score.patch_size) * self.fake_score.out_dim,
                total_layers=self.fake_score.num_layers
            ).to(dtype=dtype, device=device)
            print(f"Initialize discriminator, prepare for gan training")
            self.discriminator.requires_grad_(True)
            self.discriminator.train()

    def _process_timestep(self, timestep: torch.Tensor, type: str) -> torch.Tensor:
        """
        Pre-process the randomly generated timestep based on the generator's task type.
        Input:
            - timestep: [batch_size, num_frame] tensor containing the randomly generated timestep.
            - type: a string indicating the type of the current model (image, bidirectional_video, or causal_video).
        Output Behavior:
            - image: check that the second dimension (num_frame) is 1.
            - bidirectional_video: broadcast the timestep to be the same for all frames.
            - causal_video: broadcast the timestep to be the same for all frames **in a block**.
        """
        if type == "image":
            assert timestep.shape[1] == 1
            return timestep
        elif type == "bidirectional_video":
            for index in range(timestep.shape[0]):
                timestep[index] = timestep[index, 0]
            return timestep
        elif type == "causal_video":
            # make the noise level the same within every motion block
            timestep = timestep.reshape(
                timestep.shape[0], -1, self.num_frame_per_block)
            timestep[:, :, 1:] = timestep[:, :, 0:1]
            timestep = timestep.reshape(timestep.shape[0], -1)
            return timestep
        else:
            raise NotImplementedError("Unsupported model type {}".format(type))

    def _compute_kl_grad(
        self, noisy_image_or_video: torch.Tensor,
        estimated_clean_image_or_video: torch.Tensor,
        timestep: torch.Tensor,
        conditional_dict: dict, unconditional_dict: dict,
        inpaint_latents: torch.Tensor, clip_context: torch.Tensor,
        normalization: bool = True
    ) -> Tuple[torch.Tensor, dict]:
        """
        Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
        Input:
            - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
            - estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
            - timestep: a tensor with shape [B, F] containing the randomly generated timestep.
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
            - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
            - normalization: a boolean indicating whether to normalize the gradient.
        Output:
            - kl_grad: a tensor representing the KL grad.
            - kl_log_dict: a dictionary containing the intermediate tensors for logging.
        """
        seq_len = noisy_image_or_video.shape[1] * noisy_image_or_video.shape[3] * noisy_image_or_video.shape[4] // self.patch_size[1] // self.patch_size[2]
        # Step 1: Compute the fake score
        pred_fake_image = self.fake_score(
            x=noisy_image_or_video.permute(0,2,1,3,4),
            context=conditional_dict['prompt_embeds'],
            t=timestep[:,0],
            seq_len=seq_len,
            y=inpaint_latents.permute(0,2,1,3,4),
            clip_fea=clip_context
        ).permute(0,2,1,3,4)
        pred_fake_image = self.convert_flow_pred_to_x0(pred_fake_image, noisy_image_or_video, timestep)

        # Step 2: Compute the real score
        # We compute the conditional and unconditional prediction
        # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
        pred_real_image_cond = self.real_score(
            x=noisy_image_or_video.permute(0,2,1,3,4),
            context=conditional_dict['prompt_embeds'],
            t=timestep[:,0],
            seq_len=seq_len,
            y=inpaint_latents.permute(0,2,1,3,4),
            clip_fea=clip_context
        ).permute(0,2,1,3,4)
        pred_real_image_cond = self.convert_flow_pred_to_x0(pred_real_image_cond, noisy_image_or_video, timestep)

        pred_real_image_uncond = self.real_score(
            x=noisy_image_or_video.permute(0,2,1,3,4),
            context=unconditional_dict['prompt_embeds'],
            t=timestep[:,0],
            seq_len=seq_len,
            y=inpaint_latents.permute(0,2,1,3,4),
            clip_fea=clip_context
        ).permute(0,2,1,3,4)
        pred_real_image_uncond = self.convert_flow_pred_to_x0(pred_real_image_uncond, noisy_image_or_video, timestep)

        pred_real_image = pred_real_image_uncond + (
            pred_real_image_cond - pred_real_image_uncond
        ) * self.real_guidance_scale

        # Step 3: Compute the DMD gradient (DMD paper eq. 7).
        grad = (pred_fake_image - pred_real_image)

        # TODO: Change the normalizer for causal teacher
        if normalization:
            # Step 4: Gradient normalization (DMD paper eq. 8).
            p_real = (estimated_clean_image_or_video - pred_real_image)
            normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
            grad = grad / normalizer
        grad = torch.nan_to_num(grad)

        return grad, {
            "dmdtrain_clean_latent": estimated_clean_image_or_video.detach(),
            "dmdtrain_noisy_latent": noisy_image_or_video.detach(),
            "dmdtrain_pred_real_image": pred_real_image.detach(),
            "dmdtrain_pred_fake_image": pred_fake_image.detach(),
            "dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
            "timestep": timestep.detach()
        }

    def _compute_cls_logits(self, image_or_video, next_timestep, conditional_dict, 
                inpaint_latents: torch.Tensor, clip_context: torch.Tensor,
                no_grad=False):
        # timesteps = torch.zeros([image_or_video.shape[0]], dtype=torch.long, device=image_or_video.device)
        noise = torch.randn_like(image_or_video)
        image_or_video = self.add_noise(
            image_or_video.flatten(0, 1),
            noise.flatten(0, 1),
            next_timestep.flatten(0, 1)
        ).unflatten(0, next_timestep.shape)
        timesteps = next_timestep
        seq_len = image_or_video.shape[1] * image_or_video.shape[3] * image_or_video.shape[4] // self.patch_size[1] // self.patch_size[2]

        if no_grad:
            context_manager = torch.no_grad()
            # network = self.fake_score
            network = self.real_score
        else:
            context_manager = contextlib.nullcontext()
            network = self.fake_score

        with context_manager:
            features = network(
                x=image_or_video.permute(0,2,1,3,4),
                context=conditional_dict['prompt_embeds'],
                t=timesteps,
                seq_len=seq_len,
                y=inpaint_latents.permute(0,2,1,3,4),
                clip_fea=clip_context,
                output_features=True, output_features_stride=self.args.output_features_stride
            )
        logits = self.discriminator(features)
        return logits

    def compute_gan_g_loss(self, fake_image, next_timestep, conditional_dict, inpaint_latents, clip_context, gradient_mask=None):
        loss_dict = {} 

        logits = self._compute_cls_logits(
            fake_image, next_timestep, conditional_dict, inpaint_latents, clip_context
        )
        loss = 0.0
        for logit in logits:
            # if gradient_mask is not None:
            #     logit = logit[gradient_mask.expand_as(logit)]
            loss += F.softplus(-logit).mean()
        loss /= len(logits)
        loss_dict["gan_g_loss"] = loss
        return loss_dict 

    def compute_gan_d_loss(self, real_image, fake_image, next_timestep, conditional_dict, inpaint_latents, clip_context):
        loss_dict = {}

        # if self.args.gradient_checkpointing:
        #     getattr(self.fake_score, 'module', self.fake_score).disable_gradient_checkpointing()
        # real_logits = self._compute_cls_logits(
        #     real_image.detach(), conditional_dict, inpaint_latents, clip_context, no_grad=True
        # )
        # fake_logits = self._compute_cls_logits(
        #     fake_image.detach(), conditional_dict, inpaint_latents, clip_context, no_grad=True
        # )  # TODO: no_grad True or False ???
        # if self.args.gradient_checkpointing:
        #     getattr(self.fake_score, 'module', self.fake_score).enable_gradient_checkpointing()

        merge_image = torch.cat([real_image.detach(), fake_image.detach()], dim=0)
        merge_conditional_dict = {
            "prompt_embeds": torch.cat([conditional_dict["prompt_embeds"], conditional_dict["prompt_embeds"]], dim=0)
        }
        merge_inpaint_latents = torch.cat([inpaint_latents, inpaint_latents], dim=0)
        merge_clip_context = torch.cat([clip_context, clip_context], dim=0)
        merge_next_timestep = torch.cat([next_timestep, next_timestep], dim=0)
        merge_logits = self._compute_cls_logits(
            merge_image, merge_next_timestep, merge_conditional_dict, merge_inpaint_latents, merge_clip_context, no_grad=False
        )
        real_logits, fake_logits = [], []
        for logits in merge_logits:
            x, y = logits.chunk(2, dim=0)
            real_logits.append(x)
            fake_logits.append(y)

        loss = 0.0
        for real_logit, fake_logit in zip(real_logits, fake_logits):
            loss += F.softplus(-real_logit).mean() + F.softplus(fake_logit).mean()
        loss /= len(real_logits)
        loss_dict["gan_d_loss"] = loss
        return loss_dict
    
    def add_noise(self, latents, noise, timestep):
        # 专门处理 timestep==0 的情况，将 sigma 也置为 0
        sigma = get_sigma_from_timestep(self.scheduler, timestep, latents.dtype)
        sigma = sigma.reshape(-1, 1, 1, 1)
        noisy_latents = (1.0 - sigma) * latents + sigma * noise
        return noisy_latents

    def convert_flow_pred_to_x0(self, flow_pred, xt, timestep):
        num_frames = 0
        if timestep.ndim == 2:
            batch_size, num_frames = timestep.shape
            timestep = timestep.flatten(0, 1)
        sigma = get_sigma_from_timestep(self.scheduler, timestep, xt.dtype)
        if num_frames > 0:
            assert flow_pred.ndim == 5
            sigma = sigma.reshape(batch_size, num_frames, 1, 1, 1)
        else:
            assert flow_pred.ndim == 4
            sigma = sigma.reshape(-1, 1, 1, 1)
        x0_pred = xt - sigma * flow_pred
        return x0_pred
    
    def encode_text(self, text):
        if isinstance(text, str):
            text = [text]
        prompt_ids = self.tokenizer(
            text, 
            padding="max_length", 
            max_length=512, 
            truncation=True,
            add_special_tokens=True, 
            return_tensors="pt"
        )
        text_input_ids = prompt_ids.input_ids.to(self.device)
        prompt_attention_mask = prompt_ids.attention_mask.to(self.device)
        seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
        context = self.text_encoder(text_input_ids, prompt_attention_mask)[0]

        for u, v in zip(context, seq_lens):
            u[v:] = 0.0  # set padding to 0.0
        res = {
            "prompt_embeds": context.to(self.device, self.dtype),
        }
        return res

    def compute_distribution_matching_loss(
        self, image_or_video: torch.Tensor, conditional_dict: dict,
        unconditional_dict: dict, gradient_mask: torch.Tensor = None,
        inpaint_latents: torch.Tensor = None, clip_context: torch.Tensor = None
    ) -> Tuple[torch.Tensor, dict]:
        """
        Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
        Input:
            - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
            - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
            - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
        Output:
            - dmd_loss: a scalar tensor representing the DMD loss.
            - dmd_log_dict: a dictionary containing the intermediate tensors for logging.
        """
        original_latent = image_or_video

        batch_size, num_frame = image_or_video.shape[:2]

        with torch.no_grad():
            # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
            timestep = torch.randint(
                0,
                self.num_train_timestep,
                [batch_size, num_frame],
                device=self.device,
                dtype=torch.long
            )

            timestep = self._process_timestep(
                timestep, type=self.real_task_type)

            # TODO: Add timestep warping
            if self.timestep_shift > 1:
                timestep = self.timestep_shift * \
                    (timestep / 1000) / \
                    (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
            timestep = timestep.clamp(self.min_step, self.max_step)
            print(f"add noise to generator simulation with timestep: {timestep}")

            noise = torch.randn_like(image_or_video)
            noisy_latent = self.add_noise(
                image_or_video.flatten(0, 1),
                noise.flatten(0, 1),
                timestep.flatten(0, 1)
            ).detach().unflatten(0, (batch_size, num_frame))

            # Step 2: Compute the KL grad
            grad, dmd_log_dict = self._compute_kl_grad(
                noisy_image_or_video=noisy_latent,
                estimated_clean_image_or_video=original_latent,
                timestep=timestep,
                conditional_dict=conditional_dict,
                unconditional_dict=unconditional_dict,
                inpaint_latents=inpaint_latents,
                clip_context=clip_context,
            )

        if gradient_mask is not None:
            gradient_mask = gradient_mask.expand_as(original_latent)
            dmd_loss = 0.5 * F.mse_loss(original_latent.double(
            )[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
        else:
            dmd_loss = 0.5 * F.mse_loss(original_latent.double(
            ), (original_latent.double() - grad.double()).detach(), reduction="mean")
        return dmd_loss, dmd_log_dict

    def _initialize_inference_pipeline(self):
        """
        Lazy initialize the inference pipeline during the first backward simulation run.
        Here we encapsulate the inference code with a model-dependent outside function.
        We pass our FSDP-wrapped modules into the pipeline to save memory.
        """
        # self.inference_pipeline = get_inference_pipeline_wrapper(
        #     self.generator_model_name,
        #     denoising_step_list=self.denoising_step_list,
        #     scheduler=self.scheduler,
        #     generator=self.generator,
        #     num_frame_per_block=self.num_frame_per_block
        # )
        self.inference_pipeline = SelfForcingTrainingPipeline(
            denoising_step_list=self.denoising_step_list,
            scheduler=self.scheduler,
            generator=self.generator,
            num_frame_per_block=self.num_frame_per_block,
            num_last_frames_with_grad= self.num_last_frames_with_grad,
        )

    def _consistency_backward_simulation(self, noise: torch.Tensor, **conditional_dict: dict) -> torch.Tensor:
        """
        Simulate the generator's input from noise to avoid training/inference mismatch.
        See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
        Here we use the consistency sampler (https://arxiv.org/abs/2303.01469)
        Input:
            - noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
        Output:
            - output: a tensor with shape [B, T, F, C, H, W].
            T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
            represents the x0 prediction at each timestep.
        """
        if self.inference_pipeline is None:
            self._initialize_inference_pipeline()

        return self.inference_pipeline.inference_with_trajectory(noise=noise, **conditional_dict)

    def _run_generator(self, image_or_video_shape, conditional_dict: dict, unconditional_dict: dict, clean_latent: torch.tensor,
                inpaint_latents: torch.Tensor, clip_context: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Optionally simulate the generator's input from noise using backward simulation
        and then run the generator for one-step.
        Input:
            - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
            - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
            - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
        Output:
            - pred_image: a tensor with shape [B, F, C, H, W].
        """
        # Step 1: Sample noise and backward simulate the generator's input
        if getattr(self.args, "backward_simulation", True):
            pred_image_or_video, pred_image_or_video_x0, next_timestep = self._consistency_backward_simulation(
                noise=torch.randn(image_or_video_shape, device=self.device, dtype=self.dtype),
                initial_latent=clean_latent[:, :1],
                inpaint_latents=inpaint_latents,
                clip_context=clip_context,
                **conditional_dict,
            )
            bsz, num_frames = image_or_video_shape[0], image_or_video_shape[1]
            gradient_mask = torch.zeros((bsz, num_frames, 1, 1, 1), device=self.device, dtype=torch.bool)
            gradient_mask[:, -self.num_last_frames_with_grad:, 0, 0, 0] = 1
            return pred_image_or_video, pred_image_or_video_x0, gradient_mask, next_timestep
        else:
            # Randomly sample a timestep and pick the corresponding input
            # index = torch.randint(0, len(self.denoising_step_list), [
            #                       image_or_video_shape[0], image_or_video_shape[1]], device=self.device, dtype=torch.long)
            bsz, num_frames = image_or_video_shape[0], image_or_video_shape[1]
            timestep_shape = (bsz, num_frames)
            self.denoising_step_list = torch.sort(self.denoising_step_list)[0]
            if self.args.timestep_sampling_method == 'uniform':
                index = torch.randint(0, len(self.denoising_step_list), timestep_shape, device=self.device, dtype=torch.long)
            elif self.args.timestep_sampling_method == 'mono_inc':
                assert self.denoising_step_list[0] == 0 and self.denoising_step_list[1] == 20
                index = torch.zeros(timestep_shape, device=self.device, dtype=torch.long)
                for i in range(bsz):
                    prefix_end = random.randint(0, num_frames - 2)
                    index[i, 1:prefix_end+1] = 1  # 0 for first frame, 20 for perturbed prefix
                    positions = sorted(random.choices(range(2, len(self.denoising_step_list)), k=num_frames-prefix_end-1))
                    index[i, prefix_end+1:] = torch.Tensor(positions).to(device=self.device).long()

            index = self._process_timestep(index, type=self.generator_task_type)
            timestep = self.denoising_step_list[index]
            next_index = index - 1
            if self.denoising_step_list[1] == 20:
                next_index[next_index == 1] = 0
            next_timestep = self.denoising_step_list[next_index]
            next_timestep[next_index < 0] = 0
            print(f"generator index: {index}, timestep: {timestep}, next_timestep: {next_timestep}")

            noise = torch.randn(
                image_or_video_shape, device=self.device, dtype=self.dtype)
            noisy_input = self.add_noise(
                clean_latent.flatten(0, 1),
                noise.flatten(0, 1),
                timestep.flatten(0, 1)
            ).unflatten(0, image_or_video_shape[:2])

            seq_len = image_or_video_shape[1] * image_or_video_shape[3] * image_or_video_shape[4] // self.patch_size[1] // self.patch_size[2]
            flow_pred = self.generator(
                x=noisy_input.permute(0,2,1,3,4),
                context=conditional_dict['prompt_embeds'],
                t=timestep[:,0] if self.generator_task_type == "bidirectional_video" else timestep,
                seq_len=seq_len,
                y=inpaint_latents.permute(0,2,1,3,4),
                clip_fea=clip_context,
            ).permute(0,2,1,3,4)
            pred_image_or_video = self.convert_flow_pred_to_x0(flow_pred, noisy_input, timestep)

            gradient_mask = (timestep > 20).view(bsz, num_frames, 1, 1, 1)

            # pred_image_or_video = noisy_input * \
            #     (1-gradient_mask.float()).reshape(*gradient_mask.shape, 1, 1, 1) + \
            #     pred_image_or_video * gradient_mask.float().reshape(*gradient_mask.shape, 1, 1, 1)

            pred_image_or_video = pred_image_or_video.type_as(noisy_input)

            return pred_image_or_video, noisy_input, gradient_mask, next_timestep

    def generator_loss(self, image_or_video_shape, conditional_dict: dict, unconditional_dict: dict, clean_latent: torch.Tensor,
                inpaint_latents: torch.Tensor=None, clip_context: torch.Tensor=None) -> Tuple[torch.Tensor, dict]:
        """
        Generate image/videos from noise and compute the DMD loss.
        The noisy input to the generator is backward simulated.
        This removes the need of any datasets during distillation.
        See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
        Input:
            - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
            - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
            - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
        Output:
            - loss: a scalar tensor representing the generator loss.
            - generator_log_dict: a dictionary containing the intermediate tensors for logging.
        """
        # Step 1: Run generator on backward simulated noisy input
        pred_image, generator_noisy_input, gradient_mask, next_timestep = self._run_generator(
            image_or_video_shape=image_or_video_shape,
            conditional_dict=conditional_dict,
            unconditional_dict=unconditional_dict,
            clean_latent=clean_latent,
            inpaint_latents=inpaint_latents,
            clip_context=clip_context,
        )

        loss_dict = {}
        # Step 2: Compute the DMD loss
        dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
            image_or_video=pred_image,
            conditional_dict=conditional_dict,
            unconditional_dict=unconditional_dict,
            gradient_mask=gradient_mask,
            inpaint_latents=inpaint_latents,
            clip_context=clip_context,
        )
        loss_dict["dmd_loss"] = dmd_loss
        if generator_noisy_input is not None:
            dmd_log_dict['generator_noisy_input'] = generator_noisy_input.detach()

        # Step 3: Compute the GAN loss
        if self.discriminator is not None:
            g_loss = self.compute_gan_g_loss(
                pred_image, next_timestep, conditional_dict, inpaint_latents, clip_context, gradient_mask=gradient_mask,
            )
            loss_dict.update(g_loss)

        return loss_dict, dmd_log_dict

    def critic_loss(self, image_or_video_shape, conditional_dict: dict, unconditional_dict: dict, clean_latent: torch.Tensor,
                inpaint_latents: torch.Tensor, clip_context: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Generate image/videos from noise and train the critic with generated samples.
        The noisy input to the generator is backward simulated.
        This removes the need of any datasets during distillation.
        See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
        Input:
            - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
            - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
            - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
        Output:
            - loss: a scalar tensor representing the generator loss.
            - critic_log_dict: a dictionary containing the intermediate tensors for logging.
        """

        # Step 1: Run generator on backward simulated noisy input
        with torch.no_grad():
            generated_image, _, _, next_timestep = self._run_generator(
                image_or_video_shape=image_or_video_shape,
                conditional_dict=conditional_dict,
                unconditional_dict=unconditional_dict,
                clean_latent=clean_latent,
                inpaint_latents=inpaint_latents,
                clip_context=clip_context,
            )

        # Step 2: Compute the fake prediction
        critic_timestep = torch.randint(
            0,
            self.num_train_timestep,
            image_or_video_shape[:2],
            device=self.device,
            dtype=torch.long
        )
        critic_timestep = self._process_timestep(
            critic_timestep, type=self.fake_task_type)

        # TODO: Add timestep warping
        if self.timestep_shift > 1:
            critic_timestep = self.timestep_shift * \
                (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000

        critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)

        critic_noise = torch.randn_like(generated_image)
        noisy_generated_image = self.add_noise(
            generated_image.flatten(0, 1),
            critic_noise.flatten(0, 1),
            critic_timestep.flatten(0, 1)
        ).unflatten(0, image_or_video_shape[:2])

        seq_len = image_or_video_shape[1] * image_or_video_shape[3] * image_or_video_shape[4] // self.patch_size[1] // self.patch_size[2]
        pred_fake_image_flow = self.fake_score(
            x=noisy_generated_image.permute(0,2,1,3,4),
            context=conditional_dict['prompt_embeds'],
            t=critic_timestep[:,0],
            seq_len=seq_len,
            y=inpaint_latents.permute(0,2,1,3,4),
            clip_fea=clip_context
        ).permute(0,2,1,3,4)
        pred_fake_image = self.convert_flow_pred_to_x0(
            pred_fake_image_flow, noisy_generated_image, critic_timestep)
        target = critic_noise - generated_image
        denoising_loss = torch.mean((pred_fake_image_flow - target) ** 2)

        loss_dict = {}
        loss_dict['critic_loss'] = denoising_loss
        # Step 4: Compute the GAN loss
        if self.discriminator is not None:
            d_loss = self.compute_gan_d_loss(
                real_image=clean_latent, 
                fake_image=generated_image,
                next_timestep=next_timestep,
                conditional_dict=conditional_dict, inpaint_latents=inpaint_latents, clip_context=clip_context
            )
            loss_dict.update(d_loss)

        # Step 5: Debugging Log
        critic_log_dict = {
            "critictrain_latent": generated_image.detach(),
            "critictrain_noisy_latent": noisy_generated_image.detach(),
            "critictrain_pred_image": pred_fake_image.detach(),
            "critic_timestep": critic_timestep.detach()
        }

        return loss_dict, critic_log_dict
