import argparse
import logging
import math
import os
import shutil
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torchvision.transforms as TT
import transformers
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer

import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import cast_training_params, free_memory
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module

import os
import imageio
import numpy as np
from typing import Union

import torch
import torchvision
import torch.distributed as dist
from einops import rearrange
import sys
import traceback
sys.tracebacklimit = 50  # 或更大的数字
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from accelerate.utils import DeepSpeedPlugin
import torch.nn as nn
import deepspeed




def save_videos_grid(videos: torch.Tensor, path: str, rescale=True, n_rows=6, fps=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
            x = (x.to(torch.float32) * 255).numpy().astype(np.uint8)
        outputs.append(x)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, fps=fps)

if is_wandb_available():
    import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# check_min_version("0.32.0.dev0")

logger = get_logger(__name__)

def get_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")

    # Model information
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--pretrained_lora_model_name_or_path",
        type=str,
        default=None,
        required=False,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        default=None,
        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
    )   

    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )

    # Dataset information
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that 🤗 Datasets can understand."
        ),
    )

    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )

    parser.add_argument(
        "--instance_data_root",
        type=str,
        default=None,
        help=("A folder containing the training data."),
    )

    parser.add_argument(
        "--video_column",
        type=str,
        default="video",
        help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.",
    )

    parser.add_argument(
    "--caption_column",
    type=str,
    default="text",
    help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in '--instance_data_root' folder containing the line-separated prompts.",
    )

    parser.add_argument(
        "--id_token",
        type=str,
        default=None,
        help="Identifier token appended to the start of each prompt if provided.",
    )

    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
        ),
    )

    # Validation
    parser.add_argument(
        "--validation_prompt",
        type=str,
        default=None,
        help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator'.",
    )

    parser.add_argument(
        "--validation_prompt_separator",
        type=str,
        default=":::",
        help="String that separates multiple validation prompts",
    )

    parser.add_argument(
        "--num_validation_videos",
        type=int,
        default=1,
        help="Number of videos that should be generated during validation per 'validation_prompt'.",
    )

    parser.add_argument(
        "--validation_epochs",
        type=int,
        default=50,
        help=(
            "Run validation every X epochs. Validation consists of running the prompt 'args.validation_prompt' multiple times: 'args.num_validation_videos'."
        ),
    )
    parser.add_argument(
    "--guidance_scale",
    type=float,
    default=6,
    help="The guidance scale to use while sampling validation videos.",
    )

    parser.add_argument(
        "--use_dynamic_cfg",
        action="store_true",
        default=False,
        help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
    )

    # Training information
    parser.add_argument(
        "--seed", type=int, default=None, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--rank",
        type=int,
        default=128,
        help=("The dimension of the LoRA update matrices."),
    )

    parser.add_argument(
        "--lora_alpha",
        type=float,
        default=128,
        help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
    )

    parser.add_argument(
        "--lambda_reg",
        type=float,
        default=0.,
        help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
    )

    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
        ),
    )
    parser.add_argument(
    "--output_dir",
    type=str,
    default="cogvideo-tdm-tcd-lora-fixbug",
    help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--k_step",
        type=int,
        default=4,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--cfg",
        type=float,
        default=5,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--eta",
        type=float,
        default=0.9,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--height",
        type=int,
        default=480,
        help="All input videos are resized to this height.",
    )

    parser.add_argument(
        "--width",
        type=int,
        default=720,
        help="All input videos are resized to this width.",
    )

    parser.add_argument(
    "--video_reshape_mode",
    type=str,
    default="center",
    help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)

    parser.add_argument(
        "--fps", type=int, default=8, help="All input videos will be used at this FPS."
    )

    parser.add_argument(
        "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
    )

    parser.add_argument(
        "--skip_frames_start",
        type=int,
        default=0,
        help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
    )

    parser.add_argument(
        "--skip_frames_end",
        type=int,
        default=0,
        help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
    )

    parser.add_argument(
        "--random_flip",
        action="store_true",
        help="whether to randomly flip videos horizontally",
    )

    parser.add_argument(
        "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
    )

    parser.add_argument("--num_train_epochs", type=int, default=1)

    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides '--num_train_epochs'.",
    )

    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
            " training using '--resume_from_checkpoint'."
        ),
    )
    parser.add_argument(
    "--checkpoints_total_limit",
    type=int,
    default=None,
    help="Max number of checkpoints to store.",
    )

    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            " --checkpointing_steps, or 'latest' to automatically select the last available checkpoint."
        ),
    )

    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )

    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )

    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
        "--learning_rate_fake",
        type=float,
        default=1e-3,
        help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
    "--learning_rate_g",
    type=float,
    default=1e-4,
    help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )

    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            "The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',"
            "'constant', 'constant_with_warmup']"
        ),
    )

    parser.add_argument(
        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
    )

    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
    )

    parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")

    parser.add_argument(
        "--enable_slicing",
        action="store_true",
        default=False,
        help="Whether or not to use VAE slicing for saving memory.",
    )

    parser.add_argument(
        "--enable_tiling",
        action="store_true",
        default=False,
        help="Whether or not to use VAE tiling for saving memory.",
    )

    # Optimizer
    parser.add_argument(
        "--optimizer",
        type=lambda s: s.lower(),
        default="adam",
        choices=["adam", "adamw", "prodigy"],
        help="The optimizer type to use.",
    )

    parser.add_argument(
        "--use_8bit_adam",
        action="store_true",
        help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
    )

    parser.add_argument(
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam and Prodigy optimizers.",
    )

    parser.add_argument(
        "--adam_beta2",
        type=float,
        default=0.95,
        help="The beta2 parameter for the Adam and Prodigy optimizers.",
    )

    parser.add_argument(
        "--prodigy_beta3",
        type=float,
        default=None,
        help=(
            "Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses"
            " the value of square root of beta2."
        ),
    )

    parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")

    parser.add_argument(
        "--adam_weight_decay",
        type=float,
        default=1e-04,
        help="Weight decay to use for unet params",
    )

    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
    )

    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")

    parser.add_argument(
        "--prodigy_use_bias_correction",
        action="store_true",
        help="Turn on Adam's bias correction.",
    )

    parser.add_argument(
        "--prodigy_safeguard_warmup",
        action="store_true",
        help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
    )

    # Other information
    parser.add_argument(
        "--tracker_name",
        type=str,
        default=None,
        help="Project tracker name"
    )
    parser.add_argument(
        "--push_to_hub",
        action="store_true",
        help="Whether or not to push the model to the Hub."
    )
    parser.add_argument(
        "--hub_token",
        type=str,
        default=None,
        help="The token to use to push to the Model Hub."
    )
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )

    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help="Directory where logs are stored.",
    )

    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    
    parser.add_argument(
        "--use_sparsity",
        type=lambda x: x.lower() == 'true',
        default=False,
        help="Whether to use sparsity in the model training. Set to 'true' or 'false'.",
    )

    parser.add_argument(
        "--report_to",
        type=str,
        default=None,
        help=(
            "The integration to report the results and logs to. Supported platforms are `tensorboard` (default),"
            "`wandb`, and `comet_ml`. Use `all` to report to all integrations."
        ),
    )

    return parser.parse_args()


        
def save_model_card(
    repo_id: str,
    videos=None,
    base_model: str = None,
    validation_prompt=None,
    repo_folder=None,
    fps=8,
):
    widget_dict = []
    if videos is not None:
        for i, video in enumerate(videos):
            export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
            widget_dict.append(
                {"text": validation_prompt if validation_prompt else "", "output": {"url": f"video_{i}.mp4"}}
            )

    model_description = f"""
# CogVideoX LoRA - {repo_id}

## Model description

These are {repo_id} LoRA weights for {base_model}.

The weights were trained using the [CogVideoX diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).

Was LoRA for the text encoder enabled? No.

## Download model

[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.

## Use it with the [diffusers library](https://github.com/huggingface/diffusers)

```py
from diffusers import CogVideoXPipeline
import torch

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
"""
    model_card = load_or_create_model_card(
        repo_id_or_path=repo_id,
        from_training=True,
        license="other",
        base_model=base_model,
        prompt=validation_prompt,
        model_description=model_description,
        widget=widget_dict,
    )
    tags = [
        "text-to-video",
        "diffusers-training",
        "diffusers",
        "lora",
        "cogvideox",
        "cogvideox-diffusers",
        "template:sd-lora",
    ]
    model_card = populate_model_card(model_card, tags=tags)
    model_card.save(os.path.join(repo_folder, "README.md"))

def log_validation(
    pipe,
    args,
    accelerator,
    pipeline_args,
    epoch,
    is_final_validation: bool = False,
):
    logger.info(
        f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
    )
    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
    scheduler_args = {}
    if "variance_type" in pipe.scheduler.config:
        variance_type = pipe.scheduler.config.variance_type
        if variance_type in ["learned", "learned_range"]:
            variance_type = "fixed_small"
        scheduler_args["variance_type"] = variance_type

    pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, scheduler_args)
    pipe.to(accelerator.device)

    # pipe.set_progress_bar_config(disable=True)

    # run inference
    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None

    videos = []

    for _ in range(args.num_validation_videos):
        pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
        pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])

        image_np = VaeImageProcessor.pt_to_numpy(pt_images)
        image_pil = VaeImageProcessor.numpy_to_pil(image_np)
        videos.append(image_pil)

    for tracker in accelerator.trackers:
        phase_name = "test" if is_final_validation else "validation"
        if tracker.name == "wandb":
            video_filenames = []
            for i, video in enumerate(videos):
                prompt = (
                    pipeline_args["prompt"][:25]
                    .replace("  ", "_")
                    .replace(" ", "_")
                    .replace("'", "_")
                    .replace('""', "_")
                    .replace("/", "_")
                )
                filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
                export_to_video(video, filename, fps=8)
                video_filenames.append(filename)

            tracker.log(
                {
                    phase_name: [
                        wandb.Video(filename, caption=f"({i}): {pipeline_args['prompt']}")
                        for i, filename in enumerate(video_filenames)
                    ]
                }
            )
        
    del pipe
    free_memory()
    return videos
    
def _get_t5_prompt_embeds(
    tokenizer: T5Tokenizer,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]],
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 226,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    text_input_ids: Optional[torch.Tensor] = None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    if tokenizer is not None:
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
    else:
        if text_input_ids is None:
            raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    # duplicate text embeddings for each generation per prompt, using mps friendly method
    _,seq_len,_ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)

    return prompt_embeds

def encode_prompt(
    tokenizer: T5Tokenizer,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]],
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 226,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    text_input_ids: Optional[torch.Tensor] = None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    prompt_embeds = _get_t5_prompt_embeds(
        tokenizer,
        text_encoder,
        prompt=prompt,
        num_videos_per_prompt=num_videos_per_prompt,
        max_sequence_length=max_sequence_length,
        device=device,
        dtype=dtype,
        text_input_ids=text_input_ids,
    )

    return prompt_embeds


def compute_prompt_embeddings(
    tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
):
    if requires_grad:
        prompt_embeds = encode_prompt(
            tokenizer,
            text_encoder,
            prompt,
            num_videos_per_prompt=1,
            max_sequence_length=max_sequence_length,
            device=device,
            dtype=dtype,
        )
    else:
        with torch.no_grad():
            prompt_embeds = encode_prompt(
                tokenizer,
                text_encoder,
                prompt,
                num_videos_per_prompt=1,
                max_sequence_length=max_sequence_length,
                device=device,
                dtype=dtype,
        )
    return prompt_embeds


def prepare_rotary_positional_embeddings(
    height: int,
    width: int,
    num_frames: int,
    vae_scale_factor_spatial: int = 8,
    patch_size: int = 2,
    attention_head_dim: int = 64,
    device: Optional[torch.device] = None,
    base_height: int = 480,
    base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
    grid_height = height // (vae_scale_factor_spatial * patch_size)
    grid_width = width // (vae_scale_factor_spatial * patch_size)
    base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
    base_size_height = base_height // (vae_scale_factor_spatial * patch_size)

    grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
    freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
        embed_dim=attention_head_dim,
        crops_coords=grid_crops_coords,
        grid_size=(grid_height,grid_width),
        temporal_size=num_frames,
    )

    freqs_cos = freqs_cos.to(device=device)
    freqs_sin = freqs_sin.to(device=device)
    return freqs_cos, freqs_sin


def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
    # Use deepspeed optimizer
    if use_deepspeed:
        from accelerate.utils import DummyOptim

        return DummyOptim(
            params_to_optimize,
            lr=args.learning_rate,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            weight_decay=args.adam_weight_decay,
        )

    # Optimizer creation
    supported_optimizers = ["adam", "adamw", "prodigy"]
    if args.optimizer not in supported_optimizers:
        logger.warning(
            f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
        )
        args.optimizer = "adamw"

    if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]:
        logger.warning(
            f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was set to {args.optimizer.lower()}"
        )
    

    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")

    if args.optimizer.lower() == "adam":
        optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
        print('optimizer==adam')
        optimizer = optimizer_class(
            params_to_optimize,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            weight_decay=args.adam_weight_decay,
        )

    elif args.optimizer.lower() == "adamw":
        optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
        optimizer = optimizer_class(
            params_to_optimize,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            weight_decay=args.adam_weight_decay,
        )
        
    elif args.optimizer.lower() == "prodigy":
        try:
            import prodigyopt
        except ImportError:
            raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`.")

        optimizer_class = prodigyopt.Prodigy

        if args.learning_rate <= 0.1:
            logger.warning(
                "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
            )

        optimizer = optimizer_class(
            params_to_optimize,
            betas=(args.adam_beta1, args.adam_beta2),
            beta3=args.prodigy_beta3,
            weight_decay=args.adam_weight_decay,
            eps=args.adam_epsilon,
            decouple=args.prodigy_decouple,
            safeguard_warmup=args.prodigy_safeguard_warmup,
            use_bias_correction=args.prodigy_use_bias_correction,
        )

    return optimizer


def main(args):
    if args.report_to == "wandb" and args.hub_token is not None:
        raise ValueError(
            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token. "
            "Please use `huggingface-cli login` to authenticate with the Hub."
        )
    if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
        # due to pytorch #99272, MPS does not yet support bfloat16
        raise ValueError(
            "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
        )

    logging_dir = Path(args.output_dir, args.logging_dir)
    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs],
    )
    # accelerator_second = Accelerator()
    print('-'*10)
    print('accelerator.state.deepspeed_plugin=',accelerator.state.deepspeed_plugin)
    print('-'*10)
    accelerator_d = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs],
    )

    # Disable AMP for MPS
    if torch.backends.mps.is_available():
        accelerator.native_amp = False

    if args.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

        if args.push_to_hub:
            repo_id = create_repo(
                repo_id=args.hub_model_id or Path(args.output_dir).name,
                exist_ok=True,
            ).repo_id

    TODO: 离线处理embedding以节约显存
    #Prepare models and scheduler
    # tokenizer = AutoTokenizer.from_pretrained(
    #     args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
    # )

    # text_encoder = T5EncoderModel.from_pretrained(
    #     args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
    # )

    # CogVideoX-2b weights are stored in float16
    # CogVideoX-5b and CogVideoX-5b-128 weights are stored in bfloat16
    #load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16

    load_dtype = torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16
    transformer = CogVideoXTransformer3DModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="transformer",
        torch_dtype=load_dtype,
        revision=args.revision,
        variant=args.variant,
    )
    
        
    set_block_sparse = args.use_sparsity
    if set_block_sparse:
        from modify_cogvideo import set_block_sparse_attn_cogvideox
        set_block_sparse_attn_cogvideox(transformer)
        print('----successfuly set_block_sparse-----')
    else:
        print('----not set_block_sparse-----')
    
    
    transformer_fake = CogVideoXTransformer3DModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="transformer",
        torch_dtype=load_dtype,
        revision=args.revision,
        variant=args.variant,
    )

    transformer_real = CogVideoXTransformer3DModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="transformer",
        torch_dtype=load_dtype,
        revision=args.revision,
        variant=args.variant,
    ).to(accelerator.device)
    
    transformer_real.requires_grad_(False)
    vae = AutoencoderKLCogVideoX.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
    )
    vae.enable_slicing()

    scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

    if args.enable_slicing:
        vae.enable_slicing()
    if args.enable_tiling:
        vae.enable_tiling()

    # We only train the additional adapter LoRA Layers
    #text_encoder.requires_grad_(False)
    transformer.requires_grad_(False)
    vae.requires_grad_(False)
    transformer_real.requires_grad_(False)
    
    transformer_fake.requires_grad_(False)

    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.state.deepspeed_plugin:
        # DeepSpeed is handling precision, use what's in the DeepSpeed config
        if (
            "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
            and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
        ):
            weight_dtype = torch.float16
        if (
            "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
            and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
        ):
            weight_dtype = torch.bfloat16
    else:
        if args.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif args.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16

    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
        # due to pytorch #99272, MPS does not yet support bfloat16
        raise ValueError(
            "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
        )
    

    #text_encoder.to(accelerator.device, dtype=weight_dtype)
    transformer.to(accelerator.device, dtype=weight_dtype)
    
    transformer_fake.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)

    if args.gradient_checkpointing:
        transformer.enable_gradient_checkpointing()
        transformer_fake.enable_gradient_checkpointing()

    

    # 1. 使用配置创建适配器
    if args.pretrained_lora_model_name_or_path is not None:
        transformer_lora_config = LoraConfig(
            r=args.rank,
            lora_alpha=args.lora_alpha,
            init_lora_weights=True,
            target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        )
        transformer.add_adapter(transformer_lora_config, adapter_name="default")

        # 2. 从指定路径加载状态字典（最好使用参数而非硬编码路径）
        input_dir = args.pretrained_lora_model_name_or_path  # 使用参数代替硬编码路径
        lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)

        # 3. 处理状态字典与PEFT兼容
        transformer_state_dict = {
            f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() 
            if k.startswith("transformer.")
        }
        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)

        # 4. 使用PEFT的API加载权重
        incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")

        # 5. 保留精度处理代码
        if args.mixed_precision == "fp16":
            cast_training_params([transformer])

        print('----successfully load_adapter-----')
    else:
        transformer_lora_config = LoraConfig(
        r=args.rank,
        lora_alpha=args.lora_alpha,
        init_lora_weights=True,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        )
        transformer.add_adapter(transformer_lora_config)
        print('---use baseline lora weights---')

    # transformer_fake.add_adapter(transformer_lora_config)
    from copy import deepcopy
    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model.orig_mod if is_compiled_module(model) else model
        return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
    def save_model_hook(models, weights, output_dir):
        if accelerator.is_main_process:
            transformer_lora_layers_to_save = None

            for model in models:
                # 添加对DeepSpeed引擎的支持
                if hasattr(model, 'module') and isinstance(model, deepspeed.runtime.engine.DeepSpeedEngine):
                    # 处理DeepSpeed模型，从DeepSpeedEngine中获取原始模型
                    ds_model = accelerator.unwrap_model(model.module)
                    transformer_lora_layers_to_save = get_peft_model_state_dict(ds_model)
                elif isinstance(model, type(unwrap_model(transformer))):
                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)
                else:
                    raise ValueError(f"Unexpected save model: {model.__class__}")
                # make sure to pop weight so that corresponding models are not saved again
                if weights:  # 检查weights列表是否为空
                    weights.pop()

            CogVideoXPipeline.save_lora_weights(
                output_dir,
                transformer_lora_layers=transformer_lora_layers_to_save,
            )

    def load_model_hook(models, input_dir):
        transformer_ = None
        while len(models) > 0:
            model = models.pop()
            # 添加对DeepSpeed引擎的支持
            if hasattr(model, 'module') and isinstance(model, deepspeed.runtime.engine.DeepSpeedEngine):
                # 处理DeepSpeed模型，从DeepSpeedEngine中获取原始模型
                transformer_ = accelerator.unwrap_model(model.module)
            elif isinstance(model, type(unwrap_model(transformer))):
                transformer_ = model
            else:
                raise ValueError(f"Unexpected save model: {model.__class__}")

        lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
        transformer_state_dict = {
            f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
        }
        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
        if incompatible_keys is not None:
            # check only for unexpected keys
            unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
            if unexpected_keys:
                logger.warning(
                    f"Loading adapter weights from state dict with unexpected keys not found in the model: "
                    f" {unexpected_keys}."
                )

        # Make sure the trainable params are in float32. This is again needed since the base models
        # are in weight_dtype. More details:
        # https://github.com/huggingface/diffusers/pull/6514@discussion_r1449796884
        if args.mixed_precision == "fp16":
            # only upcast trainable parameters (LoRA) into fp32
            cast_training_params([transformer_])

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32 and torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    # Make sure the trainable params are in float32
    if args.mixed_precision == "fp16":
        # only upcast trainable parameters (LoRA) into fp32
        cast_training_params([transformer], dtype=torch.float32)

    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
    

    # Optimization parameters
    transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
    params_to_optimize = [transformer_parameters_with_lr]
    
   

    transformer_fake_lora_parameters = list(filter(lambda p: p.requires_grad, transformer_fake.parameters()))

    # Optimization parameters
    transformer_fake_parameters_with_lr = {"params": transformer_fake_lora_parameters, "lr": args.learning_rate}
    params_to_optimize_fake = [transformer_fake_parameters_with_lr]
    use_deepspeed_optimizer = (
        accelerator.state.deepspeed_plugin is not None
        and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
    )

    use_deepspeed_scheduler = (
        accelerator.state.deepspeed_plugin is not None
        and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
    )

    args.adam_beta1 = 0.
    args.learning_rate = args.learning_rate_g
    optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
    

    # args.adam_beta1 = 0.
    # args.learning_rate = = 1e-3
    # optimizer_d = get_optimizer(args, params_to_optimize_fake, use_deepspeed=use_deepspeed_optimizer)


    class CustomImagePromptDataset(Dataset):
        def __init__(self, jsonl_file, transform=None):
            self.data = []
            self.transform = transform
            # self.generator = torch.Generator().manual_seed(42)
            self.pth_base = "prompts"
            self.generator = torch.Generator()
            self.data = torch.load("prompts/shuffled_prompts.pt", weights_only=False)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            text = self.data[idx]
            return text, text


    # Create Dataset
    dataset = CustomImagePromptDataset(jsonl_file="prompts/shuffled_prompts.pt", transform=None)

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=args.train_batch_size,
        num_workers=8,
        pin_memory=True,
    )   

    progress_encode_bar = tqdm(
        range(0, len(dataset)),
        desc="Loading Encode videos",
    )

    progress_encode_bar.close()
    override_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        override_max_train_steps = True
    
    if use_deepspeed_scheduler:
        from accelerate.utils import DummyScheduler

        lr_scheduler = DummyScheduler(
            name=args.lr_scheduler,
            optimizer=optimizer,
            total_num_steps=args.max_train_steps * accelerator.num_processes,
            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
        )
    else:
        lr_scheduler = get_scheduler(
            args.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
            num_training_steps=args.max_train_steps * accelerator.num_processes,
            num_cycles=args.lr_num_cycles,
            power=args.lr_power,
        )

    # Prepare everything with our accelerator.
    transformer_fake = deepcopy(transformer)
    transformer_fake_lora_parameters = list(filter(lambda p: p.requires_grad, transformer_fake.parameters()))
    

    # Optimization parameters
    transformer_fake_parameters_with_lr = {"params": transformer_fake_lora_parameters, "lr": args.learning_rate_fake}
    params_to_optimize_fake = [transformer_fake_parameters_with_lr]

    args.learning_rate = args.learning_rate_fake

    args.adam_beta1 = 0.
    num_trainable_parameters = sum(param.numel() for model in params_to_optimize_fake for param in model["params"])
    optimizer_d = get_optimizer(args, params_to_optimize_fake, use_deepspeed=use_deepspeed_optimizer)

    
    num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
    
    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        transformer, optimizer, train_dataloader, lr_scheduler
    )
    
    # transformer_real  = accelerator_second.prepare(transformer_real)
    optimizer_d, transformer_fake = accelerator_d.prepare(
        optimizer_d, transformer_fake
    )
    #from accelerate.utils import is_deepspeed_available
    num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
    
    # we need to recalculate our total training steps as the size of the training dataloader may have changed
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if override_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        tracker_name = args.tracker_name or "cogvideox-lora"
        accelerator.init_trackers(tracker_name, config=vars(args))
    
    
   
    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])

    logger.info("***** Running training *****")
    logger.info(f"Num trainable parameters: {num_trainable_parameters}")
    logger.info(f"Num examples: {len(dataset)}")
    logger.info(f"Num batches each epoch: {len(train_dataloader)}")
    logger.info(f"Num epochs: {args.num_train_epochs}")
    logger.info(f"Instantaneous batch size per device: {args.train_batch_size}")
    logger.info(f"Total train batch size (w. parallel, distributed & accumulation): {total_batch_size}")
    logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
    logger.info(f"Total optimization steps: {args.max_train_steps}")

    global_step = 0
    first_epoch = 0

    if global_step == 0 and accelerator.is_main_process:
        from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce
        logger.info(f"Tensor Parallel config: {accelerator.state.deepspeed_plugin.deepspeed_config.get('tensor_parallel', 'Not enabled')}")
        # 打印一个张量并行层的配置
        for name, module in transformer.named_modules():
            if isinstance(module, (LinearLayer, LinearAllreduce)):
                logger.info(f"Found tensor parallel layer: {name}, tp_group: {module.tp_group}")
                break

    # Potentially load in the weights and states for a previous save
    if not args.resume_from_checkpoint:
        initial_global_step = 0

    else:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' not found. Starting a new training run."
            )
            args.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch
    
    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )
    vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)

    # For DeepSpeed training
    model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config

    def extract_into_tensor(a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))

    def generate_new(transformer, noise_scheduler, latent,noise, encoder_hidden_states, image_rotary_emb, steps=4, eta=1,
                     return_mid=False, mid_points=None, encode_steps=True, shift=False, total_steps=1000):
        T_ = torch.randint(total_steps - 1, total_steps, (latent.shape[0],), device=latent.device).long()
        zero_t = torch.zeros_like(T_)
        imgs_list = []
        pure_noisy = noise
        noisy_imgs_list = []
        #print('line 1587')
        for ind in range(steps):
            noisy_imgs_list.append(pure_noisy)
            model_input_eps = transformer(
                hidden_states=pure_noisy,
                encoder_hidden_states=encoder_hidden_states,
                timestep=T_,
                image_rotary_emb=image_rotary_emb,
                return_dict=False,
            )[0]
            model_input = noise_scheduler.get_velocity(model_input_eps, pure_noisy, T_)
            latent = model_input # .to(unet.dtype)
            imgs_list.append(latent)

            alpha_prod_t = extract_into_tensor(noise_scheduler.alphas_cumprod, T_, model_input.shape).to(weight_dtype)
            beta_prod_t = 1 - alpha_prod_t
            # print(alpha_prod_t.shape)
            pred_epsilon = (pure_noisy - alpha_prod_t ** (0.5) * model_input) / beta_prod_t ** (0.5)
            if mid_points is not None:
                T_ = mid_points[ind + 1] + zero_t
            else:
                T_ = T_ - total_steps // steps
            add_eps = eta * pred_epsilon + ((1 - eta ** 2) ** 0.5) * torch.randn_like(pred_epsilon)
            pure_noisy = noise_scheduler.add_noise(latent, add_eps, T_)
        noisy_imgs_list.append(latent)
        if return_mid:
            return imgs_list, noisy_imgs_list
        return latent
    #TODO:2
    # uncond_prompt_embeds = compute_prompt_embeddings(
    #     tokenizer,
    #     text_encoder,
    #     [""] * args.train_batch_size,
    #     model_config.max_text_seq_length,
    #     accelerator.device,
    #     weight_dtype,
    #     requires_grad=False,
    # )
    rank = dist.get_rank()
    uncond_prompt_embeds = torch.load( f'prompts/uncond_prompt_embed.pt',map_location=accelerator.device)
    # 在batch维度扩展uncond_prompt_embeds为train_batch_size
    uncond_prompt_embeds = uncond_prompt_embeds.repeat(args.train_batch_size, 1, 1)
    


    class Predictor():
        def __init__(self, noise_scheduler, alpha_schedule, sigma_schedule):
            super().__init__()
            self.noise_scheduler = noise_scheduler
            self.alpha_schedule = alpha_schedule
            self.sigma_schedule = sigma_schedule
            self.uncond_prompt_embeds = uncond_prompt_embeds

        def predict(self, score_model, noisy_samples, timesteps, encoder_hidden_states, image_rotary_emb, cfg=None, steps=1,
                    return_double=False, timestep_cond=None):
            # alpha_prod_t = self.noise_scheduler.alphas_cumprod[timesteps]
            alpha_prod_t = extract_into_tensor(self.noise_scheduler.alphas_cumprod, timesteps, noisy_samples.shape).to(weight_dtype)
            beta_prod_t = 1 - alpha_prod_t
            score_pred = score_model(
            hidden_states=noisy_samples.to(weight_dtype),
            timestep=timesteps.to(weight_dtype),
            encoder_hidden_states=encoder_hidden_states.to(weight_dtype),
            image_rotary_emb=image_rotary_emb,
            return_dict=False)[0]

            if cfg is not None:
                score_uncond_pred = score_model(
                    hidden_states=noisy_samples.to(weight_dtype),
                    timestep=timesteps.to(weight_dtype),
                    encoder_hidden_states=self.uncond_prompt_embeds.to(weight_dtype),
                    image_rotary_emb=image_rotary_emb,
                    return_dict=False)[0]
                score_pred_cfg = score_uncond_pred + cfg * (score_pred - score_uncond_pred)
                # 下面的直接混合x_0方法和直接混合scores方法是等价的
                pred_latents_cond = self.noise_scheduler.get_velocity(score_pred, noisy_samples, timesteps)
                pred_latents_uncond = self.noise_scheduler.get_velocity(score_uncond_pred, noisy_samples, timesteps)
                pred_latents = pred_latents_uncond + cfg * (pred_latents_cond - pred_latents_uncond)

                if return_double:
                    pred_epsilon = (noisy_samples - alpha_prod_t ** (0.5) * pred_latents) / beta_prod_t ** (0.5)
                    return pred_epsilon, pred_latents

            else:
                pred_latents = self.noise_scheduler.get_velocity(score_pred, noisy_samples, timesteps)
                if return_double:
                    pred_epsilon = (noisy_samples - alpha_prod_t ** (0.5) * pred_latents) / beta_prod_t ** (0.5)
                    return pred_epsilon, pred_latents

            return pred_latents

        def add_noise(self, samples, noise, t1, t2):
            sigmas = extract_into_tensor(self.sigma_schedule, t1, samples.shape)
            alphas = extract_into_tensor(self.alpha_schedule, t1, samples.shape)
            sigmas_new = extract_into_tensor(self.sigma_schedule, t2, samples.shape)
            alphas_new = extract_into_tensor(self.alpha_schedule, t2, samples.shape)
            samples = samples / alphas * alphas_new # (sigmas_new ** 2 - alphas_new ** 2 * sigmas ** 2 / alphas ** 2) ** 0.5 * noise

            beta = sigmas_new ** 2 - (alphas_new / alphas * sigmas) ** 2
            beta = beta ** 0.5
            samples = samples + beta * noise

            return samples.to(weight_dtype)

    fixed_prompt = (
        " A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The"
        " panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other"
        " pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo,"
        " casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays."
        " The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical"
        " atmosphere of this unique musical performance "
    )
    
    # fixed_prompt_embeds = compute_prompt_embeddings(
    #     tokenizer,
    #     text_encoder,
    #     [fixed_prompt] * args.train_batch_size,
    #     model_config.max_text_seq_length,
    #     accelerator.device,
    #     weight_dtype,
    #     requires_grad=False,
    # )
    fixed_prompt_embeds = torch.load(f'prompts/fixed_prompt_embedding.pt',map_location=accelerator.device)
    fixed_prompt_embeds = fixed_prompt_embeds.repeat(args.train_batch_size, 1, 1)

    total_steps = 1000

    K_step = args.k_step
    noise_scheduler = scheduler
    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
    alpha_schedule = alpha_schedule.to(accelerator.device).to(weight_dtype)
    sigma_schedule = sigma_schedule.to(accelerator.device).to(weight_dtype)

    predictor = Predictor(scheduler, alpha_schedule, sigma_schedule)
    #prob_drop = torch.rand((args.train_batch_size,), device=noise.device)

    for epoch in range(first_epoch, args.num_train_epochs):
        transformer.train()
        transformer_fake.train()
        for step, batch in enumerate(train_dataloader):
            models_to_accumulate = [transformer]
            
            with accelerator.accumulate(models_to_accumulate):
                # encode prompts
                #prompts = list(batch[0])
                # print('-'*10)
                # print('loop encode')
                # print('-'*10)
                # prompt_embeds = compute_prompt_embeddings(
                #     tokenizer,
                #     text_encoder,
                #     prompts,
                #     model_config.max_text_seq_length,
                #     accelerator.device,
                #     weight_dtype,
                #     requires_grad=False,
                # )
                index = rank * 600 + step #this is for 6 gpus and batch_size == 120
                prompts = dataset[index]
                
                # 为batch中的每个样本加载prompt_embeds并在0维拼接
                prompt_embeds_list = []
                for i in range(args.train_batch_size):
                    current_index = index if i == 0 else index + i  # 第一个使用index，其他使用连续的index
                    single_prompt_embeds = torch.load(f'prompts/individual_embeddings/{current_index}.pt', map_location=accelerator.device)
                    prompt_embeds_list.append(single_prompt_embeds)
                
                # 在第0维拼接所有的prompt_embeds
                prompt_embeds = torch.cat(prompt_embeds_list, dim=0)


                # uncond_prompt_embeds = compute_prompt_embeddings(
                #     tokenizer,
                #     text_encoder,
                #     [""] * args.train_batch_size,
                #     model_config.max_text_seq_length,
                #     accelerator.device,
                #     weight_dtype,
                #     requires_grad=False,
                # )

                # Sample noise that will be added to the Latents
                # noise = torch.randn_like(model_input)
                if "5b" in args.pretrained_model_name_or_path.lower(): # shorten the frames for training .
                    noise = torch.randn(args.train_batch_size, 13, 16, 60, 90).to(accelerator.device).to(dtype=weight_dtype)
                else:
                    noise = torch.randn(args.train_batch_size, 13, 16, 60, 90).to(accelerator.device).to(dtype=weight_dtype)
                batch_size, num_frames, num_channels, height, width = noise.shape

                #Prepare rotary embeds
                image_rotary_emb = (
                    prepare_rotary_positional_embeddings(
                        height=args.height,
                        width=args.width,
                        num_frames=num_frames,
                        vae_scale_factor_spatial=vae_scale_factor_spatial,
                        patch_size=model_config.patch_size,
                        attention_head_dim=model_config.attention_head_dim,
                        device=accelerator.device,
                    )
                    if model_config.use_rotary_positional_embeddings
                    else None
                )

                with torch.no_grad():
                    new_noise = torch.randn_like(noise)
                    # print('line 1782')
                    imgs_list, noisy_imgs_list = generate_new(transformer, scheduler, new_noise, new_noise, prompt_embeds, image_rotary_emb,eta=args.eta,
                                                            steps=K_step, return_mid=True, total_steps=1000) # [[bs, 4, 64, 64]]
                    # print('line 1785')
                    noisy_imgs_list.reverse()
                    bsz = noise.shape[0]
                    k_list = torch.randint(0, K_step, (bsz,), device=noise.device).long()

                    model_input = torch.randn_like(noise)
                    for ii in range(model_input.shape[0]):
                        model_input[ii] = imgs_list[k_list[ii]][ii]

            # Train Fake Score .
            with accelerator_d.accumulate([transformer_fake]):
                noise = torch.randn_like(model_input)
                Ind_t = torch.randint(1, K_step + 1, (bsz,), device=noise.device).long()

                noisy_latents_ode = torch.randn_like(noise)
                for i in range(noise.shape[0]):
                    noisy_latents_ode[i] = noisy_imgs_list[Ind_t[i]][i]

                timesteps_g = Ind_t * total_steps // K_step - 1
                timesteps_mid = timesteps_g - total_steps // K_step + 1
                timesteps = timesteps_g.clone() * 0
                for ind_bw in range(bsz):
                    timesteps[ind_bw] = torch.randint(timesteps_mid[ind_bw], 980, (1,), device=noise.device)[0]
                timesteps = timesteps.long()
                
                with torch.no_grad():
                    model_eps, model_latents = predictor.predict(transformer, noisy_latents_ode, timesteps_g, prompt_embeds, image_rotary_emb,
                                                                    return_double=True)
                add_eps = args.eta * model_eps + ((1 - args.eta ** 2) ** 0.5) * torch.randn_like(model_eps)
                noisy_model_latents_ode = noise_scheduler.add_noise(model_latents, add_eps, timesteps_mid).to(weight_dtype)
                noisy_model_latents = predictor.add_noise(noisy_model_latents_ode.detach(), torch.randn_like(noisy_model_latents_ode), timesteps_mid, timesteps).to(weight_dtype)
            
                if args.lambda_reg > 0:
                    with torch.no_grad():
                        real_latents = predictor.predict(transformer_real, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb)
                fake_latents = predictor.predict(transformer_fake, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb)
                alphas_cumprod = scheduler.alphas_cumprod[timesteps]
                weights = 1 / (1 - alphas_cumprod)
                while len(weights.shape) < len(fake_latents.shape):
                    weights = weights.unsqueeze(-1)
                target = model_latents
                loss = torch.mean((weights * (fake_latents - target) ** 2).reshape(batch_size, -1), dim=1)
                loss = loss.mean()
                loss_fake = loss

                if args.lambda_reg > 0:
                    loss_reg = torch.mean((weights * (fake_latents - real_latents) ** 2).reshape(batch_size, -1), dim=1)
                    loss_reg = loss_reg.mean()                    
                    accelerator_d.backward((loss + args.lambda_reg * loss_reg))
                else:
                    accelerator_d.backward(loss)
                if accelerator_d.sync_gradients:
                    params_to_clip = transformer_fake.parameters()
                    accelerator_d.clip_grad_norm_(params_to_clip, args.max_grad_norm)

                if accelerator_d.state.deepspeed_plugin is None:
                    optimizer_d.step()
                    optimizer_d.zero_grad()
       
            # Train K-step generator
            with accelerator.accumulate([transformer]):
                with torch.no_grad():
                    ind_t = torch.randint(1, K_step + 1, (bsz,), device=noise.device).long()
                    noisy_latents = torch.randn_like(noise)
                    for i in range(noise.shape[0]):
                        noisy_latents[i] = noisy_imgs_list[ind_t[i]][i]

                    timesteps_g = ind_t * total_steps // K_step - 1
                timesteps_mid = timesteps_g - total_steps // K_step + 1
                timesteps = timesteps_g.clone() * 0
                for ind_bw in range(bsz):
                    timesteps[ind_bw] = torch.randint(timesteps_mid[ind_bw], 980, (1,), device=noise.device)[0]
                timesteps = timesteps.long()
                model_eps, model_latents = predictor.predict(transformer, noisy_latents, timesteps_g, prompt_embeds, image_rotary_emb, return_double=True)

                noise = torch.randn_like(model_latents)
                add_eps = args.eta * model_eps +  ((1 - args.eta**2) ** 0.5) * torch.randn_like(model_eps)
                noisy_model_latents_ode = noise_scheduler.add_noise(model_latents, add_eps, timesteps_mid).to(weight_dtype)
                noisy_model_latents = predictor.add_noise(noisy_model_latents_ode.detach(), torch.randn_like(noisy_model_latents_ode), timesteps_mid, timesteps).to(weight_dtype)

                with torch.no_grad():
                    real_latents = predictor.predict(transformer_real, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb, cfg=args.cfg)
                    fake_latents = predictor.predict(transformer_fake, noisy_model_latents, timesteps, prompt_embeds, image_rotary_emb)
                    revised_latents = (model_latents + real_latents - fake_latents).detach().float()

                # alphas_cumprod = scheduler.alphas_cumprod[timesteps]
                # weights = 1 / (1 - alphas_cumprod)
                # while len(weights.shape) < len(fake_latents.shape):
                #     weights = weights.unsqueeze(-1)
                di_target = revised_latents.detach()
                weighting_factor = torch.abs(model_latents - real_latents).mean(dim=[1, 2, 3, 4], keepdim=True).detach()
                weighting_factor = torch.where(weighting_factor > 5, torch.randn_like(weighting_factor)*0 + 5, weighting_factor)
                args.huber_c = 1e-3 / (((64*64*4)**0.5) * (60*90*16*13)**0.5)
                loss = torch.mean(
                    (torch.sqrt((model_latents.float() - revised_latents.detach().float()) ** 2 + args.huber_c ** 2) - args.huber_c)
                    / weighting_factor
                )

                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    params_to_clip = transformer.parameters()
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

                if accelerator.state.deepspeed_plugin is None:
                    optimizer.step()
                    optimizer.zero_grad()
            
                lr_scheduler.step()
                    # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    if global_step % args.checkpointing_steps == 0:
                        with torch.no_grad():
                            model_input = generate_new(transformer, scheduler, noise, noise, fixed_prompt_embeds, image_rotary_emb, eta=1, steps=4)
                            # [ batch_size , num_channels , num frames , height , width ]
                        print(prompts)
                        with torch.no_grad():
                            model_input = rearrange(model_input, "b f c h w -> b c f h w")
                            images_t = vae.decode(model_input.to(vae.dtype) / vae.config.scaling_factor, return_dict=False)[0].clamp(-1, 1)
                            if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
                                save_videos_grid(images_t.cpu(), f'./{args.output_dir}/Videos_{global_step}.mp4')

                        with torch.no_grad():
                            model_input = generate_new(transformer, scheduler, noise, noise, prompt_embeds, image_rotary_emb, eta=args.eta, steps=K_step)
                
                        
                        with torch.no_grad():
                            # [ batch_size , num_channels , num frames , height , width ]
                            model_input = rearrange(model_input, "b f c h w -> b c f h w")
                            images_t = vae.decode(model_input.to(vae.dtype) / vae.config.scaling_factor, return_dict=False)[0].clamp(-1, 1)
                            if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
                                save_videos_grid(images_t.cpu(), f'{args.output_dir}/videos_{global_step}.mp4')

            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
                    if global_step % args.checkpointing_steps == 0:
                        # before saving state, check if this save would set us over the checkpoints_total_limit
                        if args.checkpoints_total_limit is not None:
                            checkpoints = os.listdir(args.output_dir)
                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))

                            # before we save the new checkpoint, we need to have at most `checkpoints_total_limit - 1` checkpoints
                            if len(checkpoints) >= args.checkpoints_total_limit:
                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                                removing_checkpoints = checkpoints[:num_to_remove]

                                logger.info(
                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                                )
                                logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")

                                for removing_checkpoint in removing_checkpoints:
                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                    shutil.rmtree(removing_checkpoint)

                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

            logs = {"loss_fake": loss_fake.detach().item(), "loss_du": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

            if global_step >= args.max_train_steps:
                break

        # if accelerator.is_main_process:
        #     if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
        #         # Create pipeline
        #         pipe = CogVideoXPipeline.from_pretrained(
        #             args.pretrained_model_name_or_path,
        #             transformer=unwrap_model(transformer),
        #             text_encoder=unwrap_model(text_encoder),
        #             scheduler=scheduler,
        #             revision=args.revision,
        #             variant=args.variant,
        #             torch_dtype=weight_dtype,
        #         )
        #         validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
        #         for validation_prompt in validation_prompts:
        #             pipeline_args = {
        #                 "prompt": validation_prompt,
        #                 "guidance_scale": args.guidance_scale,
        #                 "use_dynamic_cfg": args.use_dynamic_cfg,
        #                 "height": args.height,
        #                 "width": args.width,
        #             }
        #             validation_outputs = log_validation(
        #                 pipe=pipe,
        #                 args=args,
        #                 accelerator=accelerator,
        #                 pipeline_args=pipeline_args,
        #                 epoch=epoch,
        #             )

    # Save the Lora Layers
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        transformer = unwrap_model(transformer)
        dtype = (
            torch.float16
            if args.mixed_precision == "fp16"
            else torch.bfloat16
            if args.mixed_precision == "bf16"
            else torch.float32
        )
        transformer = transformer.to(dtype)
        transformer_lora_layers = get_peft_model_state_dict(transformer)
        CogVideoXPipeline.save_lora_weights(
            save_directory=args.output_dir,
            transformer_Lora_Layers=transformer_lora_layers,
        )

        # Cleanup trained models to save memory
        del transformer
        free_memory()

        # Final test inference
        pipe = CogVideoXPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            revision=args.revision,
            variant=args.variant,
            torch_dtype=weight_dtype,
        )
        pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)

        if args.enable_slicing:
            pipe.vae.enable_slicing()
        if args.enable_tiling:
            pipe.vae.enable_tiling()

        #Load LORA weights
        lora_scaling = args.lora_alpha / args.rank
        pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
        pipe.set_adapters(["cogvideox-lora"], [lora_scaling])

        #Run inference
        validation_outputs = []

        if args.validation_prompt and args.num_validation_videos > 0:
            validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
            for validation_prompt in validation_prompts:
                pipeline_args = {
                    "prompt": validation_prompt,
                    "guidance_scale": args.guidance_scale,
                    "use_dynamic_cfg": args.use_dynamic_cfg,
                    "height": args.height,
                    "width": args.width,
                }
                video = log_validation(
                    pipe=pipe,
                    args=args,
                    accelerator=accelerator,
                    pipeline_args=pipeline_args,
                    epoch=epoch,
                    is_final_validation=True,
                )
                validation_outputs.extend(video)

        if args.push_to_hub:
            save_model_card(
                repo_id=repo_id,
                videos=validation_outputs,
                base_model=args.pretrained_model_name_or_path,
                validation_prompt=args.validation_prompt,
                repo_folder=args.output_dir,
                fps=args.fps,
            )
            upload_folder(
                repo_id=repo_id,
                folder_path=args.output_dir,
                commit_message="End of training",
                ignore_patterns=["step_*", "epoch_*"],
            )

    accelerator.end_training()

if __name__ == "__main__":
    args = get_args()
    args.output_dir = args.output_dir + f"_lambda-reg_{args.lambda_reg}_cfg_{args.cfg}_eta_{args.eta}_K_{args.k_step}"
    main(args)




