from videox_fun.models.wan_wrapper import WanDiffusionWrapper
from videox_fun.utils.scheduler import SchedulerInterface
from typing import List, Optional
import torch
import torch.distributed as dist
import math

class SelfForcingTrainingPipeline:
    def __init__(self,
                 denoising_step_list: List[int],
                 scheduler: SchedulerInterface,
                 generator: WanDiffusionWrapper,
                 num_frame_per_block=1,
                 num_max_frames: int = 21,
                 num_last_frames_with_grad: int = 21,  # number of suffix frames to backpropagate gradients
                 **kwargs):
        super().__init__()
        self.scheduler = scheduler
        self.generator = generator
        self.denoising_step_list = denoising_step_list
        if self.denoising_step_list[-1] == 0:
            self.denoising_step_list = self.denoising_step_list[:-1]  # remove the zero timestep for inference
        if self.denoising_step_list[-1] == 20:
            self.denoising_step_list = self.denoising_step_list[:-1]
        print(f"SelfForcingTrainingPipeline using denoising steps: {self.denoising_step_list}")

        # Wan specific hyperparameters
        self.num_transformer_blocks = 30
        self.frame_seq_length = 1560
        self.num_frame_per_block = num_frame_per_block
        self.num_last_frames_with_grad = num_last_frames_with_grad

        self.kv_cache1 = None
        self.kv_cache2 = None
        self.kv_cache_size = num_max_frames * self.frame_seq_length

    def generate_and_sync_list(self, num_blocks, num_denoising_steps, device):
        rank = dist.get_rank() if dist.is_initialized() else 0

        if rank == 0:
            # Generate random indices
            indices = torch.randint(
                low=0,
                high=num_denoising_steps,
                size=(num_blocks,),
                device=device
            )
        else:
            indices = torch.empty(num_blocks, dtype=torch.long, device=device)

        dist.broadcast(indices, src=0)  # Broadcast the random indices to all ranks
        return indices.tolist()

    def inference_with_trajectory(
            self,
            noise: torch.Tensor,
            prompt_embeds: torch.Tensor,
            initial_latent: torch.Tensor,  # first frame latent
            inpaint_latents: torch.Tensor,
            clip_context: torch.Tensor,
    ) -> torch.Tensor:
        batch_size, num_frames, num_channels, height, width = noise.shape
        output = []
        output_x0 = torch.zeros_like(noise)
        next_timesteps = torch.zeros([batch_size, num_frames], device=noise.device, dtype=torch.int64)
        seq_len = math.ceil((height * width) / (self.generator.module.config.patch_size[1] * self.generator.module.config.patch_size[2]))  # token sequence length per frame
        num_blocks = num_frames // self.num_frame_per_block

        # Step 1: Initialize KV cache to all zeros
        self._initialize_kv_cache(
            batch_size=batch_size, dtype=noise.dtype, device=noise.device
        )
        self._initialize_crossattn_cache(
            batch_size=batch_size, dtype=noise.dtype, device=noise.device
        )

        # Step 3: Temporal denoising loop
        num_denoising_steps = len(self.denoising_step_list)
        exit_flags = self.generate_and_sync_list(num_frames, num_denoising_steps, device=noise.device)
        start_gradient_frame_index = num_frames - self.num_last_frames_with_grad
        shift_cache_length = 0

        for block_index in range(num_blocks):
            block_latents = noise[:, block_index * self.num_frame_per_block:(block_index + 1) * self.num_frame_per_block]
            block_y = inpaint_latents[:, block_index * self.num_frame_per_block:(block_index + 1) * self.num_frame_per_block]

            # 如果超出 cache 最大长度，则整体左移，丢弃最左边的部分
            current_start = block_index * self.num_frame_per_block * seq_len - shift_cache_length
            current_end = (block_index + 1) * self.num_frame_per_block * seq_len - shift_cache_length
            block_len =  current_end - current_start
            while current_end > self.kv_cache1[0]["k"].shape[1]:
                shift_cache_length += block_len
                current_end -= block_len
                current_start -= block_len
                for layer_kv_cache in self.kv_cache1:
                    layer_kv_cache["k"][:, :-block_len] = layer_kv_cache["k"][:, block_len:].clone()
                    layer_kv_cache["v"][:, :-block_len] = layer_kv_cache["v"][:, block_len:].clone()
                print(f"Shift cache length: {shift_cache_length}, current_start: {current_start}, current_end: {current_end}")

            if block_index == 0:
                denoised_pred = initial_latent
                next_timestep_index = len(self.denoising_step_list)
                next_timestep = 0
            else:
                # Step 3.1: Spatial denoising loop
                for index, current_timestep in enumerate(self.denoising_step_list):
                    exit_flag = (index == exit_flags[block_index])  # Only backprop at the randomly selected timestep (consistent across all ranks)
                    timestep = torch.ones(
                        [batch_size, self.num_frame_per_block],
                        device=noise.device,
                        dtype=torch.int64) * current_timestep

                    # print(f">  block_index: {block_index}, index: {index}, current_timestep: {current_timestep}, exit_flag: {exit_flag}")
                    if not exit_flag:
                        with torch.no_grad():
                            noise_pred = self.generator(
                                x=block_latents.permute(0,2,1,3,4),
                                context=prompt_embeds,
                                t=timestep,
                                seq_len=seq_len * self.num_frame_per_block,
                                y=block_y.permute(0,2,1,3,4),
                                clip_fea=clip_context,
                                kv_cache=self.kv_cache1,
                                crossattn_cache=self.crossattn_cache,
                                current_start=current_start,
                                current_end=current_end,
                                shift_cache_length=shift_cache_length,
                            ).permute(0,2,1,3,4)
                            next_timestep = torch.ones_like(timestep) * self.denoising_step_list[index + 1]
                            block_latents = self._convert_flow_pred_to_next_step(
                                noise_pred, block_latents, timestep, next_timestep
                            )
                    else:
                        # for getting real output
                        # with torch.set_grad_enabled(current_start_frame >= start_gradient_frame_index):
                        if block_index < start_gradient_frame_index:
                            with torch.no_grad():
                                noise_pred = self.generator(
                                    x=block_latents.permute(0,2,1,3,4),
                                    context=prompt_embeds,
                                    t=timestep,
                                    seq_len=seq_len * self.num_frame_per_block,
                                    y=block_y.permute(0,2,1,3,4),
                                    clip_fea=clip_context,
                                    kv_cache=self.kv_cache1,
                                    crossattn_cache=self.crossattn_cache,
                                    current_start=current_start,
                                    current_end=current_end,
                                    shift_cache_length=shift_cache_length,
                                ).permute(0,2,1,3,4)
                        else:
                            noise_pred = self.generator(
                                x=block_latents.permute(0,2,1,3,4),
                                context=prompt_embeds,
                                t=timestep,
                                seq_len=seq_len * self.num_frame_per_block,
                                y=block_y.permute(0,2,1,3,4),
                                clip_fea=clip_context,
                                kv_cache=self.kv_cache1,
                                crossattn_cache=self.crossattn_cache,
                                current_start=current_start,
                                current_end=current_end,
                                shift_cache_length=shift_cache_length,
                            ).permute(0,2,1,3,4)
                        next_timestep_index = index + 1
                        if index < len(self.denoising_step_list) - 1:
                            next_timestep = torch.ones_like(timestep) * self.denoising_step_list[index + 1]
                            block_latents = self._convert_flow_pred_to_next_step(
                                noise_pred, block_latents, timestep, next_timestep
                            )
                            denoised_pred = self._convert_flow_pred_to_x0(
                                noise_pred, block_latents, timestep
                            )
                        else:
                            next_timestep = torch.zeros_like(timestep)
                            block_latents = self._convert_flow_pred_to_x0(
                                noise_pred, block_latents, timestep
                            )
                            denoised_pred = block_latents
                        break

            # Step 3.2: record the model's output
            output.append(denoised_pred)
            next_timesteps[:, block_index: block_index + 1] = next_timestep


            # Step 3.3: continue denoising to x0
            with torch.no_grad():
                for index in range(next_timestep_index, num_denoising_steps):
                    current_timestep = self.denoising_step_list[index]
                    timestep = torch.ones(
                        [batch_size, self.num_frame_per_block],
                        device=noise.device,
                        dtype=torch.int64) * current_timestep
                
                    noise_pred = self.generator(
                        x=block_latents.permute(0,2,1,3,4),
                        context=prompt_embeds,
                        t=timestep,
                        seq_len=seq_len * self.num_frame_per_block,
                        y=block_y.permute(0,2,1,3,4),
                        clip_fea=clip_context,
                        kv_cache=self.kv_cache1,
                        crossattn_cache=self.crossattn_cache,
                        current_start=current_start,
                        current_end=current_end,
                        shift_cache_length=shift_cache_length,
                    ).permute(0,2,1,3,4)
                    if index < num_denoising_steps - 1:
                        next_timestep = torch.ones_like(timestep) * self.denoising_step_list[index + 1]
                        block_latents = self._convert_flow_pred_to_next_step(
                            noise_pred, block_latents, timestep, next_timestep
                        )
                    else:
                        block_latents = self._convert_flow_pred_to_x0(
                            noise_pred, block_latents, timestep[:batch_size]
                        )
                    denoised_pred = block_latents
            output_x0[:, block_index * self.num_frame_per_block:(block_index + 1) * self.num_frame_per_block] = denoised_pred.detach()

            # Step 3.3: rerun with timestep zero to update the cache
            timestep = torch.zeros([batch_size, self.num_frame_per_block],
                device=noise.device, dtype=torch.int64)
            with torch.no_grad():
                self.generator(
                    x=denoised_pred.permute(0,2,1,3,4),
                    context=prompt_embeds,
                    t=timestep,
                    seq_len=seq_len * self.num_frame_per_block,
                    y=block_y.permute(0,2,1,3,4),
                    clip_fea=clip_context,
                    kv_cache=self.kv_cache1,
                    crossattn_cache=self.crossattn_cache,
                    current_start=current_start,
                    current_end=current_end,
                    shift_cache_length=shift_cache_length,
                )

        print(f"exit_flags: {exit_flags}, next_timesteps: {next_timesteps}")
        output = torch.cat(output, dim=1).to(noise.dtype)
        return output, output_x0, next_timesteps

    def _initialize_kv_cache(self, batch_size, dtype, device):
        """
        Initialize a Per-GPU KV cache for the Wan model.
        """
        kv_cache1 = []

        for _ in range(self.num_transformer_blocks):
            kv_cache1.append({
                "k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
                "v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
                # "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
                # "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
            })

        self.kv_cache1 = kv_cache1  # always store the clean cache

    def _initialize_crossattn_cache(self, batch_size, dtype, device):
        """
        Initialize a Per-GPU cross-attention cache for the Wan model.
        """
        crossattn_cache = []

        for _ in range(self.num_transformer_blocks):
            crossattn_cache.append({
                "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
                "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
                "is_init": False
            })
        self.crossattn_cache = crossattn_cache

    def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
        """
        Convert flow matching's prediction to x0 prediction.
        flow_pred: the prediction with shape [B, C, F, H, W]
        xt: the input noisy data with shape [B, C, F, H, W]
        timestep: the timestep with shape [B, F]

        pred = noise - x0
        x_t = (1-sigma_t) * x0 + sigma_t * noise
        we have x0 = x_t - sigma_t * pred
        see derivations https://chatgpt.com/share/67bf8589-3d04-8008-bc6e-4cf1a24e2d0e
        """
        # use higher precision for calculations
        original_dtype = flow_pred.dtype
        flow_pred, xt, sigmas, timesteps = map(
            lambda x: x.double().to(flow_pred.device), [flow_pred, xt,
                                                        self.scheduler.sigmas,
                                                        self.scheduler.timesteps]
        )

        batch_size, num_frames = timestep.shape
        timestep = timestep.flatten()
        timestep_id = torch.argmin(
            (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
        sigma_t = sigmas[timestep_id].reshape(batch_size, 1, num_frames, 1, 1)
        x0_pred = xt - sigma_t * flow_pred
        return x0_pred.to(original_dtype)

    def _convert_flow_pred_to_next_step(self, flow_pred, xt, timestep, next_timestep):
        original_dtype = flow_pred.dtype
        flow_pred, xt, sigmas, timesteps = map(
            lambda x: x.double().to(flow_pred.device), [flow_pred, xt,
                                                        self.scheduler.sigmas,
                                                        self.scheduler.timesteps]
        )

        batch_size, num_frames = timestep.shape
        timestep = timestep.flatten()
        timestep_id = torch.argmin(
            (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
        sigma_t = sigmas[timestep_id].reshape(batch_size, 1, num_frames, 1, 1)

        next_timestep = next_timestep.flatten()
        next_timestep_id = torch.argmin(
            (timesteps.unsqueeze(0) - next_timestep.unsqueeze(1)).abs(), dim=1)
        sigma_next = sigmas[next_timestep_id].reshape(batch_size, 1, num_frames, 1, 1)

        prev_sample = xt + (sigma_next - sigma_t) * flow_pred
        return prev_sample.to(original_dtype)