import os
import sys
import signal

def handle_signal(signum, frame):
    print(f"Received signal {signum}, ignoring.")

signal.signal(signal.SIGHUP, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
from ddpo_pytorch.Clock import Clock
from ddpo_pytorch.adaptive_model import Discriminator_CLIP, Discriminator_P, Discriminator_Noise
from scripts.Evals import eval_all
from utils.dropper import zero_tensor_by_ratio
import torch.distributed as dist
# os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo"
from collections import defaultdict
import contextlib
import datetime
from concurrent import futures
import time
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import numpy as np
import ddpo_pytorch.prompts
import ddpo_pytorch.rewards
# from ddpo_pytorch.adaptive_model import Discriminator
from ddpo_pytorch.stat_tracking import PerPromptStatTracker
from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob, unet_predict_noise
from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob
import torch
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
from torch import inf
print(torch.cuda.get_device_name())
on_3060 = "3060" in torch.cuda.get_device_name()
on_h = "H100" in torch.cuda.get_device_name() or "H20" in torch.cuda.get_device_name()
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)


FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.")

logger = get_logger(__name__)
# def force_cudnn_initialization():
#     s = 32
#     dev = torch.device('cuda')
#     torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
# force_cudnn_initialization()
def save_tensor_to_txt(tensor, file_path):
    directory = os.path.dirname(file_path)

    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    with open(file_path, 'a+') as file:
        file.write(str(tensor) + "\n\n")

def check_all_gradients_unupdated(model):
    for name, param in model.named_parameters():
        if param.grad is not None and param.grad.abs().sum() > 0:
            return False
    return True

def global_grad_norm_(parameters, norm_type=2):
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type
        total_norm = total_norm ** (1. / norm_type)

    return total_norm

def delta_mask(target1,delta_abs_t=0.04,delta_delta_t=0.01):
    #abs>0.8 and delta<0

    delta_delta=target1[:,1:]-target1[:,:-1]
    condition_1 = torch.abs(target1[:,:-1])<delta_abs_t
    condition_2 = torch.abs(delta_delta)<delta_delta_t
    # condition_1 = target1_abs[:,:50]>abs_t
    # condition_2 = target1>delta_t
    mask = condition_1 & condition_2


    mask_num = mask.float()  # True→1.0, False→0.0


    mask_cum = mask_num.cumsum(dim=1)

    propagated_mask = mask_cum > 0  # dtype=bool, shape=[batch, dim]
    return propagated_mask

def delta_combine(target1,target1_abs,target2,delta_abs_t,delta_delta_t,first_stage=True,target2_scale=2):
    propagated_mask = delta_mask(target1,delta_abs_t=delta_abs_t,delta_delta_t=delta_delta_t)
    true_column = torch.ones(propagated_mask.size(0), 1, dtype=torch.bool).to(propagated_mask.device)

    # Concatenate the original tensor with the new column
    new_tensor = torch.cat((propagated_mask, true_column), dim=1)
    if first_stage:
        mask=(torch.abs(target1[:,:-1])<0.1)&(new_tensor[:,:-1] == False)
        mask &= torch.cumprod(mask, dim=1).bool()
        target1 = torch.where(mask, 1, target1)
    result_cut = torch.where(new_tensor, target2, 100)
    result2 = torch.where(new_tensor, target2*target2_scale, target1)
    return result_cut,result2



def step_predictor_der(target_2=None, stage='train',step_min=25,delta_abs_t=0.001,delta_delta_t=0.001):
    if stage == 'train':
        mask = (target_2 != 100)[:,:-1]
        delta_delta_target2 = target_2[:,1:]-target_2[:,:-1]
        condition1 = delta_delta_target2<delta_delta_t
        condition2 = np.abs(target_2)[:,:-1]<delta_abs_t
        condition = condition1 & mask & condition2
        all_step_num = np.argmax(condition, axis=1)+1
        all_step_num[all_step_num < step_min] = step_min
        return all_step_num, all_step_num  # valid_step+final+1,all_step_num+valid_step+1
    else:
        raise ValueError('test step predictor is not implemented yet.')

def main(_):
    # basic Accelerate and logging setup
    config = FLAGS.config
    # cjk_note: modification for local run
    SDv14 = "CompVis/stable-diffusion-v1-4"
    SDv21 = "stabilityai/stable-diffusion-2-1"
    SDXL="stabilityai/stable-diffusion-xl-base-1.0"
    SDturbo= "stabilityai/sd-turbo"
    SDv15 = 'sd-legacy/stable-diffusion-v1-5'
    try:
        config.pretrained.model = locals()[config.pretrained.model_name]
    except Exception as e:
        pass
    if torch.cuda.device_count()==4 and config.sample.batch_size>4:
        config.sample.batch_size = 16
    try:

    if config.pretrained.model_name=='SDv21':
        config.sample.batch_size = 8
        config.train.batch_size = 1
        config.sample.guidance_scale = 1.0
        config.train.gradient_accumulation_steps = 8
    if config.pretrained.model_name=='SDXL':
        config.train.learning_rate = 3e-4
        config.sample.guidance_scale = 7.0
        config.sample.batch_size = 16
        config.sample.num_steps = 30
    if config.pretrained.model_name=='SDturbo':
        config.train.learning_rate = 3e-4
        config.sample.guidance_scale = 1.0
        config.sample.batch_size = 16
    current_time = datetime.datetime.now()
    human_readable_time = current_time.strftime('%Y-%m-%d_%H_%M_%S')
    # unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
    unique_id = f"{config.intrinsic_reward_fn}{'_cut2' if config.step_cut>0 else ''}{f'_kl{config.reward_weight}' if config.train.use_kl else ''}_{'' if config.intrinsic_reward_weight else config.intrinsic_reward_weight}_{config.seed}_" + datetime.datetime.now().strftime("%m-%d_%H_%M")
    if not config.run_name:
        config.run_name = unique_id
    else:
        config.run_name += "_" + unique_id
    if config.resume_from:
        config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
        print(config.resume_from)
        if "checkpoint_" not in os.path.basename(config.resume_from):
            # get the most recent checkpoint in this directory
            checkpoints = list(
                filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from))
            )
            if len(checkpoints) == 0:
                raise ValueError(f"No checkpoints found in {config.resume_from}")
            config.resume_from = os.path.join(
                config.resume_from,
                sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1],
            )

    # number of timesteps within each trajectory to train on
    num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)

    accelerator_config = ProjectConfiguration(
        project_dir=os.path.join(config.logdir, config.run_name),
        automatic_checkpoint_naming=True,
        total_limit=config.num_checkpoint_limit,
    )
    os.environ["WANDB_MODE"] = "offline"
    accelerator = Accelerator(
        log_with="wandb",
        mixed_precision=config.mixed_precision,
        project_config=accelerator_config,
        # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
        # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
        # the total number of optimizer steps to accumulate across.
        gradient_accumulation_steps=config.train.gradient_accumulation_steps
        * num_train_timesteps if config.step_cut==0 else 1,
    )
    if accelerator.is_main_process:
        model_version_prefix = config.pretrained.model_name
        accelerator.init_trackers(
            project_name=f"{model_version_prefix}-AAAI-MT1-{config.project_name}" if config.project_name else f"{model_version_prefix}-AAAI-MT1-aesthetic",
            config=config.to_dict(),
            init_kwargs={"wandb": {"name": config.run_name}},
        )
    logger.info(f"\n{config}")
    # set seed (device_specific is very important to get different prompts on different devices)
    set_seed(config.seed, device_specific=True)
    # My local downloaded model
    # if config.pretrained.model.endswith('4'):
    #     config.pretrained.model="./ckpts/sd14"
    # load scheduler, tokenizer and models.
    pipeline = StableDiffusionPipeline.from_pretrained(
        config.pretrained.model, revision=config.pretrained.revision
    )
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(not config.use_lora)
    # disable safety checker
    pipeline.safety_checker = None
    # make the progress bar nicer
    pipeline.set_progress_bar_config(
        position=1,
        disable=not accelerator.is_local_main_process,
        leave=False,
        desc="Timestep",
        dynamic_ncols=True,
    )
    # switch to DDIM scheduler
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    inference_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16

    # Move unet, vae and text_encoder to device and cast to inference_dtype
    pipeline.vae.to(accelerator.device, dtype=inference_dtype)
    pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype)
    if config.use_lora:
        pipeline.unet.to(accelerator.device, dtype=inference_dtype)

    if config.intrinsic_reward_fn.endswith('clip') or config.intrinsic_reward_fn.endswith('clip_ada'):
        disnet_clip = Discriminator_CLIP(model=pipeline,clip_name=config.clip_model,config=config)
        disnet_target2 = Discriminator_CLIP(model=pipeline,clip_name=config.target2_model,config=config)
    else:
        disnet_clip=None
        disnet_target2=None

    if config.train.use_kl:
        unet_copy = UNet2DConditionModel.from_pretrained(
            config.pretrained.model,
            subfolder="unet",
            revision=config.pretrained.revision,
        )
        unet_copy.requires_grad_(False)
        unet_copy.to(accelerator.device, dtype=inference_dtype)
    else:
        unet_copy = None

    if config.use_lora:
        # Set correct lora layers
        lora_attn_procs = {}
        for name in pipeline.unet.attn_processors.keys():
            cross_attention_dim = (
                None
                if name.endswith("attn1.processor")
                else pipeline.unet.config.cross_attention_dim
            )
            if name.startswith("mid_block"):
                hidden_size = pipeline.unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[
                    block_id
                ]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = pipeline.unet.config.block_out_channels[block_id]

            lora_attn_procs[name] = LoRAAttnProcessor(
                hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
            )
        pipeline.unet.set_attn_processor(lora_attn_procs)

        # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
        # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
        # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
        class _Wrapper(AttnProcsLayers):
            def forward(self, *args, **kwargs):
                return pipeline.unet(*args, **kwargs)

        unet = _Wrapper(pipeline.unet.attn_processors)
    else:
        unet = pipeline.unet

    # set up diffusers-friendly checkpoint saving with Accelerate

    def save_model_hook(models, weights, output_dir):
        assert len(models) >= 1
        if config.use_lora and isinstance(models[0], AttnProcsLayers):
            pipeline.unet.save_attn_procs(output_dir)
        elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
            models[0].save_pretrained(os.path.join(output_dir, "unet"))
        else:
            raise ValueError(f"Unknown model type {type(models[0])}")
        weights.clear()  # ensures that accelerate doesn't try to handle saving of the model

    def load_model_hook(models, input_dir):
        assert len(models) == 1
        if config.use_lora and isinstance(models[0], AttnProcsLayers):
            # pipeline.unet.load_attn_procs(input_dir)
            tmp_unet = UNet2DConditionModel.from_pretrained(
                config.pretrained.model,
                revision=config.pretrained.revision,
                subfolder="unet",
            )
            tmp_unet.load_attn_procs(input_dir)
            models[0].load_state_dict(
                AttnProcsLayers(tmp_unet.attn_processors).state_dict()
            )
            del tmp_unet
        elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
            load_model = UNet2DConditionModel.from_pretrained(
                input_dir, subfolder="unet"
            )
            models[0].register_to_config(**load_model.config)
            models[0].load_state_dict(load_model.state_dict())
            del load_model
        else:
            raise ValueError(f"Unknown model type {type(models[0])}")
        models.clear()  # ensures that accelerate doesn't try to handle loading of the model

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    # Initialize the optimizer
    if config.train.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer_unet = optimizer_cls(
        unet.parameters(),
        lr=config.train.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )
    # prepare prompt and reward fn
    prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) # simple animal
    reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)(accelerator.device)  # DDPO reward
    # -------------------------------------------------------------------------------
    intrinsic_reward_fn = getattr(ddpo_pytorch.rewards, config.intrinsic_reward_fn)()
    disnet_fn = getattr(ddpo_pytorch.rewards,'baseline_ada_v2')()
    # -------------------------------------------------------------------------------

    # generate negative prompt embeddings
    neg_prompt_embed = pipeline.text_encoder(
        pipeline.tokenizer(
            [""],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=pipeline.tokenizer.model_max_length,
        ).input_ids.to(accelerator.device)
    )[0]
    sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1)
    train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)

    # initialize stat tracker
    if config.per_prompt_stat_tracking:
        stat_tracker = PerPromptStatTracker(
            config.per_prompt_stat_tracking.buffer_size,
            config.per_prompt_stat_tracking.min_count,
        )
        # 追踪内部奖励均值方差--------------------------------------------------------------------------------
        intrinsic_stat_tracker = PerPromptStatTracker(
            config.per_prompt_stat_tracking.buffer_size,
            config.per_prompt_stat_tracking.min_count,
        )
        # --------------------------------------------------------------------------------

    # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
    # more memory
    autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
    # autocast = accelerator.autocast

    # Prepare everything with our `accelerator`.
    # if disnet:
    #     unet, disnet, optimizer = accelerator.prepare(unet, disnet, optimizer)
    # else:
    unet, optimizer = accelerator.prepare(unet, optimizer_unet)

    # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
    # remote server running llava inference.
    executor = futures.ThreadPoolExecutor(max_workers=2)
    executor2 = futures.ThreadPoolExecutor(max_workers=2)
    # executor3 = futures.ThreadPoolExecutor(max_workers=2)
    # Train!
    samples_per_epoch = (
        config.sample.batch_size
        * accelerator.num_processes
        * config.sample.num_batches_per_epoch
    )
    total_train_batch_size = (
        config.train.batch_size
        * accelerator.num_processes
        * config.train.gradient_accumulation_steps
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {config.num_epochs}")
    logger.info(f"  Sample batch size per device = {config.sample.batch_size}")
    logger.info(f"  Train batch size per device = {config.train.batch_size}")
    logger.info(
        f"  Gradient Accumulation steps = {config.train.gradient_accumulation_steps}"
    )
    logger.info("")
    logger.info(f"  Total number of samples per epoch = {samples_per_epoch}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}"
    )
    logger.info(
        f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}"
    )
    logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}")

    assert config.sample.batch_size >= config.train.batch_size
    assert config.sample.batch_size % config.train.batch_size == 0
    assert samples_per_epoch % total_train_batch_size == 0

    if config.resume_from:
        logger.info(f"Resuming from {config.resume_from}")
        accelerator.load_state(config.resume_from)
        first_epoch = int(config.resume_from.split("_")[-1]) + 1
    else:
        first_epoch = 0

    global_step = 0

    if config.validation:
        score_dict = eval_all(pipeline, config, accelerator=accelerator, logger=logger, test_names=config.test_prompts)
        if accelerator.is_main_process:
            accelerator.log(
                score_dict,
                step=global_step,
            )
        # if config.test_prompts!='all' and accelerator.is_main_process:
        #     score_dict = eval_all(pipeline, config,logger=logger,test_names=config.test_prompts)
        #
        #     accelerator.log(
        #         score_dict,
        #         step=global_step,
        #     )
        #     logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
        # elif config.test_prompts=='all':
        #     for idx,test_prompts in enumerate(['imagenet_animals','simple_animals','activities_animals']):
        #         if accelerator.process_index==idx+1:
        #             score_dict = eval_all(pipeline, config,logger=logger, test_names=test_prompts)
        #             accelerator.log(
        #                 score_dict,
        #                 step=global_step,
        #             )
        #             logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
    if accelerator.is_local_main_process:
        clock=Clock()
        logger.info(f"  Clock enabled at {clock.start()}")
    reward_count=0

    for epoch in range(first_epoch, config.num_epochs):
        #################### SAMPLING ####################
        pipeline.unet.eval()
        samples = []
        prompts = []
        for i in tqdm(
            range(config.sample.num_batches_per_epoch),
            desc=f"Epoch {epoch}: sampling",
            disable=not accelerator.is_local_main_process,
            position=0,
        ):
            # generate prompts
            prompts, prompt_metadata = zip(
                *[
                    prompt_fn(**config.prompt_fn_kwargs)
                    for _ in range(config.sample.batch_size)
                ]
            )

            # encode prompts
            prompt_ids = pipeline.tokenizer(
                prompts,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=pipeline.tokenizer.model_max_length,
            ).input_ids.to(accelerator.device)
            prompt_embeds = pipeline.text_encoder(prompt_ids)[0]

            # sample
            with autocast():
                images, _, latents, log_probs,each_origin = pipeline_with_logprob(
                    pipeline,
                    prompt_embeds=prompt_embeds,
                    negative_prompt_embeds=sample_neg_prompt_embeds,
                    num_inference_steps=config.sample.num_steps,
                    guidance_scale=config.sample.guidance_scale,
                    eta=config.sample.eta,
                    output_type="pt",
                    return_origin=True
                )
            each_origin = torch.stack(
                each_origin, dim=1
            )  # (batch_size, num_steps + 1, 4, 64, 64)
            latents = torch.stack(
                latents, dim=1
            )  # (batch_size, num_steps + 1, 4, 64, 64)
            log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1)
            timesteps = pipeline.scheduler.timesteps.repeat(
                config.sample.batch_size, 1
            )  # (batch_size, num_steps)

            # compute rewards asynchronously
            if accelerator.is_local_main_process and epoch == first_epoch and i == 0:
                print('type(images):', type(images), ', images.shape:', images.shape)  # images.shape: torch.Size([bs, 3, 512, 512])
                print('latents.shape:', latents.shape)  # latents.shape: torch.Size([bs, step+1, 4, 64, 64])
            if config.multi_target=='weight':
                rewards = executor.submit(reward_fn, images, prompts, config.weight_ACI)
            else:
                rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
            if config.reward_latent=='xt':
                i_reward_latents = latents
            else:
                i_reward_latents = torch.cat([latents[:, 0:1, :, :, :], each_origin, ], dim=1)
            intrinsic_rewards = executor2.submit(intrinsic_reward_fn, images, prompts, prompt_metadata, i_reward_latents.detach())
          # --------------------------------------------------------------------------------
            # yield to make sure reward computation starts
            time.sleep(0)

            samples.append(
                {
                    'prompts':prompts,
                    "prompt_ids": prompt_ids,
                    "prompt_embeds": prompt_embeds,
                    "timesteps": timesteps,
                    'each_origins':each_origin,
                    "latents": latents[
                        :, :-1
                    ],  # each entry is the latent before timestep t
                    "next_latents": latents[
                        :, 1:
                    ],  # each entry is the latent after timestep t
                    "log_probs": log_probs,
                    "rewards": rewards,
                    # --------------------------------------------------------------------------------
                    "intrinsic_rewards": intrinsic_rewards,
                    # 'disnet_output':disnet_output
                    # --------------------------------------------------------------------------------
                }
            )

        # wait for all rewards to be computed
        for sample in tqdm(  # sample = each traj
            samples,
            desc="Waiting for rewards",
            disable=not accelerator.is_local_main_process,
            position=0,
        ):
            rewards, reward_metadata = sample["rewards"].result()
            reward_name = config.reward_fn.split('_')
            sample.update({
                f"r_{reward_name[0]}": reward_metadata[0],
                f"r_{reward_name[1]}": reward_metadata[1],
                f"r_{reward_name[2]}": reward_metadata[2],
            })
            # --------------------------------------------------------------------------------
            intrinsic_rewards, intristic_reward_metadata = sample["intrinsic_rewards"].result()
            # --------------------------------------------------------------------------------
            if config.intrinsic_reward_fn.endswith('clip') or config.intrinsic_reward_fn.endswith('clip_ada'):
                prompt_current = sample['prompts']
                prompt_embed_current = sample['prompt_embeds']
                latents_for_decode = sample['each_origins']
                latents_for_decode = torch.cat([sample['latents'][:, 0:1, :, :, :], latents_for_decode, ], dim=1)
                clip_list = []
                target2_list = []
                for idd in tqdm(range(latents_for_decode.shape[1]), desc="Processing Clip",
                                disable=not accelerator.is_local_main_process,):
                    latents_each = latents_for_decode[:, idd]
                    clip_p = disnet_clip(latents_each,prompt_embed_current,prompt_current)
                    target2_p =disnet_target2(latents_each,prompt_embed_current,prompt_current)
                    clip_list.append(clip_p)
                    target2_list.append(target2_p)
                clip_array = torch.stack(clip_list, axis=1).squeeze()
                target2_array = torch.stack(target2_list,axis=1).squeeze()
                delta_clip_array = accelerator.unwrap_model(disnet_clip).get_delta(clip_array)
                delta_target2_array = accelerator.unwrap_model(disnet_clip).get_delta(target2_array)

                if config.delta_min > -0.9:
                    x_tensor = torch.tensor(config.delta_min)
                    delta_clip_array = torch.where(delta_clip_array < 0, x_tensor.to(delta_clip_array.device),
                                                   delta_clip_array)
                # delta_clip_array=zero_tensor_by_ratio(delta_clip_array,zero_ratio=config.delta_drop_ratio,manner=config.delta_drop_manner)
                # delta_target2_array = zero_tensor_by_ratio(delta_target2_array, zero_ratio=config.delta_drop_ratio,
                #                                         manner=config.delta_drop_manner)
                delta_cut, delta_final = delta_combine(delta_clip_array,clip_array,delta_target2_array,
                                                       delta_abs_t=config.delta_abs_t,delta_delta_t=config.delta_delta_t,
                                                       first_stage=config.first_stage,target2_scale=config.t2_scale)
                delta_final = zero_tensor_by_ratio(delta_final, zero_ratio=config.delta_drop_ratio,
                                                        manner=config.delta_drop_manner)
                sample['disnet_output'] = delta_cut            # disnet_output = disnet_p(latents.detach()[:, 1:])
                intrinsic_rewards=delta_final*intrinsic_rewards
                # save_tensor_to_txt(delta_final.mean(0),
                #                    f'aaai_midres/{human_readable_time}_delta_sum.txt')
                # save_tensor_to_txt(delta_clip_array.mean(0),
                #                    f'aaai_midres/{human_readable_time}_delta_t1.txt')
                # save_tensor_to_txt(delta_target2_array.mean(0),
                #                    f'aaai_midres/{human_readable_time}_delta_t2.txt')
                if config.p_loc==1:
                    p_dis=sample["disnet_output"].detach()
                    if not config.wo_p:
                        if config.remain_05:
                            intrinsic_rewards = intrinsic_rewards * (1.5 - p_dis)
                            rewards_ada = rewards.unsqueeze(1) * (p_dis+0.5)
                        elif config.intri_only:
                            intrinsic_rewards = intrinsic_rewards * (1 - p_dis)

                            rewards_ada = rewards.unsqueeze(1).repeat(1, 50)
                        else:
                            intrinsic_rewards = intrinsic_rewards * (1 - p_dis)

                            rewards_ada = rewards.unsqueeze(1) * p_dis
                    else:
                        rewards_ada=rewards.unsqueeze(1).repeat(1, 50)
                    sample['rewards_ada']=rewards_ada
            del sample['prompts']

            # accelerator.print(reward_metadata)
            sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device)
            # --------------------------------------------------------------------------------
            sample["intrinsic_rewards"] = torch.as_tensor(intrinsic_rewards, device=accelerator.device)

            # --------------------------------------------------------------------------------
        # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
        keys = samples[0].keys()

        result = {}


        for k in keys:

            concatenated_values = [s[k] for s in samples]


            result[k] = torch.cat(concatenated_values)


        samples = result
        # this is a hack to force wandb to log the images as JPEGs instead of PNGs
        if accelerator.is_local_main_process:
            with tempfile.TemporaryDirectory() as tmpdir:
                for i, image in enumerate(images):
                    pil = Image.fromarray(
                        (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
                    )
                    # pil = pil.resize((256, 256))
                    pil.save(os.path.join(tmpdir, f"{i}.jpg"))
                accelerator.log(
                    {
                        "images": [
                            wandb.Image(
                                os.path.join(tmpdir, f"{i}.jpg"),
                                caption=f"{prompt:.25} | {reward:.2f}",
                            )
                            for i, (prompt, reward) in enumerate(
                                zip(prompts, rewards)
                            )  # only log rewards from process 0
                        ],
                    },
                    step=global_step,
                )
        # gather rewards across processes
        rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
        if config.p_loc==1:
            rewards_ada = accelerator.gather(samples["rewards_ada"]).cpu().numpy()
        # --------------------------------------------------------------------------------
        intrinsic_rewards = accelerator.gather(samples["intrinsic_rewards"]).cpu().numpy()
        info = defaultdict(list)
        accelerator.log(accelerator.reduce({f'delta_target2':samples["disnet_output"].mean(axis=0)}, reduction="mean"), step=global_step)

        pred_p = accelerator.gather(samples['disnet_output'].detach()).cpu().numpy()




        if accelerator.is_local_main_process and config.intrinsic_reward_fn != 'baseline':
            print('rewards.shape:', rewards.shape)  # rewards.shape: torch.Size([ngpus * sample.batch_size * sample.num_batches_per_epoch])
            # print('rewards:', rewards)
            print('zero rewards:', len(rewards[np.isclose(rewards, 0.0)]))
            # print('intrinsic_rewards.shape:', intrinsic_rewards.shape)
            # print('intrinsic_rewards:', intrinsic_rewards)
            # print('intrinsic_rewards.mean:', intrinsic_rewards.mean(axis=0))
            # print('intrinsic_advantages:', repr(intrinsic_advantages))
        # per-prompt mean/std tracking
        if config.per_prompt_stat_tracking:
            # gather the prompts across processes
            prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
            prompts = pipeline.tokenizer.batch_decode(
                prompt_ids, skip_special_tokens=True
            )
            if config.p_loc==1:
                advantages = stat_tracker.update(prompts, rewards_ada)
            else:
                advantages = stat_tracker.update(prompts, rewards)
            # --------------------------------------------------------------------------------
            intrinsic_advantages = intrinsic_stat_tracker.update(prompts, intrinsic_rewards)
            # --------------------------------------------------------------------------------
        else:
            if config.p_loc==1:
                advantages = (rewards_ada - rewards_ada.mean(axis=0)) / (rewards_ada.std(axis=0) + 1e-8)
                print(advantages.shape)
                print(rewards.shape)
            else:
                advantages = (rewards - rewards) / (rewards + 1e-8)
            # --------------------------------------------------------------------------------
            # [[1, 2, 3, 4],
            #  [2, 3, 4, 5]]
            # mean = [1.5, 2.5, 3.5, 4.5]
            # std = [...]
            intrinsic_advantages = (intrinsic_rewards - intrinsic_rewards.mean(axis=0)) / (intrinsic_rewards.std(axis=0) + 1e-8)
            # --------------------------------------------------------------------------------
        if epoch>config.cut_warmup and config.step_cut>0:
            if disnet_p:
                # num_train_timesteps, raw_array = step_predictor_noise(pred_delta_noise=pred_p,tao=config.tao)
                num_train_timesteps, raw_array = step_predictor_der\
                    (target_2=pred_p, delta_abs_t=config.step_delta_abs_t,delta_delta_t=config.step_delta_delta_t)
            else:
                num_train_timesteps,raw_array=step_predictor_empty(clip_p=clip_p)

        reward_name = config.reward_fn.split('_')
        info_multi_reward={
            f"r_{reward_name[0]}_mean":samples[f'r_{reward_name[0]}'].mean().item(),
            f"r_{reward_name[1]}_mean":samples[f'r_{reward_name[1]}'].mean().item(),
            f"r_{reward_name[2]}_mean":samples[f'r_{reward_name[2]}'].mean().item(),
        }
        # info_multi_reward=accelerator.reduce(info_multi_reward, reduction="mean")
        info_multi_reward.update({
                "reward": rewards,
                "intrinsic_rewards": intrinsic_rewards,
                "epoch": epoch,
                "reward_mean": rewards.mean(),
                "reward_std": rewards.std(),
                "intrinsic_reward_mean": intrinsic_rewards.mean(axis=0),
                "intrinsic_rewards.all_step_mean":intrinsic_rewards.mean(),
                "intrinsic_reward_std": intrinsic_rewards.std(axis=0),
                # 'intrinsic_advantages': intrinsic_advantages,
                # 'intrinsic_advantages_mean': intrinsic_advantages.mean(axis=0),
                # 'advantages': advantages,
            })
        accelerator.log(
            info_multi_reward,
            step=global_step,
        )
        if accelerator.is_local_main_process and reward_count!=-1:
            if rewards.mean()>config.cut_all_score:
                reward_count+=1
                if reward_count>1:
                    elapsed_time=clock.end()
                    accelerator.log(
                        {
                            'elapsed_time': elapsed_time
                        },
                        step=global_step,
                    )
                    reward_count=-1
        # ungather advantages; we only need to keep the entries corresponding to the samples on this process
        if config.p_loc==1:
            samples["advantages"] = (
                torch.as_tensor(advantages)
                    .reshape(accelerator.num_processes, -1,advantages.shape[-1])[accelerator.process_index]
                    .to(accelerator.device)
            )
        else:
            samples["advantages"] = (
                torch.as_tensor(advantages)
                .reshape(accelerator.num_processes, -1)[accelerator.process_index]
                .to(accelerator.device)
            )
        samples["intrinsic_advantages"] = (
            torch.as_tensor(intrinsic_advantages)
            .reshape(accelerator.num_processes, -1, intrinsic_advantages.shape[-1])[accelerator.process_index]
            .to(accelerator.device)
        )
        if config.step_cut==2:
            samples["raw_array"] = (
                torch.as_tensor(raw_array)
                .reshape(accelerator.num_processes, -1)[accelerator.process_index]
            .to(accelerator.device)
            )
        torch.set_printoptions(threshold=float('inf'))  # 设置阈值为无限大，确保显示所有元素
        for each in samples["intrinsic_advantages"]:
            save_tensor_to_txt(each,
                               f'cjk_midres/intrinsic_clip_advantages_{human_readable_time}.txt')
        del samples["rewards"]
        del samples["intrinsic_rewards"]
        del samples["prompt_ids"]

        total_batch_size, num_timesteps = samples["timesteps"].shape
        assert (
            total_batch_size
            == config.sample.batch_size * config.sample.num_batches_per_epoch
        )
        assert num_timesteps == config.sample.num_steps

        #################### TRAINING ####################
        for inner_epoch in range(config.train.num_inner_epochs):
            # shuffle samples along batch dimension
            # cjk_note: useless since num_in_epoch=1
            if config.step_cut==2:
                # cjk_note: reorder sample by step_num, a light-weight method to approximately align step number of each gpu/batch
                perm = torch.argsort(samples['raw_array'], descending=True).to(accelerator.device)
            else:
                perm = torch.randperm(total_batch_size, device=accelerator.device)
            samples = {k: v[perm] for k, v in samples.items()}
            # if config.step_cut==2:
            #     print(f'Step array on Rank-{accelerator.state.local_process_index}: {samples["raw_array"]}')

            if config.step_cut<2:
                # shuffle along time dimension independently for each sample
                # cjk_note: it breaks the order of denoising, unexpected for step cut. (We want drop last)
                perms = torch.stack(
                    [
                        torch.randperm(num_timesteps, device=accelerator.device)
                        for _ in range(total_batch_size)
                    ]
                )
                for key in ["timesteps", "latents", "next_latents", "log_probs"]:
                    samples[key] = samples[key][
                        torch.arange(total_batch_size, device=accelerator.device)[:, None],
                        perms,
                    ]
            # rebatch for training
            samples_batched = {
                k: v.reshape(-1, config.train.batch_size, *v.shape[1:])
                for k, v in samples.items()
            }

            # dict of lists -> list of dicts for easier iteration
            samples_batched = [
                dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
            ]

            # train
            pipeline.unet.train()
            info = defaultdict(list)
            # if config.step_cut == 2:
            #     each_accu_steps_arr = \
            #         samples['raw_array'].cpu().detach().numpy().reshape(-1,config.train.gradient_accumulation_steps*config.train.batch_size).sum(axis=1)
            #     # b2_res=each_accu_steps_arr.reshape(-1, 2).max(axis=1)
            #     print(f'Each-step array on Rank-{accelerator.state.local_process_index}: {each_accu_steps_arr}')
            accumulate_count=0
            inner_counter=0
            avg_step_num_array=[]
            avg_step_num_array_v2 = []
            for i, sample in tqdm(
                list(enumerate(samples_batched)),
                desc=f"Epoch {epoch}.{inner_epoch}: training",
                position=0,
                disable=not accelerator.is_local_main_process,
            ):
                if config.train.cfg:
                    # concat negative prompts to sample prompts to avoid two forward passes
                    embeds = torch.cat(
                        [train_neg_prompt_embeds, sample["prompt_embeds"]]
                    )
                else:
                    embeds = sample["prompt_embeds"]
                # cjk_note: get the avg step in one forward across num_gpu * batch_size
                if config.step_cut == 2:
                    num_train_timesteps=round(sample['raw_array'].cpu().detach().numpy().mean())
                    # accelerator.gradient_accumulation_steps = config.train.gradient_accumulation_steps * num_train_timesteps
                if len(sample['raw_array'])==1:
                    num_train_timesteps_gather = accelerator.gather(sample['raw_array'])
                else:
                    num_train_timesteps_gather = accelerator.gather(sample['raw_array'].float().mean().long())
                total = num_train_timesteps_gather.sum().item() # total sample number for gradient accumulation
                try:
                    avg_train_steps=int(num_train_timesteps_gather.cpu().detach().numpy().mean())
                    # print(f"avg_train_steps {avg_train_steps} at i={i}")
                    avg_step_num_array.append(avg_train_steps)
                    avg_step_num_array_v2.append(avg_train_steps)
                    num_train_timesteps=avg_train_steps
                except Exception as e:
                    print(num_train_timesteps_gather)
                # print(num_train_timesteps)
                for j in tqdm(
                    range(num_train_timesteps),
                    desc="Timestep",
                    position=1,
                    leave=False,
                    disable=not accelerator.is_local_main_process,
                ):
                    if config.step_cut == 0:
                        context =  accelerator.accumulate(unet)
                    else:
                        context = contextlib.nullcontext()

                    with context:
                        with autocast():
                            noise_pred = unet_predict_noise(config,
                                                            unet,
                                                            sample["latents"][:, j],
                                                            sample["timesteps"][:, j],
                                                            embeds)
                            # compute the log prob of next_latents given latents under the current model

                            _, log_prob = ddim_step_with_logprob(
                                pipeline.scheduler,
                                noise_pred,
                                sample["timesteps"][:, j],
                                sample["latents"][:, j],
                                eta=config.sample.eta,
                                prev_sample=sample["next_latents"][:, j],
                            )
                        inner_counter+=1
                        # print(f'p value is: {p}')
                        if config.reward_fn.startswith("dummy"):
                            advantages = config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                        elif config.reward_fn.startswith("extrinsic"):
                            if j == num_train_timesteps - 1:
                                advantages = sample["advantages"]
                            else:
                                advantages = config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                        else:

                            if config.p_loc==2:
                                advantages = sample["advantages"] + config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                            else:
                                # print(sample["advantages"])
                                advantages = sample["advantages"][:, j] + config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                        # ----------------------------------------------------------------------
                        advantages = torch.clamp(
                            advantages,
                            -config.train.adv_clip_max,
                            config.train.adv_clip_max,
                        )
                        ratio = torch.exp(log_prob - sample["log_probs"][:, j])
                        unclipped_loss = -advantages * ratio
                        '''
                        unclipped_loss+ beta (unnclipped_loss_next +un)
                        '''
                        clipped_loss = -advantages * torch.clamp(
                            ratio,
                            1.0 - config.train.clip_range,
                            1.0 + config.train.clip_range,
                        )
                        loss = config.reward_weight * torch.mean(torch.maximum(unclipped_loss, clipped_loss))

                        if config.train.use_kl:
                            with autocast():
                                noise_pred_old = unet_predict_noise(config,
                                                                    unet_copy,
                                                                    sample["latents"][:, j],
                                                                    sample["timesteps"][:, j],
                                                                    embeds)
                            kl_regularizer = (noise_pred - noise_pred_old) ** 2
                            if i >= config.train.kl_warmup:
                                loss += config.train.kl_weight * kl_regularizer.mean()
                                info["kl_regularizer"].append(kl_regularizer.mean())
                        # cjk_note: manually adjust loss for gradient accumulation without accelerator
                        if config.step_cut==2:
                            loss = (loss * config.train.gradient_accumulation_steps * accelerator.num_processes) / total
                        # debugging values
                        # John Schulman says that (ratio - 1) - log(ratio) is a better
                        # estimator, but most existing code uses this so...
                        # http://joschu.net/blog/kl-approx.html
                        info["approx_kl"].append(
                            0.5
                            * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
                        )
                        info["clipfrac"].append(
                            torch.mean(
                                (
                                    torch.abs(ratio - 1.0) > config.train.clip_range
                                ).float()
                            )
                        )
                        info["loss"].append(loss)
                        #print('performing backward')
                        # backward pass
                        accelerator.backward(loss)
                        # cjk_note: preserving old code
                        if config.step_cut==0 or config.step_cut==1:
                            if accelerator.sync_gradients:
                                accelerator.clip_grad_norm_(
                                    unet.parameters(), config.train.max_grad_norm
                                )
                            # if config.train.gradient_accumulation_steps* num_train_timesteps
                            optimizer.step()
                            optimizer.zero_grad()
                        # if check_all_gradients_unupdated(unet):
                        #     print(f'rank-{accelerator.state.local_process_index} updated at {i}-{j}')
                    # cjk_note: preserving old code
                    if accelerator.sync_gradients and config.step_cut < 2:
                        # assert (j == num_train_timesteps - 1) and (
                        #     i + 1
                        # ) % config.train.gradient_accumulation_steps == 0
                        # log training-related stuff
                        info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                        info = accelerator.reduce(info, reduction="mean")
                        info.update({"epoch": epoch, "inner_epoch": inner_epoch})
                        accelerator.log(info, step=global_step)

                        global_step += 1
                        info = defaultdict(list)
                accumulate_count+=1
                if accumulate_count==config.train.gradient_accumulation_steps and config.step_cut==2:
                    # print(f'Before Accumulating final on rank{accelerator.state.local_process_index}: sample_No={i},Step_total={j} accelerator.sync_gradients={accelerator.sync_gradients}, check:{check_all_gradients_unupdated(unet)}',flush=True)
                    # old=accelerator.gradient_accumulation_steps
                    #accelerator.gradient_accumulation_steps=1
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(
                            unet.parameters(), config.train.max_grad_norm
                        )
                    optimizer.step()
                    optimizer.zero_grad()
                    accumulate_count=0
                    # print(f'After Accumulating final on rank{accelerator.state.local_process_index}: sample_No={i}, accelerator.sync_gradients={accelerator.sync_gradients}, check:{check_all_gradients_unupdated(unet)}',flush=True)
                    info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                    info = accelerator.reduce(info, reduction="mean")
                    info.update({"epoch": epoch, "inner_epoch": inner_epoch})
                    accelerator.log({"step_num": wandb.Histogram(np.array(avg_step_num_array))}, step=global_step)
                    avg_step_num_array=[]
                    accelerator.log(info, step=global_step)
                    global_step += 1
                    info = defaultdict(list)
                # Checks if the accelerator has performed an optimization step behind the scenes 121移动了一格


            accelerator.log({"step_num_v2": wandb.Histogram(np.array(avg_step_num_array_v2,dtype=float))}, step=global_step)
            print(f"sync_gradient on {accelerator.state.local_process_index} (before):", accelerator.sync_gradients, flush=True)
            assert accelerator.sync_gradients
        # print(f"---->accelerator.save_state in {accelerator.state.local_process_index}<----")
        if epoch%10==1 or epoch==99:
            accelerator.save_state()
        if config.validation and epoch % config.vali_interval == 0:
            if accelerator.is_main_process and reward_count != -1:
                clock.pause()
            score_dict = eval_all(pipeline, config, accelerator=accelerator, logger=logger, test_names=config.test_prompts)
            if accelerator.is_main_process:
                accelerator.log(
                    score_dict,
                    step=global_step,
                )
            # if config.test_prompts != 'all' and accelerator.is_main_process:
            #     score_dict = eval_all(pipeline, config, logger=logger,test_names=config.test_prompts)
            #
            #     accelerator.log(
            #         score_dict,
            #         step=global_step,
            #     )
            #     logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
            # elif config.test_prompts == 'all':
            #     for idx, test_prompts in enumerate(['imagenet_animals', 'simple_animals', 'activities_animals']):
            #         if accelerator.process_index == idx + 1:
            #             score_dict = eval_all(pipeline, config, logger=logger, test_names=test_prompts)
            #             accelerator.log(
            #                 score_dict,
            #                 step=global_step,
            #             )
            #             logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
            if reward_count != -1 and accelerator.is_main_process:
                clock.restart()
            torch.cuda.empty_cache()
            del samples
            del samples_batched
        if accelerator.is_local_main_process and reward_count != -1:
            elapsed_time = clock.end()
            accelerator.log(
                {
                    'elapsed_time': elapsed_time
                },
                step=global_step,
            )


if __name__ == "__main__":
    app.run(main)