
from typing import Tuple
from einops import rearrange
from torch import nn
import torch.distributed as dist
import torch

from utils.loss import get_denoising_loss
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper, WanCLIPEncoder

import os 

from pipeline import BidirectionalTrainingPipeline

import torch.nn.functional as F
from typing import Optional, Tuple
import torch

from torchvision.io import write_video
from einops import rearrange
import os

from utils.rm_video_align_wrapper import VideoAlignWrapper

class T2V_FLOW_DPO(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        
        self._initialize_models(args, device)

        self.global_rank = dist.get_rank()
        self.is_main_process = self.global_rank == 0
        
        self.step = 0
        
        self.gen_train_videos_dir = os.path.join(args.logdir, "gen_train_videos")
        self.gen_eval_videos_dir = os.path.join(args.logdir, "gen_eval_videos")
        self.gen_log_videos_dir = os.path.join(args.logdir, "gen_log_videos")
        if self.is_main_process:
            os.makedirs(self.gen_train_videos_dir, exist_ok=True)
            os.makedirs(self.gen_eval_videos_dir, exist_ok=True)
            os.makedirs(self.gen_log_videos_dir, exist_ok=True)

        self.device = device
        self.args = args

        self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32

        self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long) # [1000, 750, 500, 250]
        timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
        print(f"------- raw self.denoising_step_list {self.denoising_step_list}")
        self.denoising_step_list = timesteps[1000 - self.denoising_step_list] # [0, 250, 500, 750]
        print(f"------- self.denoising_step_list {self.denoising_step_list}")

        # this will be init later with fsdp-wrapped modules
        self.inference_pipeline: BidirectionalTrainingPipeline = None

        # Step 2: Initialize all dmd hyperparameters
        self.num_train_timestep = args.num_train_timestep  # 1000
        self.min_step = int(0.02 * self.num_train_timestep)  # 20
        self.max_step = int(0.98 * self.num_train_timestep)  # 980

        self.real_guidance_scale = args.guidance_scale  # 6.0
        
        self.timestep_shift = getattr(args, "timestep_shift", 5.0)  # 5.0
        self.min_score_timestep = getattr(args, "min_score_timestep", 0)


    def _initialize_models(self, args, device):
        self.real_model_name = getattr(args, "real_name", "Wan2.1-I2V-14B-720P") 
        self.fake_model_name = getattr(args, "fake_name", "Wan2.1-I2V-14B-720P") 
        self.generator_name = getattr(args, "generator_name", "Wan2.1-I2V-14B-720P") 

        self.real_model_path = getattr(args, "real_path") 
        self.fake_model_path = getattr(args, "fake_path") 
        self.generator_path = getattr(args, "generator_path")

        print("--------- loading generator model")
        self.generator = WanDiffusionWrapper(
            **getattr(args, "model_kwargs", {}), # timestep_shift: 5.0
            model_name=self.generator_name, model_path=self.generator_path, wan_type="gen"
        )
        self.generator.model.requires_grad_(True)

        print("--------- loading real_score model")
        self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, model_path=self.real_model_path, wan_type="real")
        self.real_score.model.requires_grad_(False)

        self.generator.enable_gradient_checkpointing()

        self.scheduler = self.generator.get_scheduler()
        self.scheduler.timesteps = self.scheduler.timesteps.to(device)
        self.scheduler.alphas_cumprod = None  # None

    def _get_timestep(
            self,
            min_timestep: int, max_timestep: int,
            batch_size: int, num_frame: int
    ) -> torch.Tensor:
        """
        Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
        from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
        - uniform_timestep, it will use the same timestep for all frames.
        """
        timestep = torch.randint(
            min_timestep, max_timestep,
            [batch_size, 1],
            device=self.device, dtype=torch.long
        ).repeat(1, num_frame)

        if self.timestep_shift > 1:
            timestep = self.timestep_shift * (timestep / 1000) / (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000

        return timestep

    def _compute_dpo_grad(self, image_or_video, latent,
            conditional_dict: dict, unconditional_dict: dict, 
            denoised_timestep_to,
            max_grad_norm = 0.8, 
            beta: float = 100.0):
        """
        Flow-DPO loss
        latent : [B, F, C, H, W]  优样本
        image_or_video: [B, F, C, H, W]  差样本
        conditional_dict / clip_fea / y : 与 generator 前向保持一致
        beta       : DPO 正则项系数 按照flow-GRPO设为100
        """
        device = latent.device
        B = 1
        assert latent.size(0) == B and image_or_video.size(0) == B

        def _single_forward(latent: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
            """返回 (model_err, ref_err) 单个样本"""
            # 1) 采样 t 与噪声
            noise = torch.randn_like(latent)
            batch_size, num_frame = latent.shape[:2]
            timestep = self._get_timestep(
                self.min_score_timestep, self.num_train_timestep,
                latent.shape[0], latent.shape[1]
            )

            x_t = self.scheduler.add_noise(
                    latent.flatten(0, 1),
                    noise.flatten(0, 1), timestep.flatten(0, 1)
            ).detach().unflatten(0, (batch_size, num_frame))
            
            target = noise - latent

            # 2) 训练模型
            pred_flow, _ = self.generator(
                noisy_image_or_video=x_t,
                conditional_dict=conditional_dict,
                timestep=timestep
            )

            # 3) 参考模型
            with torch.no_grad():
                ref_flow, _ = self.real_score(
                    noisy_image_or_video=x_t,
                    conditional_dict=conditional_dict,
                    timestep=timestep
                )

            # 4) 计算 L2 误差
            # print(f"target = {target}, pred_flow = {pred_flow}")
            model_err = torch.mean((pred_flow - target) ** 2)
            ref_err   = torch.mean((ref_flow - target) ** 2) 
            return model_err, ref_err

        model_win_err, ref_win_err   = _single_forward(latent)
        # allocated = torch.cuda.memory_allocated() / 1024**3
        # print(f"[after win _single_forward allocated: {allocated:.1f} GB")
        model_lose_err, ref_lose_err = _single_forward(image_or_video)
        # allocated = torch.cuda.memory_allocated() / 1024**3
        # print(f"[after loss _single_forward allocated:  {allocated:.1f} GB")

        print(f"model_win_err={model_win_err}, ref_win_err={ref_win_err}")
        print(f"model_lose_err={model_lose_err}, ref_lose_err={ref_lose_err}")
        # 5) DPO 损失
        w_diff = model_win_err - ref_win_err
        l_diff = model_lose_err - ref_lose_err
        inside = -0.5 * beta * (w_diff - l_diff)
        loss = -torch.nn.functional.logsigmoid(inside)
        print(f"DPO Loss = {loss}")
        log_dict = {
            "w_diff": w_diff.detach(),
            "l_diff": l_diff.detach()
        }
        return loss, log_dict



    def compute_flow_dpo_loss(
        self,
        image_or_video: torch.Tensor, image_or_video_end, latent,
        conditional_dict: dict, unconditional_dict: dict,
        denoised_timestep_from: int = 0, denoised_timestep_to: int = 0,
    ) -> Tuple[torch.Tensor, dict]:

        dpo_loss, dpo_log_dict = self._compute_dpo_grad(
            image_or_video=image_or_video_end, latent=latent,
            conditional_dict=conditional_dict, unconditional_dict=unconditional_dict, denoised_timestep_to=denoised_timestep_to
        )

        return dpo_loss, dpo_log_dict

    def generator_loss(
        self, prompt, latent,
        image_or_video_shape,
        conditional_dict: dict, unconditional_dict: dict,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: Unroll generator to obtain fake videos
        with torch.no_grad():
            pred_image, denoised_timestep_from, denoised_timestep_to = self._run_generator(
                image_or_video_shape=image_or_video_shape, conditional_dict=conditional_dict,
            )
        pred_image_end = self.get_gen_end_image(pred_image.detach(), conditional_dict, denoised_timestep_to)

        # Step 2: Compute the DMD loss
        dmd_loss, dmd_log_dict = self.compute_flow_dpo_loss(
            image_or_video=pred_image, image_or_video_end=pred_image_end, latent=latent,
            conditional_dict=conditional_dict, unconditional_dict=unconditional_dict,
            denoised_timestep_from=denoised_timestep_from, denoised_timestep_to=denoised_timestep_to
        )

        del pred_image, denoised_timestep_from, denoised_timestep_to
        # dmd_log_dict.update(reward_log_dict)
        return dmd_loss, dmd_log_dict


    def _run_generator(
        self,
        image_or_video_shape, conditional_dict: dict
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        noise_shape = image_or_video_shape  # 1, 21, 16, tgt_h // 8, tgt_w // 8
        # _consistency_backward_simulation, run random times (1-4) add_noise-de_noise to get a pred_image, only con_forward
        pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(  
            noise=torch.randn(*noise_shape, device=self.device, dtype=self.dtype),
            **conditional_dict
        )
        pred_image_or_video = pred_image_or_video.to(self.dtype)
        
        return pred_image_or_video, denoised_timestep_from, denoised_timestep_to

    def _consistency_backward_simulation(
        self,
        noise: torch.Tensor,
        **conditional_dict: dict
    ) -> torch.Tensor:
        if self.inference_pipeline is None:
            self._initialize_inference_pipeline()

        return self.inference_pipeline.inference_with_trajectory(
            noise=noise, clip_fea=None, y=None, **conditional_dict
        )

    def _initialize_inference_pipeline(self):
        """
        Lazy initialize the inference pipeline during the first backward simulation run.
        Here we encapsulate the inference code with a model-dependent outside function.
        We pass our FSDP-wrapped modules into the pipeline to save memory.
        """
        self.inference_pipeline = BidirectionalTrainingPipeline(
            model_name=self.generator_name,
            denoising_step_list=self.denoising_step_list,
            scheduler=self.scheduler,
            generator=self.generator,
        )

    def gen_video(   
        self, idx,
        prompt, image_or_video_shape,
        conditional_dict, unconditional_dict
    ):
        with torch.no_grad():
            noisy_image_or_video = torch.randn(*image_or_video_shape, device=self.device, dtype=self.dtype) 
            noise = noisy_image_or_video
            for index, current_timestep in enumerate(self.denoising_step_list[:-1]):
                _, pred_image_or_video = self.generator(
                    noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                    timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep, 
                )  # [B, F, C, H, W]

                next_timestep = self.denoising_step_list[index + 1] * torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device)
                noisy_image_or_video = self.scheduler.add_noise(
                    pred_image_or_video.flatten(0, 1), torch.randn_like(pred_image_or_video.flatten(0, 1)),
                    next_timestep.flatten(0, 1)
                ).unflatten(0, noise.shape[:2])

            _, pred_image_or_video = self.generator(
                noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * self.denoising_step_list[-1], 
            )  # [B, F, C, H, W]

        self.vae.to(device=self.device, dtype=self.dtype)
        videos = self.vae.decode_to_pixel(pred_image_or_video)
        videos = (videos * 0.5 + 0.5).clamp(0, 1)
        videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
        self.vae.model.clear_cache()
        self.vae.to(device='cpu')
        
        base_name = f'step{self.step}-rank{self.global_rank}-{idx}'
        video_path = os.path.join(self.gen_eval_videos_dir, f'{base_name}.mp4')
        write_video(video_path, videos[0], fps=16)
        torch.cuda.empty_cache()

        return video_path

    def get_gen_end_image(self, pred_image_or_video_raw, conditional_dict, denoised_timestep_to):
        left_step = denoised_timestep_to // 250
        print(f"denoised_timestep_to = {denoised_timestep_to}, so left_step = {left_step}")

        pred_image_or_video = pred_image_or_video_raw.clone()
        noise = torch.randn(*(pred_image_or_video.shape), device=self.device, dtype=self.dtype)  # 1, 31, 16, tgt_h // 8, tgt_w // 8
        
        for current_timestep in self.denoising_step_list[-left_step:]:
            with torch.no_grad():
                cur_timestep = current_timestep * torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device)

                noisy_image_or_video = self.scheduler.add_noise(
                    pred_image_or_video.flatten(0, 1),
                    torch.randn_like(pred_image_or_video.flatten(0, 1)),
                    cur_timestep.flatten(0, 1)
                ).unflatten(0, noise.shape[:2])

                _, pred_image_or_video = self.generator(
                    noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict,
                    timestep=torch.ones(noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep
                )  # [B, F, C, H, W]

        return pred_image_or_video