
from typing import Tuple
from einops import rearrange
from torch import nn
import torch.distributed as dist
import torch

from utils.loss import get_denoising_loss
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper, WanCLIPEncoder

import os 

from pipeline import BidirectionalTrainingPipeline

import torch.nn.functional as F
from typing import Optional, Tuple
import torch

from torchvision.io import write_video
from einops import rearrange
import os

from utils.rm_video_align_wrapper import VideoAlignWrapper

class T2V_DMD_DPO(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        
        self._initialize_models(args, device)

        self.global_rank = dist.get_rank()
        self.is_main_process = self.global_rank == 0
        
        self.step = 0
        
        self.gen_train_videos_dir = os.path.join(args.logdir, "gen_train_videos")
        self.gen_eval_videos_dir = os.path.join(args.logdir, "gen_eval_videos")
        self.gen_log_videos_dir = os.path.join(args.logdir, "gen_log_videos")
        if self.is_main_process:
            os.makedirs(self.gen_train_videos_dir, exist_ok=True)
            os.makedirs(self.gen_eval_videos_dir, exist_ok=True)
            os.makedirs(self.gen_log_videos_dir, exist_ok=True)

        self.device = device
        self.args = args

        self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32

        self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long) # [1000, 750, 500, 250]
        timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
        print(f"------- raw self.denoising_step_list {self.denoising_step_list}")
        self.denoising_step_list = timesteps[1000 - self.denoising_step_list] # [0, 250, 500, 750]
        print(f"------- self.denoising_step_list {self.denoising_step_list}")

        # this will be init later with fsdp-wrapped modules
        self.inference_pipeline: BidirectionalTrainingPipeline = None

        # Step 2: Initialize all dmd hyperparameters
        self.num_train_timestep = args.num_train_timestep  # 1000
        self.min_step = int(0.02 * self.num_train_timestep)  # 20
        self.max_step = int(0.98 * self.num_train_timestep)  # 980

        self.real_guidance_scale = args.guidance_scale  # 6.0
        
        self.timestep_shift = getattr(args, "timestep_shift", 5.0)  # 5.0
        self.min_score_timestep = getattr(args, "min_score_timestep", 0)


    def _initialize_models(self, args, device):
        self.real_model_name = getattr(args, "real_name", "Wan2.1-I2V-14B-720P") 
        self.fake_model_name = getattr(args, "fake_name", "Wan2.1-I2V-14B-720P") 
        self.generator_name = getattr(args, "generator_name", "Wan2.1-I2V-14B-720P") 

        self.real_model_path = getattr(args, "real_path") 
        self.fake_model_path = getattr(args, "fake_path") 
        self.generator_path = getattr(args, "generator_path")

        print("--------- loading generator model")
        self.generator = WanDiffusionWrapper(
            **getattr(args, "model_kwargs", {}), # timestep_shift: 5.0
            model_name=self.generator_name, model_path=self.generator_path, wan_type="gen"
        )
        self.generator.model.requires_grad_(True)

        print("--------- loading real_score model")
        self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, model_path=self.real_model_path, wan_type="real")
        self.real_score.model.requires_grad_(False)

        print("--------- loading fake_score model")
        self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, model_path=self.fake_model_path, wan_type="fake")
        self.fake_score.model.requires_grad_(True)

        print(f"loading WanVAEWrapper ...")
        self.vae = WanVAEWrapper(model_name=self.generator_name).eval()
        self.vae.requires_grad_(False)
        self.vae = self.vae.to(dtype=torch.bfloat16)

        self.generator.enable_gradient_checkpointing()
        self.fake_score.enable_gradient_checkpointing()

        print(f"--------- loading RewardModel ...")
        video_align_path = ""
        self.reward_model = VideoAlignWrapper(video_align_path, device)
        self.reward_model.inferencer.model.requires_grad_(False)

        self.scheduler = self.generator.get_scheduler()
        self.scheduler.timesteps = self.scheduler.timesteps.to(device)
        self.scheduler.alphas_cumprod = None  # None

    def _get_timestep(
            self,
            min_timestep: int, max_timestep: int,
            batch_size: int, num_frame: int
    ) -> torch.Tensor:
        """
        Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
        from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
        - uniform_timestep, it will use the same timestep for all frames.
        """
        timestep = torch.randint(
            min_timestep, max_timestep,
            [batch_size, 1],
            device=self.device, dtype=torch.long
        ).repeat(1, num_frame)

        if self.timestep_shift > 1:
            timestep = self.timestep_shift * (timestep / 1000) / (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000

        return timestep

    def _compute_dmd_grad(
        self, 
        noisy_image_or_video: torch.Tensor, estimated_clean_image_or_video: torch.Tensor,
        timestep: torch.Tensor,
        conditional_dict: dict, unconditional_dict: dict,
        normalization: bool = True,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: Compute the fake score
        _, pred_fake_image_cond = self.fake_score(
            noisy_image_or_video=noisy_image_or_video,
            conditional_dict=conditional_dict,
            timestep=timestep
        )

        pred_fake_image = pred_fake_image_cond
        # self.save_latent_as_video(pred_fake_image, f"gen_fake_pred_{int(timestep[0, 0].item())}")

        # Step 2: Compute the real score
        _, pred_real_image_cond = self.real_score(
            noisy_image_or_video=noisy_image_or_video,
            conditional_dict=conditional_dict,
            timestep=timestep,
        )

        _, pred_real_image_uncond = self.real_score(
            noisy_image_or_video=noisy_image_or_video,
            conditional_dict=unconditional_dict,
            timestep=timestep,
        )

        pred_real_image = pred_real_image_cond + (
            pred_real_image_cond - pred_real_image_uncond
        ) * self.real_guidance_scale
        # self.save_latent_as_video(pred_real_image, f"gen_real_pred_{int(timestep[0, 0].item())}")

        # Step 3: Compute the DMD gradient (DMD paper eq. 7).
        grad = (pred_fake_image - pred_real_image)
        # self.save_latent_as_video(grad, f"dmd_grad")

        # TODO: Change the normalizer for causal teacher
        if normalization:
            # Step 4: Gradient normalization (DMD paper eq. 8).
            p_real = (estimated_clean_image_or_video - pred_real_image)
            normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
            print(f"dmd normalizer = {normalizer}")
            grad = grad / normalizer
    
        if torch.isnan(grad).any() or torch.isinf(grad).any():
            grad = torch.nan_to_num(grad)
            print("WARNING: NaN/Inf in grad, nan_to_num applied")

        # self.save_latent_as_video(grad, f"dmd_grad_norm")
        
        dmd_log_dict = {
            "dmd_gradient_norm": torch.mean(torch.abs(grad)).detach(),
            "timestep": int(timestep[0, 0].item())
        }
        return grad, dmd_log_dict

    def _compute_dpo_grad(
        self, image_or_video, latent,
        conditional_dict: dict, unconditional_dict: dict, denoised_timestep_to,
        max_grad_norm = 0.8, 
        normalization: bool = True,
    ):
        video_latents = image_or_video.detach()

        timestep = self._get_timestep(
            self.min_score_timestep, self.num_train_timestep,
            video_latents.shape[0], video_latents.shape[1]
        )
        noise = torch.randn_like(image_or_video)
        batch_size, num_frame = image_or_video.shape[:2]
        noisy_video_latents = self.scheduler.add_noise(
                video_latents.flatten(0, 1),
                noise.flatten(0, 1), timestep.flatten(0, 1)
        ).detach().unflatten(0, (batch_size, num_frame))
        noisy_latent = self.scheduler.add_noise(
                latent.flatten(0, 1),
                noise.flatten(0, 1), timestep.flatten(0, 1)
        ).detach().unflatten(0, (batch_size, num_frame))

        with torch.no_grad():
            _, pred_bad_image = self.fake_score(
                noisy_image_or_video=noisy_video_latents,
                conditional_dict=conditional_dict,
                timestep=timestep
            )        

            _, pred_good_image = self.fake_score(
                noisy_image_or_video=noisy_latent,
                conditional_dict=conditional_dict,
                timestep=timestep
            )            

        grad = pred_bad_image - pred_good_image

        if normalization:
            p_real = (video_latents - pred_bad_image)
            # p_real = (video_latents - pred_new_image)
            normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
            print(f"grpo normalizer = {normalizer}")
            grad = grad / normalizer
        
        # 限制 grad 的绝对值均值不超过 max_grad_norm (0.8)
        raw_grad_norm = torch.mean(torch.abs(grad)).detach()
        # scale = torch.clamp(max_grad_norm / (raw_grad_norm + 1e-8), max=1.0)
        # grad = grad * scale

        # self.save_latent_as_video(grad, f"grpo_grad_norm")
        dpo_log_dict = {
            "raw_dpo_grad_norm": raw_grad_norm,
            "dpo_grad_norm": torch.mean(torch.abs(grad)).detach(),
            "dpo_timestep": int(timestep[0, 0].item())
        }
        return grad, dpo_log_dict



    def compute_distribution_matching_loss(
        self,
        image_or_video: torch.Tensor, image_or_video_end, latent,
        conditional_dict: dict, unconditional_dict: dict,
        denoised_timestep_from: int = 0, denoised_timestep_to: int = 0,
    ) -> Tuple[torch.Tensor, dict]:

        original_latent = image_or_video
        batch_size, num_frame = image_or_video.shape[:2]

        with torch.no_grad():
            # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
            timestep = self._get_timestep(
                self.min_score_timestep, self.num_train_timestep,
                batch_size, num_frame,
            )
            timestep = timestep.clamp(self.min_step, self.max_step)
            # print(f"------------[generator_loss] dmd sampled timestep = {timestep[0, 0].item()}")

            noise = torch.randn_like(image_or_video)
            noisy_latent = self.scheduler.add_noise(
                image_or_video.flatten(0, 1),
                noise.flatten(0, 1), timestep.flatten(0, 1)
            ).detach().unflatten(0, (batch_size, num_frame))

            # Step 2: Compute the dmd grad
            dmd_grad, dmd_log_dict = self._compute_dmd_grad(
                noisy_image_or_video=noisy_latent, estimated_clean_image_or_video=original_latent,
                timestep=timestep,
                conditional_dict=conditional_dict, unconditional_dict=unconditional_dict,
            )
            dpo_grad = 0
            if self.step >= 80 and self.step <= 600:
                dpo_grad, dpo_log_dict = self._compute_dpo_grad(
                    image_or_video=image_or_video_end, latent=latent,
                    conditional_dict=conditional_dict, unconditional_dict=unconditional_dict, denoised_timestep_to=denoised_timestep_to
                )
                dmd_log_dict.update(dpo_log_dict)
        def compute_t(step):
            if step < 80:
                return 0.5  
            elif step > 600:
                return 1.0  
            else:
                return 0.5
        
        t = compute_t(self.step)
        grad = (1-t) * dpo_grad + t * dmd_grad
        print(f"log_dict = {dmd_log_dict}")
        dmd_loss = 0.5 * F.mse_loss(original_latent.double(), (original_latent.double() - grad.double()).detach(), reduction="mean")
        return dmd_loss, dmd_log_dict

    def generator_loss(
        self, prompt, latent,
        image_or_video_shape,
        conditional_dict: dict, unconditional_dict: dict,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: Unroll generator to obtain fake videos
        pred_image, denoised_timestep_from, denoised_timestep_to = self._run_generator(
            image_or_video_shape=image_or_video_shape, conditional_dict=conditional_dict,
        )
        pred_image_end = self.get_gen_end_image(pred_image.detach(), conditional_dict, denoised_timestep_to)
        # print(f"------------[generator_loss] timestep from {denoised_timestep_from} to {denoised_timestep_to}")
        # Step 1.1 save video and get reward
        # self.save_latent_as_video(pred_image, f"gen_pred_{denoised_timestep_to}", prompt)
        reward_log_dict = self.save_latent_and_reward(pred_image_end, f"gen_gen_pred_{denoised_timestep_from}_{denoised_timestep_to}", prompt)

        # Step 2: Compute the DMD loss
        dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
            image_or_video=pred_image, image_or_video_end=pred_image_end, latent=latent,
            conditional_dict=conditional_dict, unconditional_dict=unconditional_dict,
            denoised_timestep_from=denoised_timestep_from, denoised_timestep_to=denoised_timestep_to
        )

        del pred_image, denoised_timestep_from, denoised_timestep_to
        dmd_log_dict.update(reward_log_dict)
        return dmd_loss, dmd_log_dict

    def critic_loss(
        self, prompt,
        image_or_video_shape,
        conditional_dict: dict, unconditional_dict: dict,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: Run generator on backward simulated noisy input
        with torch.no_grad():
            generated_image, denoised_timestep_from, denoised_timestep_to = self._run_generator(
                image_or_video_shape=image_or_video_shape, conditional_dict=conditional_dict
            )
        # print(f"------------[critic_loss] timestep from {denoised_timestep_from} to {denoised_timestep_to}")
        # Step 1.1 save video
        # self.save_latent_as_video(generated_image, f"critic_gen_{denoised_timestep_from}_{denoised_timestep_to}", prompt)
        # Step 2: Compute the fake prediction
        critic_timestep = self._get_timestep(
            self.min_score_timestep, self.num_train_timestep,
            image_or_video_shape[0], image_or_video_shape[1]  # B, F
        )
        critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)

        critic_noise = torch.randn_like(generated_image)
        noisy_generated_image = self.scheduler.add_noise(
            generated_image.flatten(0, 1), critic_noise.flatten(0, 1),
            critic_timestep.flatten(0, 1)
        ).unflatten(0, image_or_video_shape[:2])
        # print(f"------------[critic_loss] critic sampled timestep = {critic_timestep[0, 0].item()}")

        pred_flow, pred_fake_image = self.fake_score(
            noisy_image_or_video=noisy_generated_image, conditional_dict=conditional_dict,
            timestep=critic_timestep
        )
        # self.save_latent_as_video(pred_fake_image, f"critic_gen_pred_{critic_timestep[0, 0].item()}")

        # Step 3: Compute the denoising loss for the fake critic
        denoising_loss = torch.mean((pred_flow - (critic_noise - generated_image)) ** 2)
        grad = pred_flow - (critic_noise - generated_image)
        # self.save_latent_as_video(grad, f"critic_gen_grad_{denoised_timestep_from}_{denoised_timestep_to}")

        # Step 5: Debugging Log
        critic_log_dict = {
            "critic_timestep": critic_timestep.detach(),
            "critic_sft_gradient_norm": torch.mean(torch.abs(grad)).detach(),
        }

        if torch.isnan(denoising_loss).any() or torch.isinf(denoising_loss).any():
            denoising_loss = torch.nan_to_num(denoising_loss)
            print("WARNING: NaN/Inf in denoising_loss, nan_to_num applied")

        return denoising_loss, critic_log_dict

    def save_latent_as_video(self, video_latents, task_type, prompt=None, save_dir=None):
        if self.step % 5 == 4 and self.global_rank < 8:
        # if self.global_rank < 4:
            if save_dir is None:
                save_dir = self.gen_log_videos_dir

            # 1. 解码视频
            video_latents = video_latents.detach()
            self.vae.to(device=self.device, dtype=self.dtype)
            videos = self.vae.decode_to_pixel(video_latents)
            videos = (videos * 0.5 + 0.5).clamp(0, 1)
            videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
            self.vae.model.clear_cache()
            self.vae.to(device='cpu')

            # 2. 保存 mp4
            base_name = f'step{self.step}-{task_type}-rank{self.global_rank}'
            video_path = os.path.join(save_dir, f'{base_name}.mp4')
            write_video(video_path, videos[0], fps=16)

            # 3. 保存 prompt 到同名 txt
            if prompt is not None:
                txt_path = os.path.join(save_dir, f'{base_name}.txt')
                with open(txt_path, 'w', encoding='utf-8') as f:
                    f.write(prompt[0])

            torch.cuda.empty_cache()

    def save_latent_and_reward(self, video_latents, task_type, prompt):
        save_dir = self.gen_train_videos_dir
        # 1. 解码视频
        video_latents = video_latents.detach()
        self.vae.to(device=self.device, dtype=self.dtype)
        videos = self.vae.decode_to_pixel(video_latents)
        videos = (videos * 0.5 + 0.5).clamp(0, 1)
        videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
        self.vae.model.clear_cache()
        self.vae.to(device='cpu')
        # 2. 保存 mp4
        base_name = f'step{self.step}-{task_type}-rank{self.global_rank}'
        video_path = os.path.join(save_dir, f'{base_name}.mp4')
        write_video(video_path, videos[0], fps=16)
        torch.cuda.empty_cache()
        # 3. reward
        all_reward = self.reward_model.inferencer.reward([video_path], prompt, use_norm=False)[0]
        # avg_reward = (all_reward['MQ'] + all_reward["VQ"] + all_reward["TA"]) / 3
        avg_reward = 0.1 * all_reward['MQ'] + 0.1 * all_reward["VQ"] + all_reward["TA"] * 0.8

        reward_log_dict = {
            "avg_reward" : avg_reward,
            "mq_reward" : all_reward['MQ'],
            "vq_reward" : all_reward['VQ'],
            "ta_reward" : all_reward['TA']
        }
        # 4. save prompt
        txt_path = os.path.join(save_dir, f'{base_name}.txt')
        with open(txt_path, 'w', encoding='utf-8') as f:
            f.write(prompt[0])

        return reward_log_dict

    def _run_generator(
        self,
        image_or_video_shape, conditional_dict: dict
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        noise_shape = image_or_video_shape  # 1, 21, 16, tgt_h // 8, tgt_w // 8
        # _consistency_backward_simulation, run random times (1-4) add_noise-de_noise to get a pred_image, only con_forward
        pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(  
            noise=torch.randn(*noise_shape, device=self.device, dtype=self.dtype),
            **conditional_dict
        )
        pred_image_or_video = pred_image_or_video.to(self.dtype)
        
        return pred_image_or_video, denoised_timestep_from, denoised_timestep_to

    def _consistency_backward_simulation(
        self,
        noise: torch.Tensor,
        **conditional_dict: dict
    ) -> torch.Tensor:
        if self.inference_pipeline is None:
            self._initialize_inference_pipeline()

        return self.inference_pipeline.inference_with_trajectory(
            noise=noise, clip_fea=None, y=None, **conditional_dict
        )

    def _initialize_inference_pipeline(self):
        """
        Lazy initialize the inference pipeline during the first backward simulation run.
        Here we encapsulate the inference code with a model-dependent outside function.
        We pass our FSDP-wrapped modules into the pipeline to save memory.
        """
        self.inference_pipeline = BidirectionalTrainingPipeline(
            model_name=self.generator_name,
            denoising_step_list=self.denoising_step_list,
            scheduler=self.scheduler,
            generator=self.generator,
        )

    def gen_video(   
        self, idx,
        prompt, image_or_video_shape,
        conditional_dict, unconditional_dict
    ):
        with torch.no_grad():
            noisy_image_or_video = torch.randn(*image_or_video_shape, device=self.device, dtype=self.dtype) 
            noise = noisy_image_or_video
            for index, current_timestep in enumerate(self.denoising_step_list[:-1]):
                _, pred_image_or_video = self.generator(
                    noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                    timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep, 
                )  # [B, F, C, H, W]

                next_timestep = self.denoising_step_list[index + 1] * torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device)
                noisy_image_or_video = self.scheduler.add_noise(
                    pred_image_or_video.flatten(0, 1), torch.randn_like(pred_image_or_video.flatten(0, 1)),
                    next_timestep.flatten(0, 1)
                ).unflatten(0, noise.shape[:2])

            _, pred_image_or_video = self.generator(
                noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * self.denoising_step_list[-1], 
            )  # [B, F, C, H, W]

        self.vae.to(device=self.device, dtype=self.dtype)
        videos = self.vae.decode_to_pixel(pred_image_or_video)
        videos = (videos * 0.5 + 0.5).clamp(0, 1)
        videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
        self.vae.model.clear_cache()
        self.vae.to(device='cpu')
        
        base_name = f'step{self.step}-rank{self.global_rank}-{idx}'
        video_path = os.path.join(self.gen_eval_videos_dir, f'{base_name}.mp4')
        write_video(video_path, videos[0], fps=16)
        torch.cuda.empty_cache()

        return video_path

    def get_gen_end_image(self, pred_image_or_video_raw, conditional_dict, denoised_timestep_to):
        left_step = denoised_timestep_to // 250
        print(f"denoised_timestep_to = {denoised_timestep_to}, so left_step = {left_step}")

        pred_image_or_video = pred_image_or_video_raw.clone()
        noise = torch.randn(*(pred_image_or_video.shape), device=self.device, dtype=self.dtype)  # 1, 31, 16, tgt_h // 8, tgt_w // 8
        
        for current_timestep in self.denoising_step_list[-left_step:]:
            with torch.no_grad():
                cur_timestep = current_timestep * torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device)

                noisy_image_or_video = self.scheduler.add_noise(
                    pred_image_or_video.flatten(0, 1),
                    torch.randn_like(pred_image_or_video.flatten(0, 1)),
                    cur_timestep.flatten(0, 1)
                ).unflatten(0, noise.shape[:2])

                _, pred_image_or_video = self.generator(
                    noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                    timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep
                )  # [B, F, C, H, W]

        return pred_image_or_video