#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# modified on sga version
import sys
sys.path.insert(0, "")

import argparse
import itertools
import logging
import os
from pathlib import Path
import random

import datasets
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from attacks.utils import StableDiffuser
import torch.optim as optim
import re

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.13.0.dev0")

logger = get_logger(__name__)


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation




def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help=(
            "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
            " float32 precision."
        ),
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--instance_data_dir",
        type=str,
        default=None,
        required=True,
        help="A folder containing the training data of instance images.",
    )
    parser.add_argument(
        "--instance_prompt",
        type=str,
        default=None,
        required=True,
        help="The prompt with identifier specifying the instance",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="text-inversion-model",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--center_crop",
        default=False,
        action="store_true",
        help=(
            "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
            " cropped. The images will be resized to the resolution first before cropping."
        ),
    )
    parser.add_argument(
        "--train_text_encoder",
        action="store_true",
        help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
    )
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.",
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention",
        action="store_true",
        help="Whether or not to use xformers.",
    )
    parser.add_argument(
        "--pgd_alpha",
        type=float,
        default=1.0 / 255,
        help="The step size for pgd.",
    )
    parser.add_argument(
        "--pgd_eps",
        type=float,
        default=16.0 / 255,
        help="The noise budget for pgd.",
    )
    parser.add_argument(
        "--target_image_path",
        default=None,
        help="target image for attacking",
    )

    # 新增参数
    parser.add_argument(
        "--pgd_batch_size",
        type=int,
        default=4,
        help="Batch size for PGD attack",
    )
    parser.add_argument(
        "--inner_iters",
        type=int,
        default=10,
        help="Batch size for PGD attack",
    )
    parser.add_argument(
        "--input_nums",
        type=int,
        default=4,
        help="How much images will be used for training from the input images",
    )

    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()

    return args


class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

        self.instance_images_path = [p for p in Path(instance_data_root).iterdir() if p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                truncation=True,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                return_tensors="pt",
            ).input_ids

        return example


def collate_fn(examples):
    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }
    return batch


class PromptDataset(Dataset):
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example


def infer(checkpoint_path, prompts=None, n_img=16, bs=8, n_steps=100, guidance_scale=7.5):
    pipe = StableDiffusionPipeline.from_pretrained(checkpoint_path, torch_dtype=torch.float16, safety_checker=None).to(
        "cuda"
    )

    for prompt in prompts:
        norm_prompt = prompt.lower().replace(",", "").replace(" ", "_")
        out_path = f"{checkpoint_path}/dreambooth/{norm_prompt}"
        os.makedirs(out_path, exist_ok=True)
        for i in range(n_img // bs):
            images = pipe(
                [prompt] * bs,
                num_inference_steps=n_steps,
                guidance_scale=guidance_scale,
            ).images
            for idx, image in enumerate(images):
                image.save(f"{out_path}/{i}_{idx}.png")
    del pipe


def main(args):
    logging_dir = Path(args.output_dir, args.logging_dir)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_dir=logging_dir,
        # logging_dir=logging_dir,
    )

    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
        raise ValueError(
            "Gradient accumulation is not supported when training the text encoder in distributed training. "
            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    # Load the tokenizer
    if args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
    elif args.pretrained_model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            args.pretrained_model_name_or_path,
            subfolder="tokenizer",
            revision=args.revision,
            use_fast=False,
        )

    # import correct text encoder class
    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)

    # Load scheduler and models
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    text_encoder = text_encoder_cls.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    )

    vae.requires_grad_(False)
    if not args.train_text_encoder:
        text_encoder.requires_grad_(False)

    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        if args.train_text_encoder:
            text_encoder.gradient_checkpointing_enable()

    # Check that all trainable models are in full precision
    low_precision_error_string = (
        "Please make sure to always have all model weights in full float32 precision when starting training - even if"
        " doing mixed precision training. copy of the weights should still be float32."
    )

    if accelerator.unwrap_model(unet).dtype != torch.float32:
        raise ValueError(
            f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
        )

    if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
        raise ValueError(
            f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
            f" {low_precision_error_string}"
        )

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    # Dataset and DataLoaders creation:
    train_dataset = DreamBoothDataset(
        instance_data_root=args.instance_data_dir,
        instance_prompt=args.instance_prompt,
        tokenizer=tokenizer,
        size=args.resolution,
        center_crop=args.center_crop,
    )

    # 准备所有图像路径
    all_image_paths = sorted(train_dataset.instance_images_path)
    # all_image_paths = all_image_paths[:30]
    total_images = len(all_image_paths)
    batch_size = args.pgd_batch_size

    # Prepare everything with our `accelerator`.
    unet, text_encoder = accelerator.prepare(
        unet, text_encoder
    )

    # For mixed precision training we cast the text_encoder and vae weights to half-precision
    # as these models are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move vae and text_encoder to device and cast to weight_dtype
    vae.to(accelerator.device, dtype=weight_dtype)
    if not args.train_text_encoder:
        text_encoder.to(accelerator.device, dtype=weight_dtype)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        accelerator.init_trackers("dreambooth", config=vars(args))

    # 将图像处理改为批次进行
    def process_image_batch(image_paths):
        images = [Image.open(i).convert("RGB") for i in image_paths]
        images = [train_dataset.image_transforms(i) for i in images]
        return torch.stack(images).contiguous()

    def get_tokenizer_input(prompt, batch_size):
        return train_dataset.tokenizer(
            prompt,
            truncation=True,
            padding="max_length",
            max_length=train_dataset.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.repeat(batch_size, 1)

    logger.info("***** Running training *****")
    logger.info(f"  Total number of images = {total_images}")
    logger.info(f"  Batch size for PGD = {batch_size}")
    logger.info(f"  Num training steps = {args.max_train_steps}")

    global_step = 0
    progress_bar = tqdm(
        range(global_step, args.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )
    progress_bar.set_description("Steps")

    # 加载目标图像（如果有的话）
    target_latent_tensor = None
    if args.target_image_path is not None:
        target_image_path = Path(args.target_image_path)
        assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"

        target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
        target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)
        target_image_tensor = torch.from_numpy(target_image).to(accelerator.device, dtype=torch.float32) / 127.5 - 1.0
        target_latent_tensor = vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor

    unet.train()
    if args.train_text_encoder:
        text_encoder.train()

    ddim_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    modules = {
        'vae': vae,
        'tokenizer': tokenizer,
        'unet': unet,
        'text_encoder': text_encoder,
        'scheduler': ddim_scheduler
    }
    diffuser = StableDiffuser(modules)


    # 身份子空间建模部分
    pre_optimize_steps = 50
    r = 0.001  # Learning rate for text embedding optimization

    for step in range(pre_optimize_steps):
        unet.zero_grad()
        text_encoder.zero_grad()

        original_images = process_image_batch(all_image_paths).to(accelerator.device, dtype=weight_dtype)

        # Encode images to latents
        latents = vae.encode(original_images).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Add noise to latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        if step == 0:
            # Get text embeddings
            input_ids = get_tokenizer_input(args.instance_prompt, bsz)
            encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0].detach()
            encoder_hidden_states.requires_grad_(True)

            # Initialize Adam optimizer for encoder_hidden_states after it is created
            optimizer = optim.AdamW([encoder_hidden_states], lr=r)

        # Predict noise
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Calculate loss
        target = noise
        tloss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # Optimize text embeddings with Adam optimizer
        optimizer.zero_grad()  # Clear previous gradients
        tloss.backward()  # Compute gradients
        optimizer.step()  # Update encoder_hidden_states

        print(f"Pre-optimization step {step}, loss: {tloss.item()}")

    # #* debug purpose
    # # Check and print the difference between initial and optimized encoder_hidden_states
    # initial_encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0][0:1]
    # optimized_encoder_hidden_states = encoder_hidden_states
    # save_states = torch.cat([initial_encoder_hidden_states, optimized_encoder_hidden_states], dim=0)
    
    # match = re.search(r'data/(.*?)/set_A/', str(all_image_paths[0]))
    # if match:
    #     subject_id = match.group(1)
    #     print(subject_id)
    # else:
    #     print("No match found")

    # torch.save(save_states, f"temp/{subject_id}_1e-3.pt")

    # return

    # difference = optimized_encoder_hidden_states - initial_encoder_hidden_states
    # print("Difference between initial and optimized encoder_hidden_states:", difference)


    # 计算均值 mean_c 和标准差 std_c
    encoder_hidden_states = encoder_hidden_states.detach()
    mean_c = encoder_hidden_states.mean(dim=0)  # 计算均值, Shape: (L, dE)
    std_c = encoder_hidden_states.std(dim=0)  # 计算标准差, Shape: (L, dE)

    del optimizer

    # Parameters for SGA
    k = args.inner_iters  # Number of inner iterations
    minibatch = args.pgd_batch_size
    alpha = args.pgd_alpha  # 步长
    eps = args.pgd_eps      # 最大扰动范围

    delta = torch.zeros(3, args.resolution, args.resolution)
    with torch.no_grad():
        input_ids = get_tokenizer_input(args.instance_prompt, minibatch)
        # Get text embeddings
        encoder_hidden_states = text_encoder(input_ids.to(accelerator.device))[0]
        unconditional_tokens = diffuser.text_tokenize([""] * minibatch)
        unconditional_embeddings = diffuser.text_encode(unconditional_tokens)
    for i in range(args.max_train_steps):
        total_batch_loss = 0

        # 内循环的总计次数：k
        noise_inner_all = torch.zeros(k, 3, args.resolution, args.resolution).to(accelerator.device)
        delta = delta.to(accelerator.device)
        delta_inner = delta.clone()
        for j in range(k):
            with torch.no_grad():
                nsteps = 50                

                # 使用重参数化技巧从 mean_c 和 std_c 中采样
                epsilon = torch.randn(minibatch, *mean_c.shape).to(accelerator.device)  # 从标准正态分布中采样，Shape: (L, dE)
                sampled_c = mean_c + epsilon * std_c  # 使用重参数化技巧得到样本，Shape: (L, dE)

                infer_embeds = torch.cat([unconditional_embeddings, sampled_c])

                diffuser.set_scheduler_timesteps(nsteps)

                iteration = torch.randint(1, nsteps - 1, (1,)).item()

                latents = diffuser.get_initial_latents(minibatch, 512, 1)

                # 注意这里的end_iteration=iteration，所以和scheduler是反着的
                latents_steps, _ = diffuser.diffusion(
                    latents,
                    infer_embeds,
                    start_iteration=0,
                    end_iteration=iteration,
                    guidance_scale=7.5, # ? 为什么原始设置的3？ 
                    show_progress=False
                )

                diffuser.set_scheduler_timesteps(1000)

                iteration = int(iteration / nsteps * 1000)

                # 涉及加入扰动的方法
                noise_prediction = unet(latents_steps[0], ddim_scheduler.timesteps[iteration], encoder_hidden_states).sample

                pred_x0 = ddim_scheduler.step(
                    noise_prediction,
                    ddim_scheduler.timesteps[iteration],
                    latents_steps[0],
                ).pred_original_sample
                # noise_prediction, pred_x0 = diffuser.predict_noise(iteration, latents_steps[0], encoder_hidden_states, guidance_scale=1)
                pred_imgs = diffuser.decode(pred_x0)

            # batch_delta: (minibatch, 3, 512, 512)
            batch_delta = delta_inner.detach().clone().unsqueeze(0).repeat([minibatch, 1, 1, 1])
            batch_delta.requires_grad_()
            # batch_delta.grad.data.zero_()
            pertubed_images = torch.clamp((pred_imgs + batch_delta), -1, 1)

            # Convert images to latent space
            latents = vae.encode(pertubed_images.to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise
            noise = noise_prediction
            timesteps = torch.tensor(ddim_scheduler.timesteps[iteration], dtype=int, device=latents.device)
            timesteps = timesteps.long()

            # Add noise to latents
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Predict noise residual
            noise_prediction_adv = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            # Calculate loss
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            unet.zero_grad()
            text_encoder.zero_grad()

            target = noise_prediction
            model_pred = noise_prediction_adv
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Target-shift loss if applicable
            if target_latent_tensor is not None:
                current_target_latent = target_latent_tensor.repeat(minibatch, 1, 1, 1)
                xtm1_pred = torch.cat([
                    noise_scheduler.step(
                        model_pred[idx:idx + 1],
                        timesteps[idx:idx + 1],
                        noisy_latents[idx:idx + 1],
                    ).prev_sample
                    for idx in range(minibatch)
                ])
                xtm1_target = noise_scheduler.add_noise(current_target_latent, noise, timesteps - 1)
                loss = loss - F.mse_loss(xtm1_pred, xtm1_target)

            total_batch_loss += loss.item()  # 累积loss

            accelerator.backward(loss)

            grad_inner = batch_delta.grad.mean(dim=0)
            delta_inner = delta_inner + grad_inner.sign() * alpha
            delta_inner = torch.clamp(delta_inner, -eps, eps)
            noise_inner_all[j, :, :, :] = grad_inner

            if accelerator.sync_gradients:
                params_to_clip = (
                    itertools.chain(unet.parameters(), text_encoder.parameters())
                    if args.train_text_encoder
                    else unet.parameters()
                )
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)                

        # batch update
        grad_sign = torch.mean(noise_inner_all.detach().clone(), dim=0, keepdim=True).squeeze(0).sign()

        delta = delta + grad_sign * alpha
        delta = torch.clamp(delta, -eps, eps)

        # Rest of the code remains the same...
        global_step += 1
        if accelerator.sync_gradients and global_step % args.checkpointing_steps == 0:
            if accelerator.is_main_process:
                save_folder = f"{args.output_dir}/noise-ckpt/{global_step}"
                os.makedirs(save_folder, exist_ok=True)

                # Save universal perturbation
                perturbation_path = os.path.join(save_folder, f"universal_perturbation.pt")
                torch.save(delta.cpu(), perturbation_path)

                # Visualize perturbation effect
                vis_perturbation = (delta - delta.min()) / (delta.max() - delta.min())
                vis_perturbation = (delta.cpu() * 127.5 + 128).clamp(0, 255).to(torch.uint8)
                Image.fromarray(vis_perturbation.permute(1, 2, 0).numpy()).save(
                    os.path.join(save_folder, f"universal_perturbation_vis.png")
                )

                # Save example images
                original_batch = pred_imgs.to(accelerator.device)
                perturbed_batch = torch.clamp(
                    pred_imgs + delta.unsqueeze(0).repeat(len(original_batch), 1, 1, 1),
                    min=-1,
                    max=1
                )

                for idx, (orig, pert) in enumerate(zip(original_batch, perturbed_batch)):
                    # Save original images
                    Image.fromarray(
                        (orig * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
                    ).save(os.path.join(save_folder, f"example_{idx}_original.png"))

                    # Save perturbed images
                    Image.fromarray(
                        (pert * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
                    ).save(os.path.join(save_folder, f"example_{idx}_perturbed.png"))

        # Update progress
        if accelerator.sync_gradients:
            progress_bar.update(1)

            logs = {
                "loss": total_batch_loss / k,  # 使用内循环次数k来平均loss
                "perturbation_norm": torch.norm(delta).item()
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

    # 保存最终的通用扰动
    if accelerator.is_main_process:
        final_save_folder = f"{args.output_dir}/final"
        os.makedirs(final_save_folder, exist_ok=True)
        torch.save(delta.cpu(), os.path.join(final_save_folder, "universal_perturbation.pt"))

        # 保存可视化结果
        vis_perturbation = (delta.cpu() * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        Image.fromarray(vis_perturbation.permute(1, 2, 0).numpy()).save(
            os.path.join(final_save_folder, "universal_perturbation_vis.png")
        )

        final_adv_save_folder = f"{args.output_dir}/final_adv_imgs"
        os.makedirs(final_adv_save_folder, exist_ok=True)

        # 保存一些示例图像
        original_batch = process_image_batch(all_image_paths[:4]).to(accelerator.device)
        perturbed_batch = torch.clamp(
            original_batch + delta.unsqueeze(0).repeat(len(original_batch), 1, 1, 1),
            min=-1,
            max=1
        )

        for idx, pert in enumerate(perturbed_batch):            
            # 保存扰动后的图像
            Image.fromarray(
                (pert * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
            ).save(os.path.join(final_adv_save_folder, f"example_{idx}_perturbed.png"))

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        print("Done")
    accelerator.end_training()

if __name__ == "__main__":
    args = parse_args()
    main(args)
