from videox_fun.models.wan_image_encoder import CLIPModel
from videox_fun.models.wan_vae import AutoencoderKLWan
from videox_fun.models.wan_text_encoder import WanT5EncoderModel
from videox_fun.pipeline.flow_match import get_sigma_from_timestep
from transformers import AutoTokenizer
import torch.nn.functional as F
from typing import Tuple
from torch import nn
import torch
import os
import math
import random
from omegaconf import OmegaConf
from diffusers import FlowMatchEulerDiscreteScheduler
from videox_fun.models import WanTransformer3DModel, CausalWanModel, Discriminator
import contextlib
from videox_fun.pipeline.self_forcing_RL import SelfForcingRLPipeline


def get_diffusion_wrapper(model_name):
    return {
        "wan": WanTransformer3DModel,
        "causal_wan": CausalWanModel
    }[model_name]

def load_state_dict(model, transformer_path):
    if transformer_path is not None:
        print(f"From checkpoint: {transformer_path}")
        if transformer_path.endswith("safetensors"):
            from safetensors.torch import load_file, safe_open
            state_dict = load_file(transformer_path)
        else:
            state_dict = torch.load(transformer_path, map_location="cpu")
        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict

        m, u = model.load_state_dict(state_dict, strict=False)
        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0

def filter_kwargs(cls, kwargs):
    import inspect
    sig = inspect.signature(cls.__init__)
    valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
    filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
    return filtered_kwargs


class wan_RL(nn.Module):
    def __init__(self, args, device, dtype):
        """
        Initialize the DMD (Distribution Matching Distillation) module.
        This class is self-contained and compute generator and fake score losses
        in the forward pass.
        """
        super().__init__()

        # Step 1: Initialize all models

        self.generator_model_name = getattr(
            args, "generator_name", args.model_name)
        self.ref_model = getattr(
            args, "generator_name", args.model_name)
        # self.real_model_name = getattr(args, "real_name", args.model_name)
        # self.fake_model_name = getattr(args, "fake_name", args.model_name)

        self.generator_task_type = getattr(
            args, "generator_task_type", args.generator_task)
        self.ref_task_type = getattr(
            args, "generator_task_type", args.generator_task)
        # self.real_task_type = getattr(
        #     args, "real_task_type", args.generator_task)
        # self.fake_task_type = getattr(
        #     args, "fake_task_type", args.generator_task)

        self.generator = get_diffusion_wrapper(model_name=self.generator_model_name).from_pretrained(
            args.model_path,
            transformer_additional_kwargs=OmegaConf.to_container(args['transformer_additional_kwargs']))
        load_state_dict(self.generator, args.generator_transformer_path)
        self.generator.requires_grad_(True)
        self.generator.train()
        self.generator.to(device=device, dtype=dtype)

        self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
        self.num_last_frames_with_grad = getattr(args, "num_last_frames_with_grad", 21)

        if self.num_frame_per_block > 1:
            self.generator.num_frame_per_block = self.num_frame_per_block
        
        self.ref_model = get_diffusion_wrapper(model_name=self.generator_model_name).from_pretrained(
            args.model_path,
            transformer_additional_kwargs=OmegaConf.to_container(args['transformer_additional_kwargs']))
        load_state_dict(self.ref_model, args.generator_transformer_path)
        self.ref_model.requires_grad_(False).eval()
        self.ref_model.to(device=device, dtype=dtype)

        # self.real_score = get_diffusion_wrapper(model_name=self.real_model_name).from_pretrained(
        #     args.model_path,
        #     transformer_additional_kwargs=OmegaConf.to_container(args['transformer_additional_kwargs']))
        # load_state_dict(self.real_score, args.critic_transformer_path)
        # self.real_score.requires_grad_(False).eval()
        # self.real_score.to(device=device, dtype=dtype)

        # self.fake_score = get_diffusion_wrapper(model_name=self.fake_model_name).from_pretrained(
        #     args.model_path,
        #     transformer_additional_kwargs=OmegaConf.to_container(args['transformer_additional_kwargs']))
        # if hasattr(args, 'fake_transformer_path') and args.fake_transformer_path is not None:
        #     print(f"Fake score loading from checkpoint: {args.fake_transformer_path}")
        #     load_state_dict(self.fake_score, args.fake_transformer_path)
        # else:
        #     load_state_dict(self.fake_score, args.critic_transformer_path)
        # self.fake_score.requires_grad_(True)
        # self.fake_score.train()
        # self.fake_score.to(device=device, dtype=dtype)

        if args.gradient_checkpointing:
            self.generator.enable_gradient_checkpointing()
            # self.fake_score.enable_gradient_checkpointing()

        self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.model_path, 'google/umt5-xxl/'), seq_len=512, clean='whitespace')
        self.text_encoder = WanT5EncoderModel.from_pretrained(
            os.path.join(args.model_path, args.text_encoder_kwargs.get('text_encoder_subpath', 'text_encoder')),
            additional_kwargs=OmegaConf.to_container(args.text_encoder_kwargs))
        self.text_encoder.requires_grad_(False).eval()
        self.text_encoder.to(device=device, dtype=dtype)

        self.vae = AutoencoderKLWan.from_pretrained(os.path.join(args.model_path, 'Wan2.1_VAE.pth'))
        self.vae.requires_grad_(False).eval()
        self.vae.to(device=device, dtype=dtype)

        self.clip_image_encoder = CLIPModel.from_pretrained(os.path.join(args.model_path, 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'))
        self.clip_image_encoder.requires_grad_(False).eval()
        self.clip_image_encoder.to(device=device, dtype=dtype)

        self.inference_pipeline = None

        # Step 2: Initialize all dmd hyperparameters

        self.denoising_step_list = torch.tensor(
            args.denoising_step_list, dtype=torch.long, device=device)
        self.num_train_timestep = args.scheduler_kwargs.num_train_timesteps
        self.min_step = int(0.02 * self.num_train_timestep)
        self.max_step = int(0.98 * self.num_train_timestep)
        self.real_guidance_scale = args.real_guidance_scale
        self.timestep_shift = args.scheduler_kwargs.get('shift', 1.0)

        self.args = args
        self.device = device
        self.dtype = dtype
        self.scheduler = FlowMatchEulerDiscreteScheduler(
            **filter_kwargs(FlowMatchEulerDiscreteScheduler, args.scheduler_kwargs)
        )
        print(f"DMD scheduler: {self.scheduler}")
        denoising_sigma_list = get_sigma_from_timestep(self.scheduler, self.denoising_step_list, self.dtype)
        print(f"DMD timestep_shift = {self.timestep_shift}, denoising_step_list = {self.denoising_step_list}, denoising_sigma_list = {denoising_sigma_list}")
    
        self.patch_size = self.generator.config.patch_size

        # self.discriminator = None
        # if self.args.adv_g_loss_weight > 0:
        #     self.discriminator = Discriminator(
        #         stride=self.args.output_features_stride, num_h_per_head=self.args.discriminator_num_heads, 
        #         adapter_channel_dims=[self.fake_score.dim], 
        #         adapter_out_dim=math.prod(self.fake_score.patch_size) * self.fake_score.out_dim,
        #         total_layers=self.fake_score.num_layers
        #     ).to(dtype=dtype, device=device)
        #     print(f"Initialize discriminator, prepare for gan training")
        #     self.discriminator.requires_grad_(True)
        #     self.discriminator.train()

    def _process_timestep(self, timestep: torch.Tensor, type: str) -> torch.Tensor:
        """
        Pre-process the randomly generated timestep based on the generator's task type.
        Input:
            - timestep: [batch_size, num_frame] tensor containing the randomly generated timestep.
            - type: a string indicating the type of the current model (image, bidirectional_video, or causal_video).
        Output Behavior:
            - image: check that the second dimension (num_frame) is 1.
            - bidirectional_video: broadcast the timestep to be the same for all frames.
            - causal_video: broadcast the timestep to be the same for all frames **in a block**.
        """
        if type == "image":
            assert timestep.shape[1] == 1
            return timestep
        elif type == "bidirectional_video":
            for index in range(timestep.shape[0]):
                timestep[index] = timestep[index, 0]
            return timestep
        elif type == "causal_video":
            # make the noise level the same within every motion block
            timestep = timestep.reshape(
                timestep.shape[0], -1, self.num_frame_per_block)
            timestep[:, :, 1:] = timestep[:, :, 0:1]
            timestep = timestep.reshape(timestep.shape[0], -1)
            return timestep
        else:
            raise NotImplementedError("Unsupported model type {}".format(type))

    def _compute_kl_grad(
        self, noisy_image_or_video: torch.Tensor,
        estimated_clean_image_or_video: torch.Tensor,
        timestep: torch.Tensor,
        conditional_dict: dict, unconditional_dict: dict,
        inpaint_latents: torch.Tensor, clip_context: torch.Tensor,
        normalization: bool = True
    ) -> Tuple[torch.Tensor, dict]:
        """
        Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
        Input:
            - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
            - estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
            - timestep: a tensor with shape [B, F] containing the randomly generated timestep.
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
            - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
            - normalization: a boolean indicating whether to normalize the gradient.
        Output:
            - kl_grad: a tensor representing the KL grad.
            - kl_log_dict: a dictionary containing the intermediate tensors for logging.
        """
        seq_len = noisy_image_or_video.shape[1] * noisy_image_or_video.shape[3] * noisy_image_or_video.shape[4] // self.patch_size[1] // self.patch_size[2]
        # Step 1: Compute the fake score
        pred_fake_image = self.fake_score(
            x=noisy_image_or_video.permute(0,2,1,3,4),
            context=conditional_dict['prompt_embeds'],
            t=timestep[:,0],
            seq_len=seq_len,
            y=inpaint_latents.permute(0,2,1,3,4),
            clip_fea=clip_context
        ).permute(0,2,1,3,4)
        pred_fake_image = self.convert_flow_pred_to_x0(pred_fake_image, noisy_image_or_video, timestep)

        # Step 2: Compute the real score
        # We compute the conditional and unconditional prediction
        # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
        pred_real_image_cond = self.real_score(
            x=noisy_image_or_video.permute(0,2,1,3,4),
            context=conditional_dict['prompt_embeds'],
            t=timestep[:,0],
            seq_len=seq_len,
            y=inpaint_latents.permute(0,2,1,3,4),
            clip_fea=clip_context
        ).permute(0,2,1,3,4)
        pred_real_image_cond = self.convert_flow_pred_to_x0(pred_real_image_cond, noisy_image_or_video, timestep)

        pred_real_image_uncond = self.real_score(
            x=noisy_image_or_video.permute(0,2,1,3,4),
            context=unconditional_dict['prompt_embeds'],
            t=timestep[:,0],
            seq_len=seq_len,
            y=inpaint_latents.permute(0,2,1,3,4),
            clip_fea=clip_context
        ).permute(0,2,1,3,4)
        pred_real_image_uncond = self.convert_flow_pred_to_x0(pred_real_image_uncond, noisy_image_or_video, timestep)

        pred_real_image = pred_real_image_uncond + (
            pred_real_image_cond - pred_real_image_uncond
        ) * self.real_guidance_scale

        # Step 3: Compute the DMD gradient (DMD paper eq. 7).
        grad = (pred_fake_image - pred_real_image)

        # TODO: Change the normalizer for causal teacher
        if normalization:
            # Step 4: Gradient normalization (DMD paper eq. 8).
            p_real = (estimated_clean_image_or_video - pred_real_image)
            normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
            grad = grad / normalizer
        grad = torch.nan_to_num(grad)

        return grad, {
            "dmdtrain_clean_latent": estimated_clean_image_or_video.detach(),
            "dmdtrain_noisy_latent": noisy_image_or_video.detach(),
            "dmdtrain_pred_real_image": pred_real_image.detach(),
            "dmdtrain_pred_fake_image": pred_fake_image.detach(),
            "dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
            "timestep": timestep.detach()
        }

    def add_noise(self, latents, noise, timestep):
        # 专门处理 timestep==0 的情况，将 sigma 也置为 0
        sigma = get_sigma_from_timestep(self.scheduler, timestep, latents.dtype)
        sigma = sigma.reshape(-1, 1, 1, 1)
        noisy_latents = (1.0 - sigma) * latents + sigma * noise
        return noisy_latents

    def convert_flow_pred_to_x0(self, flow_pred, xt, timestep):
        num_frames = 0
        if timestep.ndim == 2:
            batch_size, num_frames = timestep.shape
            timestep = timestep.flatten(0, 1)
        sigma = get_sigma_from_timestep(self.scheduler, timestep, xt.dtype)
        if num_frames > 0:
            assert flow_pred.ndim == 5
            sigma = sigma.reshape(batch_size, num_frames, 1, 1, 1)
        else:
            assert flow_pred.ndim == 4
            sigma = sigma.reshape(-1, 1, 1, 1)
        x0_pred = xt - sigma * flow_pred
        return x0_pred
    
    def encode_text(self, text):
        if isinstance(text, str):
            text = [text]
        prompt_ids = self.tokenizer(
            text, 
            padding="max_length", 
            max_length=512, 
            truncation=True,
            add_special_tokens=True, 
            return_tensors="pt"
        )
        text_input_ids = prompt_ids.input_ids.to(self.device)
        prompt_attention_mask = prompt_ids.attention_mask.to(self.device)
        seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
        context = self.text_encoder(text_input_ids, prompt_attention_mask)[0]

        for u, v in zip(context, seq_lens):
            u[v:] = 0.0  # set padding to 0.0
        res = {
            "prompt_embeds": context.to(self.device, self.dtype),
        }
        return res

    

    def _initialize_inference_pipeline(self):
        """
        Lazy initialize the inference pipeline during the first backward simulation run.
        Here we encapsulate the inference code with a model-dependent outside function.
        We pass our FSDP-wrapped modules into the pipeline to save memory.
        """
        # self.inference_pipeline = get_inference_pipeline_wrapper(
        #     self.generator_model_name,
        #     denoising_step_list=self.denoising_step_list,
        #     scheduler=self.scheduler,
        #     generator=self.generator,
        #     num_frame_per_block=self.num_frame_per_block
        # )
        self.inference_pipeline = SelfForcingRLPipeline(
            denoising_step_list=self.denoising_step_list,
            scheduler=self.scheduler,
            generator=self.generator,
            ref_model=self.ref_model,
            vae=self.vae,
            image_processor = None,
            num_frame_per_block=self.num_frame_per_block,
            num_last_frames_with_grad= self.num_last_frames_with_grad,
        )

    def _consistency_backward_simulation(self, noise: torch.Tensor, **conditional_dict: dict) -> torch.Tensor:
        """
        Simulate the generator's input from noise to avoid training/inference mismatch.
        See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
        Here we use the consistency sampler (https://arxiv.org/abs/2303.01469)
        Input:
            - noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
        Output:
            - output: a tensor with shape [B, T, F, C, H, W].
            T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
            represents the x0 prediction at each timestep.
        """
        if self.inference_pipeline is None:
            self._initialize_inference_pipeline()

        return self.inference_pipeline.inference_with_trajectory(noise=noise, **conditional_dict)

    
    def _run_sampling(self, image_or_video_shape, conditional_dict: dict, unconditional_dict: dict, clean_latent: torch.tensor,
                inpaint_latents: torch.Tensor, clip_context: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Optionally simulate the generator's input from noise using backward simulation
        and then run the generator for one-step.
        Input:
            - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
            - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
            - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
            - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
        Output:
            - pred_image: a tensor with shape [B, F, C, H, W].
        """
        # Step 1: Sample noise and backward simulate the generator's input
        images, sample, cache_sample = self._consistency_backward_simulation(
            noise=torch.randn(image_or_video_shape, device=self.device, dtype=self.dtype),
            initial_latent=clean_latent[:, :1],
            inpaint_latents=inpaint_latents,
            clip_context=clip_context,
            **conditional_dict,
        )
        # bsz, num_frames = image_or_video_shape[0], image_or_video_shape[1]
        # gradient_mask = torch.zeros((bsz, num_frames, 1, 1, 1), device=self.device, dtype=torch.bool)
        # gradient_mask[:, -self.num_last_frames_with_grad:, 0, 0, 0] = 1
        # return pred_image_or_video, pred_image_or_video_x0, gradient_mask, next_timestep
        return images, sample, cache_sample
        
        
    def sampling(self, image_or_video_shape, conditional_dict: dict, unconditional_dict: dict, clean_latent: torch.Tensor,
                inpaint_latents: torch.Tensor=None, clip_context: torch.Tensor=None) -> Tuple[torch.Tensor, dict]:
        
        images, sample, cache_sample = self._run_sampling(
            image_or_video_shape=image_or_video_shape,
            conditional_dict=conditional_dict,
            unconditional_dict=unconditional_dict,
            clean_latent=clean_latent,
            inpaint_latents=inpaint_latents,
            clip_context=clip_context,
        )
        return images, sample, cache_sample

    