import os
os.environ['CURL_CA_BUNDLE'] = ''
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
print(os.getcwd())
import sys
import signal

def handle_signal(signum, frame):
    print(f"Received signal {signum}, ignoring.")

signal.signal(signal.SIGHUP, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
# os.environ["HUGGINGFACE_HUB_URL"] = "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-hub"
# os.environ["HF_DATASETS_URL"] = "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-datasets"
# os.environ["HF_METRICS_URL"] = "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-metrics"
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:10809'
from ddpo_pytorch.Clock import Clock
from ddpo_pytorch.adaptive_model import Discriminator_CLIP, Discriminator_P, Discriminator_Eta
from scripts.Evals import eval_all
from utils.dropper import zero_tensor_by_ratio, clone_tensor_with_grad
import torch.distributed as dist
# os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo"
from collections import defaultdict
import contextlib
# os.environ["HF_HOME"] = "/path/to/your/hf_home_directory"
# os.environ["TRANSFORMERS_CACHE"] = "/path/to/your/transformers_cache_directory"
# os.environ["HF_DATASETS_CACHE"] = "/path/to/your/datasets_cache_directory"

import datetime
from concurrent import futures
import time
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import numpy as np
import ddpo_pytorch.prompts
import ddpo_pytorch.rewards
# from ddpo_pytorch.adaptive_model import Discriminator
from ddpo_pytorch.stat_tracking import PerPromptStatTracker
from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob, unet_predict_noise
from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob
import torch
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
from torch import inf
print(torch.cuda.get_device_name())
on_3060 = "3060" in torch.cuda.get_device_name()
on_h100 = "H100" in torch.cuda.get_device_name()
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)


FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.")

logger = get_logger(__name__)
# def force_cudnn_initialization():
#     s = 32
#     dev = torch.device('cuda')
#     torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
# force_cudnn_initialization()
def save_tensor_to_txt(tensor, file_path):
    with open(file_path, 'a+') as file:
        file.write(str(tensor) + "\n\n")
def check_all_gradients_unupdated(model):
    for name, param in model.named_parameters():
        if param.grad is not None and param.grad.abs().sum() > 0:
            return False
    return True

def global_grad_norm_(parameters, norm_type=2):
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type
        total_norm = total_norm ** (1. / norm_type)

    return total_norm
def step_predictor(rewards,stage='train',rate_threshold=0.99,abs_threshold=-0.52,valid_step=36,option='mean'):
    if stage == 'train':
        reward_valid=rewards[:,valid_step:]
        factor_rate = reward_valid[:,1:] / (reward_valid[:,:-1]+1e-8)
        factor_abs = reward_valid
        condition = ((factor_rate>rate_threshold) & (factor_abs[:,1:] >-abs_threshold)) | (factor_abs[:,1:]==0)
        all_step_num=np.argmax(condition, axis=1)
        max=all_step_num.max()
        min=all_step_num.min()
        mean=round(all_step_num.mean())
        if option=='mean':
            final=mean
        elif option=='min':
            final=min
        else:
            final=max
        return valid_step+final+1,all_step_num+valid_step+1

    else:
        raise ValueError('test step predictor is not implemented yet.')
def step_predictor_empty(clip_p=None, stage='train',step_min=20):
    if stage == 'train':
        condition = clip_p>1
        condition[:,-1]=True
        all_step_num = np.argmax(condition, axis=1)+1
        all_step_num[all_step_num < step_min] = step_min
        return all_step_num, all_step_num  # valid_step+final+1,all_step_num+valid_step+1

    else:
        raise ValueError('test step predictor is not implemented yet.')
def step_predictor_v2(pred_p, stage='train',step_min=30):
    if stage == 'train':
        condition = np.isclose(pred_p,1.0)
        condition[:,-1]=True
        all_step_num = np.argmax(condition, axis=1)+1
        all_step_num[all_step_num < step_min] = step_min
        return all_step_num, all_step_num  # valid_step+final+1,all_step_num+valid_step+1

    else:
        raise ValueError('test step predictor is not implemented yet.')
def main(_):
    # basic Accelerate and logging setup
    config = FLAGS.config
    # cjk_note: modification for local run
    SDv14 = "CompVis/stable-diffusion-v1-4"
    SDv21 = "stabilityai/stable-diffusion-2-1"
    SDXL="stabilityai/stable-diffusion-xl-base-1.0"
    SDturbo= "stabilityai/sd-turbo"
    SDv15 = 'sd-legacy/stable-diffusion-v1-5'
    try:
        config.pretrained.model = locals()[config.pretrained.model_name]
    except Exception as e:
        pass
    if torch.cuda.device_count()==4 and config.sample.batch_size>4:
        config.sample.batch_size = 16
    try:
        if on_h100:
            config.train.batch_size = 2
        config.train.batch_size=2

    if config.pretrained.model_name=='SDv21':
        config.sample.batch_size = 8
        config.train.batch_size = 1
        config.sample.guidance_scale = 1.0
        config.train.gradient_accumulation_steps = 8
    if config.pretrained.model_name=='SDXL':
        config.train.learning_rate = 3e-4
        config.sample.guidance_scale = 7.0
        config.sample.batch_size = 16
        config.sample.num_steps = 30
    if config.pretrained.model_name=='SDturbo':
        config.train.learning_rate = 3e-4
        config.sample.guidance_scale = 1.0
        config.sample.batch_size = 16
    current_time = datetime.datetime.now()
    human_readable_time = current_time.strftime('%Y-%m-%d_%H_%M_%S')
    # unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
    unique_id = f"{config.intrinsic_reward_fn}{'_cut2' if config.step_cut>0 else ''}{f'_kl{config.reward_weight}' if config.train.use_kl else ''}_{'' if config.intrinsic_reward_weight else config.intrinsic_reward_weight}_{config.seed}_" + datetime.datetime.now().strftime("%m-%d_%H_%M")
    if not config.run_name:
        config.run_name = unique_id
    else:
        config.run_name += "_" + unique_id
    if config.resume_from:
        config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
        print(config.resume_from)
        if "checkpoint_" not in os.path.basename(config.resume_from):
            # get the most recent checkpoint in this directory
            checkpoints = list(
                filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from))
            )
            if len(checkpoints) == 0:
                raise ValueError(f"No checkpoints found in {config.resume_from}")
            config.resume_from = os.path.join(
                config.resume_from,
                sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1],
            )

    # number of timesteps within each trajectory to train on
    num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)

    accelerator_config = ProjectConfiguration(
        project_dir=os.path.join(config.logdir, config.run_name),
        automatic_checkpoint_naming=True,
        total_limit=config.num_checkpoint_limit,
    )
    os.environ["WANDB_MODE"] = "offline"
    accelerator = Accelerator(
        log_with="wandb",
        mixed_precision=config.mixed_precision,
        project_config=accelerator_config,
        # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
        # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
        # the total number of optimizer steps to accumulate across.
        gradient_accumulation_steps=config.train.gradient_accumulation_steps
        * num_train_timesteps if config.step_cut==0 else 1,
    )
    if accelerator.is_main_process:
        model_version_prefix = config.pretrained.model_name
        accelerator.init_trackers(
            project_name=f"{model_version_prefix}-Revision-{config.project_name}" if config.project_name else f"{model_version_prefix}-revision-aesthetic",
            config=config.to_dict(),
            init_kwargs={"wandb": {"name": config.run_name}},
        )
    logger.info(f"\n{config}")
    # set seed (device_specific is very important to get different prompts on different devices)
    set_seed(config.seed, device_specific=True)
    # My local downloaded model
    if config.pretrained.model.endswith('4'):
        config.pretrained.model="./ckpts/sd14"
    # load scheduler, tokenizer and models.
    pipeline = StableDiffusionPipeline.from_pretrained(
        config.pretrained.model, revision=config.pretrained.revision
    )
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(not config.use_lora)
    # disable safety checker
    pipeline.safety_checker = None
    # make the progress bar nicer
    pipeline.set_progress_bar_config(
        position=1,
        disable=not accelerator.is_local_main_process,
        leave=False,
        desc="Timestep",
        dynamic_ncols=True,
    )
    # switch to DDIM scheduler
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    inference_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16

    # Move unet, vae and text_encoder to device and cast to inference_dtype
    pipeline.vae.to(accelerator.device, dtype=inference_dtype)
    pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype)
    if config.use_lora:
        pipeline.unet.to(accelerator.device, dtype=inference_dtype)

    if config.intrinsic_reward_fn.endswith('ada'):
        disnet_p = Discriminator_P(64, label_len=config.disnet_label_len,config=config)
        pipeline.disnet_p = disnet_p
        pipeline.disnet_p.to(accelerator.device)
        pipeline.disnet_p.train()
        if config.pretrained_disnet_p!='None':
            disnet_p.load_state_dict(torch.load(config.pretrained_disnet_p, map_location=accelerator.device)) # ckpts/disnet_p
    else:
        disnet_p = None
        pipeline.disnet_p = disnet_p
    if config.intrinsic_reward_fn.endswith('clip') or config.intrinsic_reward_fn.endswith('clip_ada'):
        disnet_clip = Discriminator_CLIP(model=pipeline,config=config)

    if config.train.use_kl or disnet_p:
        unet_copy = UNet2DConditionModel.from_pretrained(
            config.pretrained.model,
            subfolder="unet",
            revision=config.pretrained.revision,
        )
        unet_copy.requires_grad_(False)
        unet_copy.to(accelerator.device, dtype=inference_dtype)
    else:
        unet_copy = None

    if config.use_lora:
        # Set correct lora layers
        lora_attn_procs = {}
        for name in pipeline.unet.attn_processors.keys():
            cross_attention_dim = (
                None
                if name.endswith("attn1.processor")
                else pipeline.unet.config.cross_attention_dim
            )
            if name.startswith("mid_block"):
                hidden_size = pipeline.unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[
                    block_id
                ]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = pipeline.unet.config.block_out_channels[block_id]

            lora_attn_procs[name] = LoRAAttnProcessor(
                hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
            )
        pipeline.unet.set_attn_processor(lora_attn_procs)

        # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
        # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
        # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
        class _Wrapper(AttnProcsLayers):
            def forward(self, *args, **kwargs):
                return pipeline.unet(*args, **kwargs)

        unet = _Wrapper(pipeline.unet.attn_processors)
    else:
        unet = pipeline.unet

    # set up diffusers-friendly checkpoint saving with Accelerate

    def save_model_hook(models, weights, output_dir):
        assert len(models) >= 1
        if config.use_lora and isinstance(models[0], AttnProcsLayers):
            pipeline.unet.save_attn_procs(output_dir)
            if len(models) > 1:
                models[1].save_model(os.path.join(output_dir, "disnet_p"))
               #  models[2].save_model(os.path.join(output_dir, "disnet_eta"))
        elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
            models[0].save_pretrained(os.path.join(output_dir, "unet"))
        else:
            raise ValueError(f"Unknown model type {type(models[0])}")
        weights.clear()  # ensures that accelerate doesn't try to handle saving of the model

    def load_model_hook(models, input_dir):
        assert len(models) == 1
        if config.use_lora and isinstance(models[0], AttnProcsLayers):
            # pipeline.unet.load_attn_procs(input_dir)
            tmp_unet = UNet2DConditionModel.from_pretrained(
                config.pretrained.model,
                revision=config.pretrained.revision,
                subfolder="unet",
            )
            tmp_unet.load_attn_procs(input_dir)
            models[0].load_state_dict(
                AttnProcsLayers(tmp_unet.attn_processors).state_dict()
            )
            if len(models) > 1:
                models[1].load_state_dict(torch.load(os.path.join(input_dir, "disnet_p")))
                # models[2].load_state_dict(torch.load(os.path.join(input_dir, "disnet_eta")))
            del tmp_unet
        elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
            load_model = UNet2DConditionModel.from_pretrained(
                input_dir, subfolder="unet"
            )
            models[0].register_to_config(**load_model.config)
            models[0].load_state_dict(load_model.state_dict())
            del load_model
        else:
            raise ValueError(f"Unknown model type {type(models[0])}")
        models.clear()  # ensures that accelerate doesn't try to handle loading of the model

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    # Initialize the optimizer
    if config.train.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer_unet = optimizer_cls(
        unet.parameters(),
        lr=config.train.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )
    if disnet_p:
        optimizer_disnet_p = optimizer_cls(
            disnet_p.parameters(),
            lr=config.train.learning_rate,
            betas=(config.train.adam_beta1, config.train.adam_beta2),
            weight_decay=config.train.adam_weight_decay,
            eps=config.train.adam_epsilon,
        )
    # prepare prompt and reward fn
    prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) # simple animal
    reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() # DDPO reward
    # -------------------------------------------------------------------------------
    intrinsic_reward_fn = getattr(ddpo_pytorch.rewards, config.intrinsic_reward_fn)()
    disnet_fn = getattr(ddpo_pytorch.rewards,'baseline_ada_v2')()
    # -------------------------------------------------------------------------------

    # generate negative prompt embeddings
    neg_prompt_embed = pipeline.text_encoder(
        pipeline.tokenizer(
            [""],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=pipeline.tokenizer.model_max_length,
        ).input_ids.to(accelerator.device)
    )[0]
    sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1)
    train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)

    # initialize stat tracker
    if config.per_prompt_stat_tracking:
        stat_tracker = PerPromptStatTracker(
            config.per_prompt_stat_tracking.buffer_size,
            config.per_prompt_stat_tracking.min_count,
        )
        intrinsic_stat_tracker = PerPromptStatTracker(
            config.per_prompt_stat_tracking.buffer_size,
            config.per_prompt_stat_tracking.min_count,
        )
        # --------------------------------------------------------------------------------

    # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
    # more memory
    autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
    # autocast = accelerator.autocast

    # Prepare everything with our `accelerator`.
    # if disnet:
    #     unet, disnet, optimizer = accelerator.prepare(unet, disnet, optimizer)
    # else:
    if disnet_p:
        unet, disnet_p, optimizer, optimizer_disnet_p = \
            accelerator.prepare(unet, disnet_p, optimizer_unet,optimizer_disnet_p)
    else:
        unet, optimizer = accelerator.prepare(unet, optimizer_unet)

    # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
    # remote server running llava inference.
    executor = futures.ThreadPoolExecutor(max_workers=2)
    executor2 = futures.ThreadPoolExecutor(max_workers=2)
    # executor3 = futures.ThreadPoolExecutor(max_workers=2)
    # Train!
    samples_per_epoch = (
        config.sample.batch_size
        * accelerator.num_processes
        * config.sample.num_batches_per_epoch
    )
    total_train_batch_size = (
        config.train.batch_size
        * accelerator.num_processes
        * config.train.gradient_accumulation_steps
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {config.num_epochs}")
    logger.info(f"  Sample batch size per device = {config.sample.batch_size}")
    logger.info(f"  Train batch size per device = {config.train.batch_size}")
    logger.info(
        f"  Gradient Accumulation steps = {config.train.gradient_accumulation_steps}"
    )
    logger.info("")
    logger.info(f"  Total number of samples per epoch = {samples_per_epoch}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}"
    )
    logger.info(
        f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}"
    )
    logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}")

    assert config.sample.batch_size >= config.train.batch_size
    assert config.sample.batch_size % config.train.batch_size == 0
    assert samples_per_epoch % total_train_batch_size == 0

    if config.resume_from:
        logger.info(f"Resuming from {config.resume_from}")
        accelerator.load_state(config.resume_from)
        first_epoch = int(config.resume_from.split("_")[-1]) + 1
    else:
        first_epoch = 0

    global_step = 0

    if config.validation:
        score_dict = eval_all(pipeline, config, accelerator=accelerator, logger=logger, test_names=config.test_prompts)
        if accelerator.is_main_process:
            accelerator.log(
                score_dict,
                step=global_step,
            )
        # if config.test_prompts!='all' and accelerator.is_main_process:
        #     score_dict = eval_all(pipeline, config,logger=logger,test_names=config.test_prompts)
        #
        #     accelerator.log(
        #         score_dict,
        #         step=global_step,
        #     )
        #     logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
        # elif config.test_prompts=='all':
        #     for idx,test_prompts in enumerate(['imagenet_animals','simple_animals','activities_animals']):
        #         if accelerator.process_index==idx+1:
        #             score_dict = eval_all(pipeline, config,logger=logger, test_names=test_prompts)
        #             accelerator.log(
        #                 score_dict,
        #                 step=global_step,
        #             )
        #             logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
    if accelerator.is_local_main_process:
        clock=Clock()
        logger.info(f"  Clock enabled at {clock.start()}")
    reward_count=0

    for epoch in range(first_epoch, config.num_epochs):
        #################### SAMPLING ####################
        pipeline.unet.eval()
        samples = []
        prompts = []
        for i in tqdm(
            range(config.sample.num_batches_per_epoch),
            desc=f"Epoch {epoch}: sampling",
            disable=not accelerator.is_local_main_process,
            position=0,
        ):
            # generate prompts
            prompts, prompt_metadata = zip(
                *[
                    prompt_fn(**config.prompt_fn_kwargs)
                    for _ in range(config.sample.batch_size)
                ]
            )

            # encode prompts
            prompt_ids = pipeline.tokenizer(
                prompts,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=pipeline.tokenizer.model_max_length,
            ).input_ids.to(accelerator.device)
            prompt_embeds = pipeline.text_encoder(prompt_ids)[0]

            # sample
            with autocast():
                images, _, latents, log_probs,each_origin = pipeline_with_logprob(
                    pipeline,
                    prompt_embeds=prompt_embeds,
                    negative_prompt_embeds=sample_neg_prompt_embeds,
                    num_inference_steps=config.sample.num_steps,
                    guidance_scale=config.sample.guidance_scale,
                    eta=config.sample.eta,
                    output_type="pt",
                    return_origin=True
                )
            each_origin = torch.stack(
                each_origin, dim=1
            )  # (batch_size, num_steps + 1, 4, 64, 64)
            latents = torch.stack(
                latents, dim=1
            )  # (batch_size, num_steps + 1, 4, 64, 64)
            log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1)
            timesteps = pipeline.scheduler.timesteps.repeat(
                config.sample.batch_size, 1
            )  # (batch_size, num_steps)

            # compute rewards asynchronously
            if accelerator.is_local_main_process and epoch == first_epoch and i == 0:
                print('type(images):', type(images), ', images.shape:', images.shape)  # images.shape: torch.Size([bs, 3, 512, 512])
                print('latents.shape:', latents.shape)  # latents.shape: torch.Size([bs, step+1, 4, 64, 64])
            rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
            if disnet_p:
                #disoutput = disnet(latents.detach()[:, 1:].reshape(-1, *latents.shape[2:])).reshape(latents.shape[0], -1)
                intrinsic_rewards = executor2.submit(intrinsic_reward_fn, images, prompts, prompt_metadata, latents.detach())
                disnet_output = disnet_fn(images, prompts, prompt_metadata, latents.detach(),disnet_p)
            else:
                intrinsic_rewards = executor2.submit(intrinsic_reward_fn, images, prompts, prompt_metadata, latents.detach())
                disnet_output=None
          # --------------------------------------------------------------------------------
            # yield to make sure reward computation starts
            time.sleep(0)

            samples.append(
                {
                    'prompts':prompts,
                    "prompt_ids": prompt_ids,
                    "prompt_embeds": prompt_embeds,
                    "timesteps": timesteps,
                    'each_origins':each_origin,
                    "latents": latents[
                        :, :-1
                    ],  # each entry is the latent before timestep t
                    "next_latents": latents[
                        :, 1:
                    ],  # each entry is the latent after timestep t
                    "log_probs": log_probs,
                    "rewards": rewards,
                    # --------------------------------------------------------------------------------
                    "intrinsic_rewards": intrinsic_rewards,
                    'disnet_output':disnet_output
                    # --------------------------------------------------------------------------------
                }
            )

        # wait for all rewards to be computed
        for sample in tqdm(
            samples,
            desc="Waiting for rewards",
            disable=not accelerator.is_local_main_process,
            position=0,
        ):
            rewards, reward_metadata = sample["rewards"].result()
            # --------------------------------------------------------------------------------
            intrinsic_rewards, intristic_reward_metadata = sample["intrinsic_rewards"].result()
            # --------------------------------------------------------------------------------
            # if disnet_p:
            #     sample['disnet_output'] = sample['disnet_output'].result()
                # sample['disnet_output']=sample['disnet_output'].pop()
            if config.intrinsic_reward_fn.endswith('clip') or config.intrinsic_reward_fn.endswith('clip_ada'):
                prompt_current = sample['prompts']
                prompt_embed_current = sample['prompt_embeds']
                latents_for_decode = sample['each_origins']
                latents_for_decode = torch.cat([sample['latents'][:, 0:1, :, :, :], latents_for_decode, ], dim=1)
                clip_list = []
                for idd in tqdm(range(latents_for_decode.shape[1]), desc="Processing Clip",
                                disable=not accelerator.is_local_main_process,):
                    latents_each = latents_for_decode[:, idd]
                    clip_p = disnet_clip(latents_each,prompt_embed_current,prompt_current)
                    clip_list.append(clip_p)
                clip_array = torch.stack(clip_list, axis=1).squeeze()

                delta_clip_array = accelerator.unwrap_model(disnet_clip).get_delta(clip_array)
                if config.delta_min > -0.9:
                    x_tensor = torch.tensor(config.delta_min)
                    delta_clip_array = torch.where(delta_clip_array < 0, x_tensor.to(delta_clip_array.device),
                                                   delta_clip_array)
                delta_clip_array=zero_tensor_by_ratio(delta_clip_array,zero_ratio=config.delta_drop_ratio,manner=config.delta_drop_manner)
                sample['clip_p']=clip_array
                intrinsic_rewards=delta_clip_array*intrinsic_rewards
                if config.p_loc==1:
                    p_dis=sample["disnet_output"].detach()
                    if not config.wo_p:
                        if config.remain_05:
                            intrinsic_rewards = intrinsic_rewards * (1.5 - p_dis)
                            rewards_ada = rewards.unsqueeze(1) * (p_dis+0.5)
                        elif config.intri_only:
                            intrinsic_rewards = intrinsic_rewards * (1 - p_dis)
                            rewards_ada = rewards.unsqueeze(1).repeat(1, 50)
                        else:
                            intrinsic_rewards = intrinsic_rewards * (1 - p_dis)
                            rewards_ada = rewards.unsqueeze(1) * p_dis
                    else:
                        rewards_ada=rewards.unsqueeze(1).repeat(1, 50)
                    sample['rewards_ada']=rewards_ada
            del sample['prompts']
            # accelerator.print(reward_metadata)
            sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device)
            # --------------------------------------------------------------------------------
            sample["intrinsic_rewards"] = torch.as_tensor(intrinsic_rewards, device=accelerator.device)

            # --------------------------------------------------------------------------------
        # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
        samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
        # this is a hack to force wandb to log the images as JPEGs instead of PNGs
        if accelerator.is_local_main_process:
            with tempfile.TemporaryDirectory() as tmpdir:
                for i, image in enumerate(images):
                    pil = Image.fromarray(
                        (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
                    )
                    # pil = pil.resize((256, 256))
                    pil.save(os.path.join(tmpdir, f"{i}.jpg"))
                accelerator.log(
                    {
                        "images": [
                            wandb.Image(
                                os.path.join(tmpdir, f"{i}.jpg"),
                                caption=f"{prompt:.25} | {reward:.2f}",
                            )
                            for i, (prompt, reward) in enumerate(
                                zip(prompts, rewards)
                            )  # only log rewards from process 0
                        ],
                    },
                    step=global_step,
                )
        # gather rewards across processes
        rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
        if config.p_loc==1:
            rewards_ada = accelerator.gather(samples["rewards_ada"]).cpu().numpy()
        # --------------------------------------------------------------------------------
        intrinsic_rewards = accelerator.gather(samples["intrinsic_rewards"]).cpu().numpy()
        info = defaultdict(list)
        if disnet_p:
            disnet_p.train()
            for _ in range(1):
                with autocast():
                    forward_loss_disnet, outputs_f, outputs_t, output_all = accelerator.unwrap_model(
                        disnet_p).loss_v2(samples["disnet_output"])
                    forward_loss1 = forward_loss_disnet
                    # forward_loss1 = (forward_loss1 * mask).sum() / torch.max(mask.sum(), torch.Tensor([1]).to(self.device))
                    # print(sample["disnet_output"].dtype)
                    # print(disnet.dtype)
                    loss1 = forward_loss1
                    # loss+=loss1
                    accelerator.backward(loss1)
                    # accelerator.backward(loss1)
                    accelerator.clip_grad_norm_(
                        disnet_p.parameters(), config.train.max_grad_norm
                    )
                    # # global_grad_norm_(list(self.unet.pa rameters()))
                    optimizer_disnet_p.step()
                    optimizer_disnet_p.zero_grad()
            info["disnet_f"].append(outputs_f)
            info["disnet_all"].append(output_all)
            info["disnet_t"].append(outputs_t)
            info = {k: torch.mean(torch.stack(v)) if k != 'disnet_all' else
            torch.stack(v).mean(axis=0) for k, v in info.items()}
            info = accelerator.reduce(info, reduction="mean")
            accelerator.log(info, step=global_step)
            if epoch<config.disnet_start_apply_epoch:
                # ----- option 1 -----
                # disnet.eval()
                # samples["disnet_output"]=disnet(samples['latents'].detach()[:, 1:].reshape(-1, *samples['latents'].detach().shape[2:])).reshape(samples['latents'].detach().shape[0], -1)
                #----- option 2 -----
                samples["disnet_output"]=accelerator.unwrap_model(
                        disnet_p).label_output(samples["disnet_output"])
                accelerator.log(accelerator.reduce({f'fixed p':samples["disnet_output"].mean(axis=0)}, reduction="mean"), step=global_step)
            pred_p = accelerator.gather(samples['disnet_output'].detach()).cpu().numpy()




        if config.intrinsic_reward_fn.endswith('clip') or config.intrinsic_reward_fn.endswith('clip_ada'):
            clip_p = accelerator.gather(samples['clip_p'].detach()).cpu().numpy()
            accelerator.log({
                'clip_p':wandb.Histogram(clip_p.mean(axis=0)),
                'clip_p_mean':clip_p.mean(),
            },step=global_step)
            save_tensor_to_txt(clip_p.mean(axis=0),f'midres/clip_p_{human_readable_time}.txt')
        # --------------------------------------------------------------------------------
        # log rewards and images
        if accelerator.is_local_main_process and config.intrinsic_reward_fn != 'baseline':
            print('rewards.shape:', rewards.shape)  # rewards.shape: torch.Size([ngpus * sample.batch_size * sample.num_batches_per_epoch])
            # print('rewards:', rewards)
            print('zero rewards:', len(rewards[np.isclose(rewards, 0.0)]))
            # print('intrinsic_rewards.shape:', intrinsic_rewards.shape)
            # print('intrinsic_rewards:', intrinsic_rewards)
            # print('intrinsic_rewards.mean:', intrinsic_rewards.mean(axis=0))
            # print('intrinsic_advantages:', repr(intrinsic_advantages))
        # per-prompt mean/std tracking
        if config.per_prompt_stat_tracking:
            # gather the prompts across processes
            prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
            prompts = pipeline.tokenizer.batch_decode(
                prompt_ids, skip_special_tokens=True
            )
            if config.p_loc==1:
                advantages = stat_tracker.update(prompts, rewards_ada)
            else:
                advantages = stat_tracker.update(prompts, rewards)
            # --------------------------------------------------------------------------------
            intrinsic_advantages = intrinsic_stat_tracker.update(prompts, intrinsic_rewards)
            # --------------------------------------------------------------------------------
        else:
            if config.p_loc==1:
                advantages = (rewards_ada - rewards_ada.mean(axis=0)) / (rewards_ada.std(axis=0) + 1e-8)
                print(advantages.shape)
                print(rewards.shape)
            else:
                advantages = (rewards - rewards) / (rewards + 1e-8)
            # --------------------------------------------------------------------------------
            # [[1, 2, 3, 4],
            #  [2, 3, 4, 5]]
            # mean = [1.5, 2.5, 3.5, 4.5]
            # std = [...]
            intrinsic_advantages = (intrinsic_rewards - intrinsic_rewards.mean(axis=0)) / (intrinsic_rewards.std(axis=0) + 1e-8)
            # --------------------------------------------------------------------------------
        if epoch>config.cut_warmup and config.step_cut>0:
            if disnet_p:
                num_train_timesteps, raw_array = step_predictor_v2(pred_p=pred_p)
            else:
                num_train_timesteps,raw_array=step_predictor_empty(clip_p=clip_p)
        accelerator.log(
            {
                "reward": rewards,
                "intrinsic_rewards": intrinsic_rewards,
                "epoch": epoch,
                "reward_mean": rewards.mean(),
                "reward_std": rewards.std(),
                "intrinsic_reward_mean": intrinsic_rewards.mean(axis=0),
                "intrinsic_rewards.all_step_mean":intrinsic_rewards.mean(),
                "intrinsic_reward_std": intrinsic_rewards.std(axis=0),
                # 'intrinsic_advantages': intrinsic_advantages,
                # 'intrinsic_advantages_mean': intrinsic_advantages.mean(axis=0),
                'advantages': advantages,
            },
            step=global_step,
        )
        if accelerator.is_local_main_process and reward_count!=-1:
            if rewards.mean()>config.cut_all_score:
                reward_count+=1
                if reward_count>1:
                    elapsed_time=clock.end()
                    accelerator.log(
                        {
                            'elapsed_time': elapsed_time
                        },
                        step=global_step,
                    )
                    reward_count=-1
        # ungather advantages; we only need to keep the entries corresponding to the samples on this process
        if config.p_loc==1:
            samples["advantages"] = (
                torch.as_tensor(advantages)
                    .reshape(accelerator.num_processes, -1,advantages.shape[-1])[accelerator.process_index]
                    .to(accelerator.device)
            )
        else:
            samples["advantages"] = (
                torch.as_tensor(advantages)
                .reshape(accelerator.num_processes, -1)[accelerator.process_index]
                .to(accelerator.device)
            )
        samples["intrinsic_advantages"] = (
            torch.as_tensor(intrinsic_advantages)
            .reshape(accelerator.num_processes, -1, intrinsic_advantages.shape[-1])[accelerator.process_index]
            .to(accelerator.device)
        )
        if config.step_cut==2:
            samples["raw_array"] = (
                torch.as_tensor(raw_array)
                .reshape(accelerator.num_processes, -1)[accelerator.process_index]
            .to(accelerator.device)
            )
        torch.set_printoptions(threshold=float('inf'))
        for each in samples["intrinsic_advantages"]:
            save_tensor_to_txt(each,
                               f'midres/intrinsic_clip_advantages_{human_readable_time}.txt')
        del samples["rewards"]
        del samples["intrinsic_rewards"]
        del samples["prompt_ids"]

        total_batch_size, num_timesteps = samples["timesteps"].shape
        assert (
            total_batch_size
            == config.sample.batch_size * config.sample.num_batches_per_epoch
        )
        assert num_timesteps == config.sample.num_steps

        #################### TRAINING ####################
        for inner_epoch in range(config.train.num_inner_epochs):
            # shuffle samples along batch dimension
            # cjk_note: useless since num_in_epoch=1
            if config.step_cut==2:
                # cjk_note: reorder sample by step_num, a light-weight method to approximately align step number of each gpu/batch
                perm = torch.argsort(samples['raw_array'], descending=True).to(accelerator.device)
            else:
                perm = torch.randperm(total_batch_size, device=accelerator.device)
            samples = {k: v[perm] for k, v in samples.items()}
            # if config.step_cut==2:
            #     print(f'Step array on Rank-{accelerator.state.local_process_index}: {samples["raw_array"]}')

            if config.step_cut<2:
                # shuffle along time dimension independently for each sample
                # cjk_note: it breaks the order of denoising, unexpected for step cut. (We want drop last)
                perms = torch.stack(
                    [
                        torch.randperm(num_timesteps, device=accelerator.device)
                        for _ in range(total_batch_size)
                    ]
                )
                for key in ["timesteps", "latents", "next_latents", "log_probs"]:
                    samples[key] = samples[key][
                        torch.arange(total_batch_size, device=accelerator.device)[:, None],
                        perms,
                    ]
                if disnet_clip:
                    for key in ['clip_p']:
                        samples[key] = samples[key][
                            torch.arange(total_batch_size, device=accelerator.device)[:, None],
                            perms,
                        ]
                if disnet_p:
                    for key in ['disnet_output']:
                        samples[key] = samples[key][
                            torch.arange(total_batch_size, device=accelerator.device)[:, None],
                            perms,
                        ]
            # rebatch for training
            samples_batched = {
                k: v.reshape(-1, config.train.batch_size, *v.shape[1:])
                for k, v in samples.items()
            }

            # dict of lists -> list of dicts for easier iteration
            samples_batched = [
                dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
            ]

            # train
            pipeline.unet.train()
            info = defaultdict(list)
            # if config.step_cut == 2:
            #     each_accu_steps_arr = \
            #         samples['raw_array'].cpu().detach().numpy().reshape(-1,config.train.gradient_accumulation_steps*config.train.batch_size).sum(axis=1)
            #     # b2_res=each_accu_steps_arr.reshape(-1, 2).max(axis=1)
            #     print(f'Each-step array on Rank-{accelerator.state.local_process_index}: {each_accu_steps_arr}')
            accumulate_count=0
            inner_counter=0
            avg_step_num_array=[]
            avg_step_num_array_v2 = []
            for i, sample in tqdm(
                list(enumerate(samples_batched)),
                desc=f"Epoch {epoch}.{inner_epoch}: training",
                position=0,
                disable=not accelerator.is_local_main_process,
            ):
                if config.train.cfg:
                    # concat negative prompts to sample prompts to avoid two forward passes
                    embeds = torch.cat(
                        [train_neg_prompt_embeds, sample["prompt_embeds"]]
                    )
                else:
                    embeds = sample["prompt_embeds"]
                # cjk_note: get the avg step in one forward across num_gpu * batch_size
                if config.step_cut == 2:
                    num_train_timesteps=round(sample['raw_array'].cpu().detach().numpy().mean())
                    # accelerator.gradient_accumulation_steps = config.train.gradient_accumulation_steps * num_train_timesteps
                if len(sample['raw_array'])==1:
                    num_train_timesteps_gather = accelerator.gather(sample['raw_array'])
                else:
                    num_train_timesteps_gather = accelerator.gather(sample['raw_array'].float().mean().long())
                total = num_train_timesteps_gather.sum().item() # total sample number for gradient accumulation
                try:
                    avg_train_steps=int(num_train_timesteps_gather.cpu().detach().numpy().mean())
                    # print(f"avg_train_steps {avg_train_steps} at i={i}")
                    avg_step_num_array.append(avg_train_steps)
                    avg_step_num_array_v2.append(avg_train_steps)
                    num_train_timesteps=avg_train_steps
                except Exception as e:
                    print(num_train_timesteps_gather)
                # current_log_prob=None
                for j in tqdm(
                    range(num_train_timesteps),
                    desc="Timestep",
                    position=1,
                    leave=False,
                    disable=not accelerator.is_local_main_process,
                ):
                    if config.step_cut == 0:
                        context =  accelerator.accumulate(unet)
                    else:
                        context = contextlib.nullcontext()

                    with context:
                        with autocast():
                            noise_pred = unet_predict_noise(config,
                                                            unet,
                                                            sample["latents"][:, j],
                                                            sample["timesteps"][:, j],
                                                            embeds)
                            # compute the log prob of next_latents given latents under the current model
                            # if j==0:
                            _, log_prob = ddim_step_with_logprob(
                                pipeline.scheduler,
                                noise_pred,
                                sample["timesteps"][:, j],
                                sample["latents"][:, j],
                                eta=config.sample.eta,
                                prev_sample=sample["next_latents"][:, j],
                            )
                            # else:
                            #     log_prob=current_log_prob
                            if j==num_train_timesteps-1:
                                last=True
                            else:
                                last=False
                                _, next_log_prob = ddim_step_with_logprob(
                                    pipeline.scheduler,
                                    noise_pred,
                                    sample["timesteps"][:, j+1],
                                    sample["latents"][:, j+1],
                                    eta=config.sample.eta,
                                    prev_sample=sample["next_latents"][:, j+1],
                                )
                                # current_log_prob=clone_tensor_with_grad(next_log_prob)
                                # xt xt+1
                                current_log_prob=log_prob
                        inner_counter+=1
                        if config.p_round:
                            p = accelerator.unwrap_model(disnet_p).P_round(sample["disnet_output"][:, j].detach())
                            if not last:
                                p_next=accelerator.unwrap_model(disnet_p).P_round(sample["disnet_output"][:, j].detach())
                        else:
                            p = sample["disnet_output"][:, j].detach()
                            if not last:
                                p = sample["disnet_output"][:, j+1].detach()
                        if config.reward_fn.startswith("dummy"):
                            advantages = config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                        elif config.reward_fn.startswith("extrinsic"):
                            if j == num_train_timesteps - 1:
                                advantages = sample["advantages"]
                            else:
                                advantages = config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                        else:
                            if config.p_loc==2:
                                if config.wo_p:
                                    advantages = sample["advantages"] + config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                                else:
                                    if config.remain_05:
                                        advantages = sample["advantages"] * (p+0.5)+ sample["intrinsic_advantages"][:, j] *  (1.5 - p)
                                    elif config.intri_only:
                                        advantages = sample["advantages"] + sample["intrinsic_advantages"][
                                                                                        :, j] * (1 - p) * config.advantage_weight
                                    else:
                                        advantages = sample["advantages"]*p + sample["intrinsic_advantages"][:, j]*(1-p)
                            else:
                                # print(sample["advantages"])
                                advantages = sample["advantages"][:, j] + config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j]
                        if not last:
                            if config.reward_fn.startswith("dummy"):
                                advantages_next = config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j+1]
                            elif config.reward_fn.startswith("extrinsic"):
                                if j == num_train_timesteps - 1:
                                    advantages_next = sample["advantages"]
                                else:
                                    advantages_next = config.intrinsic_reward_weight * sample["intrinsic_advantages"][:, j+1]
                            else:
                                if config.p_loc == 2:
                                    if config.wo_p:
                                        advantages_next = sample["advantages"] + config.intrinsic_reward_weight * sample[
                                                                                                                 "intrinsic_advantages"][
                                                                                                             :, j+1]
                                    else:
                                        if config.remain_05:
                                            advantages_next = sample["advantages"] * (p_next + 0.5) + sample["intrinsic_advantages"][
                                                                                            :, j+1] * (1.5 - p_next)
                                        elif config.intri_only:
                                            advantages_next = sample["advantages"] + sample["intrinsic_advantages"][
                                                                                :, j+1] * (1 - p_next) * config.advantage_weight_next
                                        else:
                                            advantages_next = sample["advantages"] * p_next + sample["intrinsic_advantages"][:, j+1] * (
                                                        1 - p_next) # intristic_reward*(1-p) -> intri*0.0045
                                else:
                                    # print(sample["advantages"])
                                    advantages_next = sample["advantages"][:, j+1] + config.intrinsic_reward_weight * sample[
                                                                                                                   "intrinsic_advantages"][
                                                                                                               :, j+1]


                        # ----------------------------------------------------------------------
                        advantages = torch.clamp(
                            advantages,
                            -config.train.adv_clip_max,
                            config.train.adv_clip_max,
                        )
                        ratio = torch.exp(log_prob - sample["log_probs"][:, j])
                        unclipped_loss = -advantages * ratio
                        clipped_loss = -advantages * torch.clamp(
                            ratio,
                            1.0 - config.train.clip_range,
                            1.0 + config.train.clip_range,
                        )
                        if not last:
                            advantages_next = torch.clamp(
                                advantages_next,
                                -config.train.adv_clip_max,
                                config.train.adv_clip_max,
                            )
                            ratio_next = torch.exp(next_log_prob - sample["log_probs"][:, j+1]) # TODO: Detach here
                            unclipped_loss_next = -advantages_next *ratio_next* ratio
                            clipped_loss_next = -advantages_next * torch.clamp(
                                ratio_next,
                                1.0 - config.train.clip_range,
                                1.0 + config.train.clip_range,
                            )* torch.clamp(
                            ratio,
                            1.0 - config.train.clip_range,
                            1.0 + config.train.clip_range,)

                            loss_next = torch.mean(torch.maximum(unclipped_loss_next, clipped_loss_next))
                        else:
                            loss_next=0
                        loss = config.reward_weight * torch.mean(torch.maximum(unclipped_loss, clipped_loss))+loss_next*config.rpo_beta

                        if config.train.use_kl:
                            with autocast():
                                noise_pred_old = unet_predict_noise(config,
                                                                    unet_copy,
                                                                    sample["latents"][:, j],
                                                                    sample["timesteps"][:, j],
                                                                    embeds)
                            kl_regularizer = (noise_pred - noise_pred_old) ** 2
                            if i >= config.train.kl_warmup:
                                loss += config.train.kl_weight * kl_regularizer.mean()
                                info["kl_regularizer"].append(kl_regularizer.mean())
                        # cjk_note: manually adjust loss for gradient accumulation without accelerator
                        if config.step_cut==2:
                            loss = (loss * config.train.gradient_accumulation_steps * accelerator.num_processes) / total
                        # debugging values
                        # John Schulman says that (ratio - 1) - log(ratio) is a better
                        # estimator, but most existing code uses this so...
                        # http://joschu.net/blog/kl-approx.html
                        info["approx_kl"].append(
                            0.5
                            * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
                        )
                        info["clipfrac"].append(
                            torch.mean(
                                (
                                    torch.abs(ratio - 1.0) > config.train.clip_range
                                ).float()
                            )
                        )
                        info["loss"].append(loss)
                        #print('performing backward')
                        # backward pass
                        accelerator.backward(loss)
                        # cjk_note: preserving old code
                        if config.step_cut==0 or config.step_cut==1:
                            if accelerator.sync_gradients:
                                accelerator.clip_grad_norm_(
                                    unet.parameters(), config.train.max_grad_norm
                                )
                            # if config.train.gradient_accumulation_steps* num_train_timesteps
                            optimizer.step()
                            optimizer.zero_grad()
                        # if check_all_gradients_unupdated(unet):
                        #     print(f'rank-{accelerator.state.local_process_index} updated at {i}-{j}')
                    # cjk_note: preserving old code
                    if accelerator.sync_gradients and config.step_cut < 2:
                        # assert (j == num_train_timesteps - 1) and (
                        #     i + 1
                        # ) % config.train.gradient_accumulation_steps == 0
                        # log training-related stuff
                        info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                        info = accelerator.reduce(info, reduction="mean")
                        info.update({"epoch": epoch, "inner_epoch": inner_epoch})
                        accelerator.log(info, step=global_step)

                        global_step += 1
                        info = defaultdict(list)
                accumulate_count+=1
                # cjk_note:  manual gradient accumulation
                if accumulate_count==config.train.gradient_accumulation_steps and config.step_cut==2:
                    # print(f'Before Accumulating final on rank{accelerator.state.local_process_index}: sample_No={i},Step_total={j} accelerator.sync_gradients={accelerator.sync_gradients}, check:{check_all_gradients_unupdated(unet)}',flush=True)
                    # old=accelerator.gradient_accumulation_steps
                    #accelerator.gradient_accumulation_steps=1
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(
                            unet.parameters(), config.train.max_grad_norm
                        )
                    optimizer.step()
                    optimizer.zero_grad()
                    accumulate_count=0
                    # print(f'After Accumulating final on rank{accelerator.state.local_process_index}: sample_No={i}, accelerator.sync_gradients={accelerator.sync_gradients}, check:{check_all_gradients_unupdated(unet)}',flush=True)
                    info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                    info = accelerator.reduce(info, reduction="mean")
                    info.update({"epoch": epoch, "inner_epoch": inner_epoch})
                    accelerator.log({"step_num": wandb.Histogram(np.array(avg_step_num_array))}, step=global_step)
                    avg_step_num_array=[]
                    accelerator.log(info, step=global_step)
                    global_step += 1
                    info = defaultdict(list)


            accelerator.log({"step_num_v2": wandb.Histogram(np.array(avg_step_num_array_v2,dtype=float))}, step=global_step)
            # save_tensor_to_txt(np.array(avg_step_num_array_v2),'midres/step_num_v1.txt')
            # make sure we did an optimization step at the end of the inner epoch
            print(f"sync_gradient on {accelerator.state.local_process_index} (before):", accelerator.sync_gradients, flush=True)
            assert accelerator.sync_gradients
        # print(f"---->accelerator.save_state in {accelerator.state.local_process_index}<----")
        if epoch !=0 and epoch%99==0:
            accelerator.save_state()
        if config.validation and epoch % config.vali_interval == 0:
            if accelerator.is_main_process and reward_count != -1:
                clock.pause()
            score_dict = eval_all(pipeline, config, accelerator=accelerator, logger=logger, test_names=config.test_prompts)
            if accelerator.is_main_process:
                accelerator.log(
                    score_dict,
                    step=global_step,
                )
            # if config.test_prompts != 'all' and accelerator.is_main_process:
            #     score_dict = eval_all(pipeline, config, logger=logger,test_names=config.test_prompts)
            #
            #     accelerator.log(
            #         score_dict,
            #         step=global_step,
            #     )
            #     logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
            # elif config.test_prompts == 'all':
            #     for idx, test_prompts in enumerate(['imagenet_animals', 'simple_animals', 'activities_animals']):
            #         if accelerator.process_index == idx + 1:
            #             score_dict = eval_all(pipeline, config, logger=logger, test_names=test_prompts)
            #             accelerator.log(
            #                 score_dict,
            #                 step=global_step,
            #             )
            #             logger.info(f"  Validation Success on GPU:{accelerator.process_index}")
            if reward_count != -1 and accelerator.is_main_process:
                clock.restart()
            torch.cuda.empty_cache()
            del samples
            del samples_batched
        if accelerator.is_local_main_process and reward_count != -1:
            elapsed_time = clock.end()
            accelerator.log(
                {
                    'elapsed_time': elapsed_time
                },
                step=global_step,
            )
        #if epoch != 0 and epoch % config.save_freq == 0:
            # checkpoint_dir = f"logs/intrinsic_layered42_0112_{int(time.time())}_{rank}"
            # accelerator.save_state(output_dir=checkpoint_dir)
            # try:

            #accelerator.save_state()
            # except Exception as e:
            #     print(f"---->error occur in {accelerator.state.local_process_index}:{e}<----")


if __name__ == "__main__":
    app.run(main)