import argparse
import logging
import math
import os
import shutil
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torchvision.transforms as TT
import transformers
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer

import diffusers
# from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import cast_training_params, free_memory
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.video_processor import VideoProcessor
from diffusers.pipelines.wan.pipeline_wan import WanPipelineOutput


import imageio
import numpy as np
from typing import Union

import deepspeed
import torch
import torchvision
import torch.distributed as dist
from einops import rearrange
# import sys
# import traceback
# sys.tracebacklimit = 50  # 或更大的数字

# os.environ["TOKENIZERS_PARALLELISM"] = "true"
# from accelerate.utils import DeepSpeedPlugin
import torch.nn as nn




def save_videos_grid(videos: torch.Tensor, path: str, rescale=True, n_rows=6, fps=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
            x = (x.to(torch.float32) * 255).numpy().astype(np.uint8)
        outputs.append(x)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, fps=fps)

if is_wandb_available():
    import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# check_min_version("0.32.0.dev0")

logger = get_logger(__name__)

def get_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script for wanx.")

    # Model information
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        default=None,
        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
    )   

    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )

    # Dataset information
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that 🤗 Datasets can understand."
        ),
    )

    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )

    parser.add_argument(
        "--instance_data_root",
        type=str,
        default=None,
        help=("A folder containing the training data."),
    )

    parser.add_argument(
        "--video_column",
        type=str,
        default="video",
        help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.",
    )

    parser.add_argument(
    "--caption_column",
    type=str,
    default="text",
    help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in '--instance_data_root' folder containing the line-separated prompts.",
    )

    parser.add_argument(
        "--id_token",
        type=str,
        default=None,
        help="Identifier token appended to the start of each prompt if provided.",
    )

    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
        ),
    )

    # Validation
    parser.add_argument(
        "--validation_prompt",
        type=str,
        default=None,
        help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator'.",
    )

    parser.add_argument(
        "--validation_prompt_separator",
        type=str,
        default=":::",
        help="String that separates multiple validation prompts",
    )

    parser.add_argument(
        "--num_validation_videos",
        type=int,
        default=1,
        help="Number of videos that should be generated during validation per 'validation_prompt'.",
    )

    parser.add_argument(
        "--validation_epochs",
        type=int,
        default=50,
        help=(
            "Run validation every X epochs. Validation consists of running the prompt 'args.validation_prompt' multiple times: 'args.num_validation_videos'."
        ),
    )
    parser.add_argument(
    "--guidance_scale",
    type=float,
    default=6,
    help="The guidance scale to use while sampling validation videos.",
    )

    parser.add_argument(
        "--use_dynamic_cfg",
        action="store_true",
        default=False,
        help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
    )

    # Training information
    parser.add_argument(
        "--seed", type=int, default=None, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--rank",
        type=int,
        default=128,
        help=("The dimension of the LoRA update matrices."),
    )

    parser.add_argument(
        "--lora_alpha",
        type=float,
        default=128,
        help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
    )

    parser.add_argument(
        "--lambda_reg",
        type=float,
        default=0.,
        help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
    )

    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
        ),
    )
    parser.add_argument(
    "--output_dir",
    type=str,
    default="cogvideo-tdm-tcd-lora-fixbug",
    help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--k_step",
        type=int,
        default=4,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--cfg",
        type=float,
        default=5,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--eta",
        type=float,
        default=0.9,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--height",
        type=int,
        default=480,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--width",
        type=int,
        default=720,
        help="All input videos are resized to this width.",
    )

    parser.add_argument(
    "--video_reshape_mode",
    type=str,
    default="center",
    help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)

    parser.add_argument(
        "--fps", type=int, default=8, help="All input videos will be used at this FPS."
    )

    parser.add_argument(
        "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
    )

    parser.add_argument(
        "--skip_frames_start",
        type=int,
        default=0,
        help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
    )

    parser.add_argument(
        "--skip_frames_end",
        type=int,
        default=0,
        help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
    )

    parser.add_argument(
        "--random_flip",
        action="store_true",
        help="whether to randomly flip videos horizontally",
    )

    parser.add_argument(
        "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
    )

    parser.add_argument("--num_train_epochs", type=int, default=1)

    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides '--num_train_epochs'.",
    )

    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
            " training using '--resume_from_checkpoint'."
        ),
    )
    parser.add_argument(
    "--checkpoints_total_limit",
    type=int,
    default=None,
    help="Max number of checkpoints to store.",
    )

    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            " --checkpointing_steps, or 'latest' to automatically select the last available checkpoint."
        ),
    )

    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )

    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )

    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
        "--learning_rate_fake",
        type=float,
        default=1e-3,
        help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
    "--learning_rate_g",
    type=float,
    default=1e-4,
    help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )

    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            "The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',"
            "'constant', 'constant_with_warmup']"
        ),
    )

    parser.add_argument(
        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
    )

    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
    )

    parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")

    parser.add_argument(
        "--enable_slicing",
        action="store_true",
        default=False,
        help="Whether or not to use VAE slicing for saving memory.",
    )

    parser.add_argument(
        "--enable_tiling",
        action="store_true",
        default=False,
        help="Whether or not to use VAE tiling for saving memory.",
    )

    # Optimizer
    parser.add_argument(
        "--optimizer",
        type=lambda s: s.lower(),
        default="adam",
        choices=["adam", "adamw", "prodigy"],
        help="The optimizer type to use.",
    )

    parser.add_argument(
        "--use_8bit_adam",
        action="store_true",
        help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
    )

    parser.add_argument(
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam and Prodigy optimizers.",
    )

    parser.add_argument(
        "--adam_beta2",
        type=float,
        default=0.95,
        help="The beta2 parameter for the Adam and Prodigy optimizers.",
    )

    parser.add_argument(
        "--prodigy_beta3",
        type=float,
        default=None,
        help=(
            "Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses"
            " the value of square root of beta2."
        ),
    )

    parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")

    parser.add_argument(
        "--adam_weight_decay",
        type=float,
        default=1e-04,
        help="Weight decay to use for unet params",
    )

    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
    )

    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")

    parser.add_argument(
        "--prodigy_use_bias_correction",
        action="store_true",
        help="Turn on Adam's bias correction.",
    )

    parser.add_argument(
        "--prodigy_safeguard_warmup",
        action="store_true",
        help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
    )

    # Other information
    parser.add_argument(
        "--tracker_name",
        type=str,
        default=None,
        help="Project tracker name"
    )
    parser.add_argument(
        "--push_to_hub",
        action="store_true",
        help="Whether or not to push the model to the Hub."
    )
    parser.add_argument(
        "--hub_token",
        type=str,
        default=None,
        help="The token to use to push to the Model Hub."
    )
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )

    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help="Directory where logs are stored.",
    )

    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )

    parser.add_argument(
        "--use_sparsity",
        type=lambda x: x.lower() == 'true',
        default=False,
        help="Whether to use sparsity in the model training. Set to 'true' or 'false'.",
    )

    parser.add_argument(
        "--use_lora",
        type=lambda x: x.lower() == 'true',
        default=False,
        help="Whether to use LoRA in the model training. Set to 'true' or 'false'.",
    )

    parser.add_argument(
        "--report_to",
        type=str,
        default=None,
        help=(
            "The integration to report the results and logs to. Supported platforms are `tensorboard` (default),"
            "`wandb`, and `comet_ml`. Use `all` to report to all integrations."
        ),
    )

    return parser.parse_args()
        
def save_model_card(
    repo_id: str,
    videos=None,
    base_model: str = None,
    validation_prompt=None,
    repo_folder=None,
    fps=8,
):
    widget_dict = []
    if videos is not None:
        for i, video in enumerate(videos):
            export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
            widget_dict.append(
                {"text": validation_prompt if validation_prompt else "", "output": {"url": f"video_{i}.mp4"}}
            )

    model_description = f"""
# wanx LoRA - {repo_id}
"""
    model_card = load_or_create_model_card(
        repo_id_or_path=repo_id,
        from_training=True,
        license="other",
        base_model=base_model,
        prompt=validation_prompt,
        model_description=model_description,
        widget=widget_dict,
    )
    tags = [
        "text-to-video",
        "diffusers-training",
        "diffusers",
        "lora",
        "cogvideox",
        "cogvideox-diffusers",
        "template:sd-lora",
    ]
    model_card = populate_model_card(model_card, tags=tags)
    model_card.save(os.path.join(repo_folder, "README.md"))

def log_validation(
    pipe,
    args,
    accelerator,
    pipeline_args,
    epoch,
    is_final_validation: bool = False,
):
    logger.info(
        f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
    )
    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
    scheduler_args = {}
    if "variance_type" in pipe.scheduler.config:
        variance_type = pipe.scheduler.config.variance_type
        if variance_type in ["learned", "learned_range"]:
            variance_type = "fixed_small"
        scheduler_args["variance_type"] = variance_type

    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, scheduler_args)
    pipe.to(accelerator.device)

    # pipe.set_progress_bar_config(disable=True)

    # run inference
    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None

    videos = []

    for _ in range(args.num_validation_videos):
        pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
        pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])

        image_np = VaeImageProcessor.pt_to_numpy(pt_images)
        image_pil = VaeImageProcessor.numpy_to_pil(image_np)
        videos.append(image_pil)

    for tracker in accelerator.trackers:
        phase_name = "test" if is_final_validation else "validation"
        if tracker.name == "wandb":
            video_filenames = []
            for i, video in enumerate(videos):
                prompt = (
                    pipeline_args["prompt"][:25]
                    .replace("  ", "_")
                    .replace(" ", "_")
                    .replace("'", "_")
                    .replace('""', "_")
                    .replace("/", "_")
                )
                filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
                export_to_video(video, filename, fps=8)
                video_filenames.append(filename)

            tracker.log(
                {
                    phase_name: [
                        wandb.Video(filename, caption=f"({i}): {pipeline_args['prompt']}")
                        for i, filename in enumerate(video_filenames)
                    ]
                }
            )
        
    del pipe
    free_memory()
    return videos
    
def _get_t5_prompt_embeds(
    tokenizer: T5Tokenizer,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]],
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 226,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    text_input_ids: Optional[torch.Tensor] = None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    if tokenizer is not None:
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
    else:
        if text_input_ids is None:
            raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    # duplicate text embeddings for each generation per prompt, using mps friendly method
    _,seq_len,_ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)

    return prompt_embeds

def encode_prompt(
    tokenizer: T5Tokenizer,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]],
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 226,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    text_input_ids: Optional[torch.Tensor] = None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    prompt_embeds = _get_t5_prompt_embeds(
        tokenizer,
        text_encoder,
        prompt=prompt,
        num_videos_per_prompt=num_videos_per_prompt,
        max_sequence_length=max_sequence_length,
        device=device,
        dtype=dtype,
        text_input_ids=text_input_ids,
    )

    return prompt_embeds


def compute_prompt_embeddings(
    tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
):
    if requires_grad:
        prompt_embeds = encode_prompt(
            tokenizer,
            text_encoder,
            prompt,
            num_videos_per_prompt=1,
            max_sequence_length=max_sequence_length,
            device=device,
            dtype=dtype,
        )
    else:
        with torch.no_grad():
            prompt_embeds = encode_prompt(
                tokenizer,
                text_encoder,
                prompt,
                num_videos_per_prompt=1,
                max_sequence_length=max_sequence_length,
                device=device,
                dtype=dtype,
        )
    return prompt_embeds


def prepare_rotary_positional_embeddings(
    height: int,
    width: int,
    num_frames: int,
    vae_scale_factor_spatial: int = 8,
    patch_size: int = 2,
    attention_head_dim: int = 64,
    device: Optional[torch.device] = None,
    base_height: int = 480,
    base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
    grid_height = height // (vae_scale_factor_spatial * patch_size)
    grid_width = width // (vae_scale_factor_spatial * patch_size)
    base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
    base_size_height = base_height // (vae_scale_factor_spatial * patch_size)

    grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
    print('attention_head_dim=',attention_head_dim) #64
    freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
        embed_dim=attention_head_dim,
        crops_coords=grid_crops_coords,
        grid_size=(grid_height,grid_width),
        temporal_size=num_frames,
    )
    #print('freqs_cos.shape=',freqs_cos.shape) [89760,64]
    freqs_cos = freqs_cos.to(device=device)
    freqs_sin = freqs_sin.to(device=device)
    return freqs_cos, freqs_sin


def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
    # Use deepspeed optimizer
    if use_deepspeed:
        from accelerate.utils import DummyOptim

        return DummyOptim(
            params_to_optimize,
            lr=args.learning_rate,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            weight_decay=args.adam_weight_decay,
        )

    # Optimizer creation
    supported_optimizers = ["adam", "adamw", "prodigy"]
    if args.optimizer not in supported_optimizers:
        logger.warning(
            f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
        )
        args.optimizer = "adamw"

    if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]:
        logger.warning(
            f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was set to {args.optimizer.lower()}"
        )
    
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")

    if args.optimizer.lower() == "adam":
        optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
        print('optimizer==adam')
        optimizer = optimizer_class(
            params_to_optimize,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            weight_decay=args.adam_weight_decay,
        )

    elif args.optimizer.lower() == "adamw":
        optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
        optimizer = optimizer_class(
            params_to_optimize,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            weight_decay=args.adam_weight_decay,
        )
        
    elif args.optimizer.lower() == "prodigy":
        try:
            import prodigyopt
        except ImportError:
            raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`.")

        optimizer_class = prodigyopt.Prodigy

        if args.learning_rate <= 0.1:
            logger.warning(
                "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
            )

        optimizer = optimizer_class(
            params_to_optimize,
            betas=(args.adam_beta1, args.adam_beta2),
            beta3=args.prodigy_beta3,
            weight_decay=args.adam_weight_decay,
            eps=args.adam_epsilon,
            decouple=args.prodigy_decouple,
            safeguard_warmup=args.prodigy_safeguard_warmup,
            use_bias_correction=args.prodigy_use_bias_correction,
        )

    return optimizer


def main(args):
    if args.report_to == "wandb" and args.hub_token is not None:
        raise ValueError(
            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token. "
            "Please use `huggingface-cli login` to authenticate with the Hub."
        )
    if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
        # due to pytorch #99272, MPS does not yet support bfloat16
        raise ValueError(
            "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
        )

    logging_dir = Path(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    from accelerate.utils import DeepSpeedPlugin

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs],
    )
    print('-'*10)
    print('accelerator.state.deepspeed_plugin=',accelerator.state.deepspeed_plugin)
    print('-'*10)
    accelerator_d = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs],
    )

    # Disable AMP for MPS
    if torch.backends.mps.is_available():
        accelerator.native_amp = False

    if args.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

        if args.push_to_hub:
            repo_id = create_repo(
                repo_id=args.hub_model_id or Path(args.output_dir).name,
                exist_ok=True,
            ).repo_id
    TODO: 1 #offline to save memory
    #Prepare models and scheduler
    # tokenizer = AutoTokenizer.from_pretrained(
    #     args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
    # )

    # text_encoder = T5EncoderModel.from_pretrained(
    #     args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
    # )

    # CogVideoX-2b weights are stored in float16
    # CogVideoX-5b and CogVideoX-5b-128 weights are stored in bfloat16
    #load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
    load_dtype = torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16

    transformer = WanTransformer3DModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="transformer",
        torch_dtype=load_dtype,
        revision=args.revision,
        variant=args.variant,
    )
    

    if args.use_sparsity:
        from modify_wan import set_adaptive_block_sparse_attn_wanx
        set_adaptive_block_sparse_attn_wanx(transformer)
        print('Successfully set sparsity to the transformer')
    else:
        print('Not using sparsity')



    transformer_fake = WanTransformer3DModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="transformer",
        torch_dtype=load_dtype,
        revision=args.revision,
        variant=args.variant,
    )

    transformer_real = WanTransformer3DModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="transformer",
        torch_dtype=load_dtype,
        revision=args.revision,
        variant=args.variant,
    ).to(accelerator.device)
    transformer_real.requires_grad_(False)
    vae = AutoencoderKLWan.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant,torch_dtype=torch.float32
    )
    #vae.enable_slicing()

    scheduler = UniPCMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    scheduler.set_timesteps(1000, device=accelerator.device)
    

    if args.enable_slicing:
        vae.enable_slicing()
    if args.enable_tiling:
        vae.enable_tiling()

    # We only train the additional adapter LoRA Layers
    #text_encoder.requires_grad_(False)
    transformer.requires_grad_(True)
    vae.requires_grad_(False)
    transformer_real.requires_grad_(False)
    transformer_fake.requires_grad_(False)

    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.state.deepspeed_plugin:
        # DeepSpeed is handling precision, use what's in the DeepSpeed config
        if (
            "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
            and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
        ):
            weight_dtype = torch.float16
        if (
            "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
            and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
        ):
            weight_dtype = torch.bfloat16
    else:
        if args.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif args.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16

    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
        # due to pytorch #99272, MPS does not yet support bfloat16
        raise ValueError(
            "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
        )

    #text_encoder.to(accelerator.device, dtype=weight_dtype)
    transformer.to(accelerator.device, dtype=weight_dtype)
    transformer_fake.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=torch.float32)

    if args.gradient_checkpointing:
        transformer.enable_gradient_checkpointing()
        transformer_fake.enable_gradient_checkpointing()

    # We will add new LORA weights to the attention Layers
    transformer_lora_config = LoraConfig(
        r=args.rank,
        lora_alpha=args.lora_alpha,
        init_lora_weights=True,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    )

    if args.use_lora:
        transformer.add_adapter(transformer_lora_config)
        print('Successfully add LoRA to the transformer')
    else:
        print('Not using LoRA')

    from copy import deepcopy
    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model.orig_mod if is_compiled_module(model) else model
        return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
    def save_model_hook(models, weights, output_dir):
        for model in models:
                # 添加对DeepSpeed引擎的支持
                if hasattr(model, 'module') and isinstance(model, deepspeed.runtime.engine.DeepSpeedEngine):
                    # 处理DeepSpeed模型，从DeepSpeedEngine中获取原始模型
                    ds_model = accelerator.unwrap_model(model.module)
                    transformer_lora_layers_to_save = get_peft_model_state_dict(ds_model)
                elif isinstance(model, type(unwrap_model(transformer))):
                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)
                else:
                    raise ValueError(f"Unexpected save model: {model.__class__}")
                # make sure to pop weight so that corresponding models are not saved again
                if weights:  # 检查weights列表是否为空
                    weights.pop()
            
        if weights:
            weights.pop()
        
        try:
            WanPipeline.save_lora_weights(
                output_dir,
                transformer_lora_layers=transformer_lora_layers_to_save,
            )
            logger.info(f"Successfully saved LoRA weights to {output_dir}")
        except Exception as e:
            logger.error(f"Failed to save LoRA weights: {e}")
        
    
    
    def save_all_model_hook(models, weights, output_dir):
        if accelerator.is_main_process:
            os.makedirs(output_dir, exist_ok=True)
            
            for model in models:
                if isinstance(model, type(unwrap_model(transformer))):
                    # 获取完整的模型（而不仅仅是 LoRA 层）
                    unwrapped_model = unwrap_model(model)
                    # 创建变换器子目录
                    transformer_dir = os.path.join(output_dir, "transformer")
                    os.makedirs(transformer_dir, exist_ok=True)
                    # 保存完整的变换器模型
                    unwrapped_model.save_pretrained(transformer_dir)
                    print(f"已保存完整的transformer模型到 {transformer_dir}")
                else:
                    raise ValueError(f"Unexpected save model: {model.__class__}")
                # 确保弹出权重，以便不再次保存相应的模型
                weights.pop()

    def load_model_hook(models, input_dir):
        transformer_ = None
        while len(models) > 0:
            model = models.pop()
            if isinstance(model, type(unwrap_model(transformer))):
                transformer_ = model
            else:
                raise ValueError(f"Unexpected save model: {model.__class__}")

        lora_state_dict = WanPipeline.lora_state_dict(input_dir)
        transformer_state_dict = {
            f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
        }
        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
        if incompatible_keys is not None:
            # check only for unexpected keys
            unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
            if unexpected_keys:
                logger.warning(
                    f"Loading adapter weights from state dict with unexpected keys not found in the model: "
                    f" {unexpected_keys}."
                )

        # Make sure the trainable params are in float32. This is again needed since the base models
        # are in weight_dtype. More details:
        # https://github.com/huggingface/diffusers/pull/6514@discussion_r1449796884
        if args.mixed_precision == "fp16":
            # only upcast trainable parameters (LoRA) into fp32
            cast_training_params([transformer_])

    if args.use_lora:
        accelerator.register_save_state_pre_hook(save_model_hook)
        accelerator.register_load_state_pre_hook(load_model_hook)
    else:
        accelerator.register_save_state_pre_hook(save_all_model_hook)
        accelerator.register_load_state_pre_hook(load_model_hook)

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32 and torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    # Make sure the trainable params are in float32
    if args.mixed_precision == "fp16":
        # only upcast trainable parameters (LoRA) into fp32
        cast_training_params([transformer], dtype=torch.float32)

    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
    

    # Optimization parameters
    transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
    params_to_optimize = [transformer_parameters_with_lr]
    
   

    transformer_fake_lora_parameters = list(filter(lambda p: p.requires_grad, transformer_fake.parameters()))

    # Optimization parameters
    transformer_fake_parameters_with_lr = {"params": transformer_fake_lora_parameters, "lr": args.learning_rate}
    params_to_optimize_fake = [transformer_fake_parameters_with_lr]
    use_deepspeed_optimizer = (
        accelerator.state.deepspeed_plugin is not None
        and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
    )

    use_deepspeed_scheduler = (
        accelerator.state.deepspeed_plugin is not None
        and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
    )

    args.adam_beta1 = 0.
    args.learning_rate = args.learning_rate_g
    optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
    


    class CustomImagePromptDataset(Dataset):
        def __init__(self, jsonl_file, transform=None):
            self.data = []
            self.transform = transform
            # self.generator = torch.Generator().manual_seed(42)
            self.pth_base = "prompts"
            self.generator = torch.Generator()
            self.data = torch.load("prompts/prompts_shuffled.pt", weights_only=False)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            text = self.data[idx]
            return text, text


    # Create Dataset
    dataset = CustomImagePromptDataset(jsonl_file="prompts/prompts_shuffled.pt", transform=None)

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=args.train_batch_size,
        num_workers=8,
        pin_memory=True,
    )   

    def encode_video(video, bar):
        bar.update(1)
        video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
        video = video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
        latent_dist = vae.encode(video).latent_dist
        return latent_dist

    progress_encode_bar = tqdm(
        range(0, len(dataset)),
        desc="Loading Encode videos",
    )

    # train_dataset.instance_videos = [encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos]
    progress_encode_bar.close()

    def collate_fn(examples):
        videos = [
            example["instance_video"].sample() * vae.config.scaling_factor for example in examples
        ]
        prompts = [example["instance_prompt"] for example in examples]
        videos = torch.cat(videos)
        videos = videos.permute(0, 2, 1, 3, 4)
        videos = videos.to(memory_format=torch.contiguous_format).float()
        return {
            "videos": videos,
            "prompts": prompts,
        }

    # Scheduler and math around the number of training steps.
    override_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        override_max_train_steps = True
    
    if use_deepspeed_scheduler:
        from accelerate.utils import DummyScheduler

        lr_scheduler = DummyScheduler(
            name=args.lr_scheduler,
            optimizer=optimizer,
            total_num_steps=args.max_train_steps * accelerator.num_processes,
            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
        )
    else:
        lr_scheduler = get_scheduler(
            args.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
            num_training_steps=args.max_train_steps * accelerator.num_processes,
            num_cycles=args.lr_num_cycles,
            power=args.lr_power,
        )

    # Prepare everything with our accelerator.
    transformer_fake = deepcopy(transformer)
    transformer_fake_lora_parameters = list(filter(lambda p: p.requires_grad, transformer_fake.parameters()))
    

    # Optimization parameters
    transformer_fake_parameters_with_lr = {"params": transformer_fake_lora_parameters, "lr": args.learning_rate_fake}
    params_to_optimize_fake = [transformer_fake_parameters_with_lr]

    args.learning_rate = args.learning_rate_fake

    args.adam_beta1 = 0.
    optimizer_d = get_optimizer(args, params_to_optimize_fake, use_deepspeed=use_deepspeed_optimizer)


    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        transformer, optimizer, train_dataloader, lr_scheduler
    )
    

    optimizer_d, transformer_fake = accelerator_d.prepare(
        optimizer_d, transformer_fake
    )

    
    
    # we need to recalculate our total training steps as the size of the training dataloader may have changed
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if override_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        tracker_name = args.tracker_name or "wanx-lora"
        accelerator.init_trackers(tracker_name, config=vars(args))
    
    
   
    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])

    logger.info("***** Running training *****")
    logger.info(f"Num trainable parameters: {num_trainable_parameters}")
    logger.info(f"Num examples: {len(dataset)}")
    logger.info(f"Num batches each epoch: {len(train_dataloader)}")
    logger.info(f"Num epochs: {args.num_train_epochs}")
    logger.info(f"Instantaneous batch size per device: {args.train_batch_size}")
    logger.info(f"Total train batch size (w. parallel, distributed & accumulation): {total_batch_size}")
    logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
    logger.info(f"Total optimization steps: {args.max_train_steps}")

    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states for a previous save
    if not args.resume_from_checkpoint:
        initial_global_step = 0
            
    else:
        if args.resume_from_checkpoint != "latest":
            
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' not found. Starting a new training run."
            )
            args.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch
    
    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )
    # vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)

    # For DeepSpeed training
    model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config

    def extract_into_tensor(a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))

    def generate_new(transformer, noise_scheduler, latent,noise, encoder_hidden_states, image_rotary_emb, steps=4, eta=1,
                     return_mid=False, mid_points=None, encode_steps=True, shift=False, total_steps=1000):
        T_ = torch.randint(total_steps - 1, total_steps, (latent.shape[0],), device=latent.device).long()

        zero_t = torch.zeros_like(T_)
        imgs_list = []
        pure_noisy = noise
        noisy_imgs_list = []
        for ind in range(steps):
            noisy_imgs_list.append(pure_noisy)
            # v_prediction model_input_eps是预测的V
            noise_pred = transformer(
                    hidden_states=pure_noisy,
                    timestep=T_,
                    encoder_hidden_states=encoder_hidden_states,
                    return_dict=False,
                )[0]
            if T_.device != noise_scheduler.timesteps.device:
                noise_scheduler.timesteps = noise_scheduler.timesteps.to(T_.device)
            step_indices = [noise_scheduler.index_for_timestep(t, noise_scheduler.timesteps) for t in T_]
            sigma = noise_scheduler.sigmas[step_indices].to(pure_noisy.device, pure_noisy.dtype)
            while len(sigma.shape) < len(pure_noisy.shape):
                  sigma = sigma.unsqueeze(-1)
            alpha_t, sigma_t = noise_scheduler._sigma_to_alpha_sigma_t(sigma)
            # model_input预测的是x_0
            
            model_input = pure_noisy - sigma_t * noise_pred
            latent = model_input # .to(unet.dtype)

            imgs_list.append(latent)
            
            pred_epsilon = pure_noisy + alpha_t * noise_pred
            if mid_points is not None:
                T_ = mid_points[ind + 1] + zero_t
            else:
                T_ = T_ - total_steps // steps
            add_eps = eta * pred_epsilon + ((1 - eta ** 2) ** 0.5) * torch.randn_like(pred_epsilon)
            pure_noisy = noise_scheduler.add_noise(latent, add_eps, T_)
        noisy_imgs_list.append(latent)
        if return_mid:
            return imgs_list, noisy_imgs_list
        return latent
    #TODO:2
    # uncond_prompt_embeds = compute_prompt_embeddings(
    #     tokenizer,
    #     text_encoder,
    #     [""] * args.train_batch_size,
    #     model_config.max_text_seq_length,
    #     accelerator.device,
    #     weight_dtype,
    #     requires_grad=False,
    # )
    rank = dist.get_rank()
    uncond_prompt_embeds = torch.load(f'prompts/negative_prompt_embeds.pt', map_location=accelerator.device)
    # 为batch_size > 1扩展维度
    if args.train_batch_size > 1:
        uncond_prompt_embeds = uncond_prompt_embeds.repeat(args.train_batch_size, 1, 1)


    class Predictor():
        def __init__(self, noise_scheduler, alpha_schedule, sigma_schedule):
            super().__init__()
            self.noise_scheduler = noise_scheduler
            self.alpha_schedule = alpha_schedule
            self.sigma_schedule = sigma_schedule
            self.uncond_prompt_embeds = uncond_prompt_embeds

        def predict(self, score_model, noisy_samples, timesteps, encoder_hidden_states, image_rotary_emb, cfg=None, steps=1,
                    return_double=False, timestep_cond=None,return_all=False):
            '''
            pred_latents: 预测的x_0
            pred_epsilon: 预测的噪声
            '''
            # alpha_prod_t = self.noise_scheduler.alphas_cumprod[timesteps]
            #alpha_prod_t = extract_into_tensor(self.noise_scheduler., timesteps, noisy_samples.shape).to(weight_dtype)
           # beta_prod_t = 1 - alpha_prod_t
            step_indices = [self.noise_scheduler.index_for_timestep(t, self.noise_scheduler.timesteps) for t in timesteps]
            sigma = self.noise_scheduler.sigmas[step_indices].to(noisy_samples.device, weight_dtype)
            while len(sigma.shape) < len(noisy_samples.shape):
                  sigma = sigma.unsqueeze(-1)
            alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma)

            score_pred = score_model(
            hidden_states=noisy_samples,
            timestep=timesteps,
            encoder_hidden_states=encoder_hidden_states,
            return_dict=False)[0]

            if cfg is not None:
                score_uncond_pred = score_model(
                    hidden_states=noisy_samples,
                    timestep=timesteps,
                    encoder_hidden_states=self.uncond_prompt_embeds,
                    return_dict=False)[0]
                score_pred_cfg = score_uncond_pred + cfg * (score_pred - score_uncond_pred)
                #model_input = noise_scheduler.step(noise_pred, T_, pure_noisy, return_dict=False)[0]
                # 下面的直接混合x_0方法和直接混合scores方法是等价的
                pred_latents_cond = noisy_samples - sigma_t * score_pred
                pred_latents_uncond = noisy_samples - sigma_t * score_uncond_pred
                pred_latents = pred_latents_uncond + cfg * (pred_latents_cond - pred_latents_uncond)
                velocity = score_pred_cfg
                pred_epsilon = noisy_samples + alpha_t * velocity

                if return_all:
                    return pred_epsilon, pred_latents, velocity

                if return_double:
                    return pred_epsilon, pred_latents

            else:
                
                pred_latents = noisy_samples - sigma_t * score_pred
                velocity = score_pred
                pred_epsilon = noisy_samples + alpha_t * velocity
                if return_all:
                    return pred_epsilon, pred_latents, velocity
                if return_double:
                    return pred_epsilon, pred_latents

            return pred_latents

        def add_noise(self, samples, noise, t1, t2):
            '''
            x_t2 = x_t1 / alphas * alphas_new + sigmas_new ** 2 - (alphas_new / alphas * sigmas) ** 2 * noise
            '''
            # sigmas = extract_into_tensor(self.sigma_schedule, t1, samples.shape)
            # alphas = extract_into_tensor(self.alpha_schedule, t1, samples.shape)
            # sigmas_new = extract_into_tensor(self.sigma_schedule, t2, samples.shape)
            # alphas_new = extract_into_tensor(self.alpha_schedule, t2, samples.shape)
            step_indices_t1 = [self.noise_scheduler.index_for_timestep(t, self.noise_scheduler.timesteps) for t in t1]
            step_indices_t2 = [self.noise_scheduler.index_for_timestep(t, self.noise_scheduler.timesteps) for t in t2]
            sigma_t1 = self.noise_scheduler.sigmas[step_indices_t1].to(samples.device, weight_dtype)
            sigma_t2 = self.noise_scheduler.sigmas[step_indices_t2].to(samples.device, weight_dtype)
            while len(sigma_t1.shape) < len(samples.shape):
                  sigma_t1 = sigma_t1.unsqueeze(-1)
            while len(sigma_t2.shape) < len(samples.shape):
                  sigma_t2 = sigma_t2.unsqueeze(-1)
            alpha_t1, sigma_t1 = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma_t1)
            alpha_t2, sigma_t2 = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma_t2)
            samples = samples / alpha_t1 * alpha_t2 # (sigmas_new ** 2 - alphas_new ** 2 * sigmas ** 2 / alphas ** 2) ** 0.5 * noise

            beta = sigma_t2 ** 2 - (alpha_t2 / alpha_t1 * sigma_t1) ** 2

            beta = torch.clamp(beta,min=1e-8)
            
            beta = beta ** 0.5
            samples = samples + beta * noise


            return samples.to(weight_dtype)

    fixed_prompt = (
        " A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The"
        " panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other"
        " pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo,"
        " casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays."
        " The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical"
        " atmosphere of this unique musical performance "
    )
    
    # fixed_prompt_embeds = compute_prompt_embeddings(
    #     tokenizer,
    #     text_encoder,
    #     [fixed_prompt] * args.train_batch_size,
    #     model_config.max_text_seq_length,
    #     accelerator.device,
    #     weight_dtype,
    #     requires_grad=False,
    # )
    fixed_prompt_embeds = torch.load(f'prompts/fixed_prompt_embeds.pt', map_location=accelerator.device)
    # 为batch_size > 1扩展维度  
    if args.train_batch_size > 1:
        fixed_prompt_embeds = fixed_prompt_embeds.repeat(args.train_batch_size, 1, 1)

    total_steps = 1000

    K_step = args.k_step
    noise_scheduler = scheduler
    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
    alpha_schedule = alpha_schedule.to(accelerator.device).to(weight_dtype)
    sigma_schedule = sigma_schedule.to(accelerator.device).to(weight_dtype)

    predictor = Predictor(scheduler, alpha_schedule, sigma_schedule)
    # 使用扩展后的uncond_prompt_embeds
    predictor.uncond_prompt_embeds = uncond_prompt_embeds
    #prob_drop = torch.rand((args.train_batch_size,), device=noise.device)

    for epoch in range(first_epoch, args.num_train_epochs):
        transformer.train()
        transformer_fake.train()
        for step, batch in enumerate(train_dataloader):
            models_to_accumulate = [transformer]
            
            with accelerator.accumulate(models_to_accumulate):
                # encode prompts
                #prompts = list(batch[0])
                # prompt_embeds = compute_prompt_embeddings(
                #     tokenizer,
                #     text_encoder,
                #     prompts,
                #     model_config.max_text_seq_length,
                #     accelerator.device,
                #     weight_dtype,
                #     requires_grad=False,
                # )
                #TODO:3
                rank = dist.get_rank()
                index = rank * 1500 + step 
                prompts = dataset[index]
                
                # 为batch中的每个样本加载prompt_embeds并在0维拼接
                prompt_embeds_list = []
                for i in range(args.train_batch_size):
                    current_index = index if i == 0 else index + i  # 第一个使用index，其他使用连续的index
                    single_prompt_embeds = torch.load(f'prompts/individual_embeddings/{current_index}.pt', map_location=accelerator.device)
                    prompt_embeds_list.append(single_prompt_embeds)
                
                # 在第0维拼接所有的prompt_embeds
                prompt_embeds = torch.cat(prompt_embeds_list, dim=0)


                # uncond_prompt_embeds = compute_prompt_embeddings(
                #     tokenizer,
                #     text_encoder,
                #     [""] * args.train_batch_size,
                #     model_config.max_text_seq_length,
                #     accelerator.device,
                #     weight_dtype,
                #     requires_grad=False,
                # )

                # Sample noise that will be added to the Latents
                # noise = torch.randn_like(model_input)
                

                if '1.3b' in args.pretrained_model_name_or_path.lower():
                    noise = torch.randn(args.train_batch_size, 16, 21, 60, 104).to(accelerator.device).to(dtype=weight_dtype)
                else:
                    error_msg = "Unsupported model size. Please use a model with 1.3b parameters."
                    raise ValueError(error_msg)
                batch_size,  num_channels, num_frames, height, width = noise.shape
                # print("num_frams=",num_frames)

                #Prepare rotary embeds
                # image_rotary_emb = (
                #     prepare_rotary_positional_embeddings(
                #         height=args.height,
                #         width=args.width,
                #         num_frames=num_frames,
                #         vae_scale_factor_spatial=vae_scale_factor_spatial,
                #         patch_size=model_config.patch_size,
                #         attention_head_dim=model_config.attention_head_dim,
                #         device=accelerator.device,
                #     )
                #     if model_config.use_rotary_positional_embeddings
                #     else None
                # )
                image_rotary_emb = None
                with torch.no_grad():
                    new_noise = torch.randn_like(noise)
                    #print('line 1782')
                    imgs_list, noisy_imgs_list = generate_new(transformer, scheduler, new_noise, new_noise, prompt_embeds, image_rotary_emb,eta=args.eta,
                                                            steps=K_step, return_mid=True, total_steps=1000) # [[bs, 4, 64, 64]]
                    #print('line 1785')
                    noisy_imgs_list.reverse()
                    bsz = noise.shape[0]
                    k_list = torch.randint(0, K_step, (bsz,), device=noise.device).long()

                    model_input = torch.randn_like(noise)
                    for ii in range(model_input.shape[0]):
                        model_input[ii] = imgs_list[k_list[ii]][ii]

            # Train Fake Score .
            # print('-'*10)
            # print('train fake score')
            # print('-'*10)
            with accelerator_d.accumulate([transformer_fake]):
                noise = torch.randn_like(model_input)
                Ind_t = torch.randint(1, K_step + 1, (bsz,), device=noise.device).long()

                noisy_latents_ode = torch.randn_like(noise)
                for i in range(noise.shape[0]):
                    noisy_latents_ode[i] = noisy_imgs_list[Ind_t[i]][i]

                timesteps_g = Ind_t * total_steps // K_step - 1
                timesteps_mid = timesteps_g - total_steps // K_step + 1
                timesteps = timesteps_g.clone() * 0
                for ind_bw in range(bsz):
                    timesteps[ind_bw] = torch.randint(timesteps_mid[ind_bw], 980, (1,), device=noise.device)[0]
                timesteps = timesteps.long()
                with torch.no_grad():
                    # print('-'*10)
                    # print('predictor 1803')
                    model_eps, model_latents = predictor.predict(transformer, noisy_latents_ode, timesteps_g, prompt_embeds, image_rotary_emb,
                                                                    return_double=True)
                    #print('-'*10)
                #print(model_eps.shape)
                add_eps = args.eta * model_eps + ((1 - args.eta ** 2) ** 0.5) * torch.randn_like(model_eps)
                noisy_model_latents_ode = noise_scheduler.add_noise(model_latents, add_eps, timesteps_mid).to(weight_dtype)
                noisy_model_latents = predictor.add_noise(noisy_model_latents_ode.detach(), torch.randn_like(noisy_model_latents_ode), timesteps_mid, timesteps).to(weight_dtype)
            
                if args.lambda_reg > 0:
                    with torch.no_grad():
                        _,real_latents, real_velocity = predictor.predict(transformer_real, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb,return_all=True)

                
                _,fake_latents, fake_velocity = predictor.predict(transformer_fake, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb,return_all=True)


                if timesteps.device != scheduler.alphas_cumprod.device:
                    scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(timesteps.device)
                # alphas_cumprod = scheduler.alphas_cumprod[timesteps]
                # weights = 1 / (1 - alphas_cumprod)
                # while len(weights.shape) < len(fake_latents.shape):
                #     weights = weights.unsqueeze(-1)
                # target = model_latents
                # loss = torch.mean((weights * (fake_latents - target) ** 2).reshape(batch_size, -1), dim=1)
                # loss = loss.mean()
                step_indices = [scheduler.index_for_timestep(t, scheduler.timesteps) for t in timesteps]
                sigma = scheduler.sigmas[step_indices].to(model_latents.device, weight_dtype)
                while len(sigma.shape) < len(model_latents.shape):
                    sigma = sigma.unsqueeze(-1)
                _, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma)
                target = (noisy_model_latents - model_latents) / sigma_t
                loss = torch.mean(((fake_velocity - target) ** 2).reshape(batch_size, -1), dim=1)
                loss = loss.mean()
                loss_fake = loss

                # 检查loss是否异常，避免数值不稳定
                if loss_fake > 2.0:
                    print(f"Skipping backward pass due to abnormal loss_fake: {loss_fake.item():.2f}")
                    skip_fake_backward = True
                else:
                    skip_fake_backward = False
                    if args.lambda_reg > 0:
                        # print('---using lambda_reg---')
                        loss_reg = torch.mean(((fake_velocity - real_velocity) ** 2).reshape(batch_size, -1), dim=1)
                        loss_reg = loss_reg.mean()
                        # print('1835')
                        # print('loss=',loss)
                        # print('args.lambda_reg=',args.lambda_reg)
                        # print('loss_reg=',loss_reg)
                        accelerator_d.backward((loss + args.lambda_reg * loss_reg))
                        
                        #print('1837')
                    else:
                        accelerator_d.backward(loss)
                
                # 只有在正常执行backward后才进行梯度处理
                if not skip_fake_backward:
                    if accelerator_d.sync_gradients:
                        params_to_clip = transformer_fake.parameters()
                        #print('1842')
                        accelerator_d.clip_grad_norm_(params_to_clip, args.max_grad_norm)

                    if accelerator_d.state.deepspeed_plugin is None:
                        #print('1844')
                        optimizer_d.step()
                        optimizer_d.zero_grad()
       
            # Train K-step generator
            with accelerator.accumulate([transformer]):
                with torch.no_grad():
                    ind_t = torch.randint(1, K_step + 1, (bsz,), device=noise.device).long()
                    noisy_latents = torch.randn_like(noise)
                    for i in range(noise.shape[0]):
                        noisy_latents[i] = noisy_imgs_list[ind_t[i]][i]

                    timesteps_g = ind_t * total_steps // K_step - 1
                timesteps_mid = timesteps_g - total_steps // K_step + 1
                timesteps = timesteps_g.clone() * 0
                for ind_bw in range(bsz):
                    timesteps[ind_bw] = torch.randint(timesteps_mid[ind_bw], 980, (1,), device=noise.device)[0]
                timesteps = timesteps.long()
                model_eps, model_latents = predictor.predict(transformer, noisy_latents, timesteps_g, prompt_embeds, image_rotary_emb, return_double=True)

                noise = torch.randn_like(model_latents)
                add_eps = args.eta * model_eps +  ((1 - args.eta**2) ** 0.5) * torch.randn_like(model_eps)
                noisy_model_latents_ode = noise_scheduler.add_noise(model_latents, add_eps, timesteps_mid).to(weight_dtype)
                noisy_model_latents = predictor.add_noise(noisy_model_latents_ode.detach(), torch.randn_like(noisy_model_latents_ode), timesteps_mid, timesteps).to(weight_dtype)

                with torch.no_grad():
                    real_latents = predictor.predict(transformer_real, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb, cfg=args.cfg)
                    fake_latents = predictor.predict(transformer_fake, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb)
                    revised_latents = (model_latents + real_latents - fake_latents).detach().float()

                # alphas_cumprod = scheduler.alphas_cumprod[timesteps]
                # weights = 1 / (1 - alphas_cumprod)
                # while len(weights.shape) < len(fake_latents.shape):
                #     weights = weights.unsqueeze(-1)
                di_target = revised_latents.detach()
                weighting_factor = torch.abs(model_latents - real_latents).mean(dim=[1, 2, 3, 4], keepdim=True).detach()
                weighting_factor = torch.where(weighting_factor > 5, torch.randn_like(weighting_factor)*0 + 5, weighting_factor)
                args.huber_c = 1e-3 / (((64*64*4)**0.5)) * ((noisy_model_latents.shape[1:].numel())**0.5)
                loss = torch.mean(
                    (torch.sqrt((model_latents.float() - revised_latents.detach().float()) ** 2 + args.huber_c ** 2) - args.huber_c)
                )

                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    params_to_clip = transformer.parameters()
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

                if accelerator.state.deepspeed_plugin is None:
                    optimizer.step()
                    optimizer.zero_grad()
            
                lr_scheduler.step()
                    # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    if global_step % args.checkpointing_steps == 0:
                        
                        with torch.no_grad():
                            model_input = generate_new(transformer, scheduler, noise, noise, fixed_prompt_embeds, image_rotary_emb, eta=1, steps=4)
                            # [ batch_size , num_channels , num frames , height , width ]
                        print(prompts)
                        with torch.no_grad():
                            #model_input = rearrange(model_input, "b f c h w -> b c f h w")
                            latents_mean = (
                                torch.tensor(vae.config.latents_mean)
                                .view(1, vae.config.z_dim, 1, 1, 1)
                                .to(model_input.device, model_input.dtype)
                            )
                            latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
                                model_input.device, model_input.dtype)
                            model_input = model_input / latents_std + latents_mean
                            images_t = vae.decode(model_input.to(vae.dtype), return_dict=False)[0]
                            if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
                                # video_processor = VideoProcessor()
                                # video = video_processor.postprocess_video(images_t, output_type="np")
                                # WanPipelineOutput(frames=video)
                                # export_to_video(video, f'./{args.output_dir}/videos_{global_step}.mp4',fps=16)
                                save_videos_grid(images_t.cpu(), f'./{args.output_dir}/Videos_{global_step}.mp4')

                        with torch.no_grad():
                            model_input = generate_new(transformer, scheduler, noise, noise, prompt_embeds, image_rotary_emb, eta=args.eta, steps=K_step)
                
                        
                        with torch.no_grad():
                            # [ batch_size , num_channels , num frames , height , width ]
                            #model_input = rearrange(model_input, "b f c h w -> b c f h w")
                            latents_mean = (
                                torch.tensor(vae.config.latents_mean)
                                .view(1, vae.config.z_dim, 1, 1, 1)
                                .to(model_input.device, model_input.dtype)
                            )
                            latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
                                model_input.device, model_input.dtype)
                            model_input = model_input / latents_std + latents_mean
                            images_t = vae.decode(model_input.to(vae.dtype), return_dict=False)[0]
                            if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
                                # video_processor = VideoProcessor()
                                # video = video_processor.postprocess_video(images_t, output_type="np")
                                # WanPipelineOutput(frames=video)
                                # export_to_video(video, f'{args.output_dir}/videos_{global_step}.mp4',fps=16)
                                save_videos_grid(images_t.cpu(), f'{args.output_dir}/Videos_{global_step}.mp4')
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED :
                    if global_step % args.checkpointing_steps == 0:
                        # before saving state, check if this save would set us over the checkpoints_total_limit
                        if args.checkpoints_total_limit is not None:
                            checkpoints = os.listdir(args.output_dir)
                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))

                            # before we save the new checkpoint, we need to have at most `checkpoints_total_limit - 1` checkpoints
                            if len(checkpoints) >= args.checkpoints_total_limit:
                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                                removing_checkpoints = checkpoints[:num_to_remove]

                                logger.info(
                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                                )
                                logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")

                                for removing_checkpoint in removing_checkpoints:
                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                    shutil.rmtree(removing_checkpoint)

                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

            logs = {"loss_fake": loss_fake.detach().item(), "loss_du": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

            if global_step >= args.max_train_steps:
                break

        # if accelerator.is_main_process:
        #     if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
        #         # Create pipeline
        #         pipe = CogVideoXPipeline.from_pretrained(
        #             args.pretrained_model_name_or_path,
        #             transformer=unwrap_model(transformer),
        #             text_encoder=unwrap_model(text_encoder),
        #             scheduler=scheduler,
        #             revision=args.revision,
        #             variant=args.variant,
        #             torch_dtype=weight_dtype,
        #         )
        #         validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
        #         for validation_prompt in validation_prompts:
        #             pipeline_args = {
        #                 "prompt": validation_prompt,
        #                 "guidance_scale": args.guidance_scale,
        #                 "use_dynamic_cfg": args.use_dynamic_cfg,
        #                 "height": args.height,
        #                 "width": args.width,
        #             }
        #             validation_outputs = log_validation(
        #                 pipe=pipe,
        #                 args=args,
        #                 accelerator=accelerator,
        #                 pipeline_args=pipeline_args,
        #                 epoch=epoch,
        #             )

    # Save the Lora Layers
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        transformer = unwrap_model(transformer)
        dtype = (
            torch.float16
            if args.mixed_precision == "fp16"
            else torch.bfloat16
            if args.mixed_precision == "bf16"
            else torch.float32
        )
        transformer = transformer.to(dtype)
        transformer_lora_layers = get_peft_model_state_dict(transformer)
        WanPipeline.save_lora_weights(
            save_directory=args.output_dir,
            transformer_Lora_Layers=transformer_lora_layers,
        )

        # Cleanup trained models to save memory
        del transformer
        free_memory()

        # Final test inference
        pipe = WanPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            revision=args.revision,
            variant=args.variant,
            torch_dtype=weight_dtype,
        )
        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

        if args.enable_slicing:
            pipe.vae.enable_slicing()
        if args.enable_tiling:
            pipe.vae.enable_tiling()

        #Load LORA weights
        lora_scaling = args.lora_alpha / args.rank
        pipe.load_lora_weights(args.output_dir, adapter_name="wanx-lora")
        pipe.set_adapters(["wanx-lora"], [lora_scaling])

        #Run inference
        validation_outputs = []

        if args.validation_prompt and args.num_validation_videos > 0:
            validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
            for validation_prompt in validation_prompts:
                pipeline_args = {
                    "prompt": validation_prompt,
                    "guidance_scale": args.guidance_scale,
                    "use_dynamic_cfg": args.use_dynamic_cfg,
                    "height": args.height,
                    "width": args.width,
                }
                video = log_validation(
                    pipe=pipe,
                    args=args,
                    accelerator=accelerator,
                    pipeline_args=pipeline_args,
                    epoch=epoch,
                    is_final_validation=True,
                )
                validation_outputs.extend(video)

        if args.push_to_hub:
            save_model_card(
                repo_id=repo_id,
                videos=validation_outputs,
                base_model=args.pretrained_model_name_or_path,
                validation_prompt=args.validation_prompt,
                repo_folder=args.output_dir,
                fps=args.fps,
            )
            upload_folder(
                repo_id=repo_id,
                folder_path=args.output_dir,
                commit_message="End of training",
                ignore_patterns=["step_*", "epoch_*"],
            )

    accelerator.end_training()

if __name__ == "__main__":
    args = get_args()
    args.output_dir = args.output_dir + f"_lambda-reg_{args.lambda_reg}_cfg_{args.cfg}_eta_{args.eta}_K_{args.k_step}"
    main(args)




