
import functools
import os, sys, json
import yaml
from copy import deepcopy
from dataclasses import dataclass, field
import time
from concurrent import futures
from PIL import Image
import subprocess
from datetime import timedelta  
import torch
import torch.nn.functional as F 
import numpy as np
from typing import List
from torch.utils.data import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._composable.fsdp import register_fsdp_forward_method
import torch.distributed as dist
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)
from torch.utils.data import DataLoader
from transformers import HfArgumentParser, set_seed
from transformers.optimization import (
    get_constant_schedule_with_warmup,
    get_cosine_with_min_lr_schedule_with_warmup,
)

from data.data_utils import add_special_tokens
from data.interleave_datasets.edit_dataset import CAPTION2EDIT_SYSTEM_PROMPT_BENCH
from modeling.autoencoder import load_ae
from modeling.bagel import (
    BagelConfig,
    Bagel,
    Qwen2Config,
    Qwen2ForCausalLM,
    SiglipVisionConfig,
    SiglipVisionModel,
)
from modeling.qwen2 import Qwen2Tokenizer
from modeling.qwen2.modeling_qwen2 import Qwen2RMSNorm
from torch import nn
from train.train_utils import create_logger, get_latest_ckpt
from train.fsdp_utils import (
    FSDPCheckpoint,
    FSDPConfig,
    grad_checkpoint_check_fn,
    fsdp_wrapper,
    fsdp_ema_setup,
    fsdp_ema_update,
)
from train.rollout_controller import MultiRoundRolloutController,RolloutStepResult
from train.rl_utils import (
    geneval_score,
    hps_score,
    GenevalPromptDataset,
    DistributedKRepeatSampler,
    PromptDataset,
    plot_rewards_histogram,
    compute_advantages_for_multi_round_rollout,
    EvalStats,
    TIIFDataset,
)
from data.transforms import ImageTransform
from modeling.diffusion.sde_sampler import SDESampler
from train.visualize import (
    visualize_packed_input,
    visualize_attention_mask_pattern,
    save_visualization,
    save_result_as_html,
)
from train.unified_reward import UnifiedReward, Qwen_Geneval
from train.geneval_plus_reward import geneval_plus_reward, yn_reward_fn
import train.tensorboard_wandb as wandb
from tqdm import tqdm


@dataclass
class ModelArguments:
    llm_path: str = field(default="hf/Qwen2.5-0.5B-Instruct/")
    llm_qk_norm: bool = field(default=True)
    tie_word_embeddings: bool = field(default=False)
    layer_module: str = field(default="Qwen2MoTDecoderLayer")
    vae_path: str = field(default="flux/vae/ae.safetensors")
    vit_path: str = field(default="hf/siglip-so400m-14-980-flash-attn2-navit/")
    max_latent_size: int = field(default=64)
    latent_patch_size: int = field(default=2)
    vit_patch_size: int = field(default=14)
    vit_max_num_patch_per_side: int = field(default=70)
    connector_act: str = field(default="gelu_pytorch_tanh")
    interpolate_pos: bool = field(default=False)
    vit_select_layer: int = field(default=-2)
    vit_rope: bool = field(default=False)

    text_cond_dropout_prob: float = field(default=0.1)
    vae_cond_dropout_prob: float = field(default=0.3)
    vit_cond_dropout_prob: float = field(default=0.3)


@dataclass
class DataArguments:
    dataset_config_file: str = field(default="data/configs/online_rl_example.yaml")
    prefetch_factor: int = field(default=2)
    num_workers: int = field(default=4)
    max_num_tokens_per_sample: int = field(default=16384)
    max_num_tokens: int = field(default=36864)
    prefer_buffer_before: int = field(default=16384)
    max_buffer_size: int = field(default=50)
    data_seed: int = field(default=42)
    online_batch_size: int = field(default=1)

    policy_group_size: int = field(default=1)

    vae_max_image_size: int = field(default=512)
    vae_min_image_size: int = field(default=256)
    vae_image_stride: int = field(default=16)
    vit_max_image_size: int = field(default=504)
    vit_min_image_size: int = field(default=252)
    vit_image_stride: int = field(default=14)
    data_path: str = field(default="/data/geneval")
    dataset_name: str = field(default="hps")
    train_data_path: str = field(default="/data/")
    val_data_path: str = field(default="/data/")
    understand_val_path: str = field(default="/data/vqa/")


@dataclass
class TrainingArguments:
    debug: bool = field(default=False)
    exp_name: str = field(default="")
    commit_id: str = field(default="")

    visual_gen: bool = field(default=True)
    visual_und: bool = field(default=True)

    mydir: str = field(default="")
    results_dir: str = field(default="results")
    checkpoint_dir: str = field(default="results/checkpoints")
    wandb_project: str = field(default="bagel")
    wandb_name: str = field(default="run")
    wandb_runid: str = field(default="trial")
    wandb_resume: str = field(default="allow")
    wandb_offline: bool = field(default=False)
    global_seed: int = field(default=4396)
    auto_resume: bool = field(default=False)
    resume_from: str = field(default=None)
    resume_model_only: bool = field(default=False)
    finetune_from_ema: bool = field(default=False)
    log_every: int = field(default=5)
    save_every: int = field(default=50)
    total_steps: int = field(default=10000)

    warmup_steps: int = field(default=50)
    lr_scheduler: str = field(default="constant")
    lr: float = field(default=5e-6)
    min_lr: float = field(default=1e-7)
    beta1: float = field(default=0.9)
    beta2: float = field(default=0.95)
    weight_decay: float = field(default=0)
    eps: float = field(default=1e-15)
    ema: float = field(default=0)
    max_grad_norm: int = field(default=1.0)
    mse_weight: float = field(default=1.0)
    ce_weight: float = field(default=1.0)
    ce_loss_reweighting: bool = field(default=False)
    expected_num_tokens: int = field(default=32768)

    num_replicate: int = field(default=1)
    num_shard: int = field(default=8)
    sharding_strategy: str = field(default="SHARD_GRAD_OP")
    backward_prefetch: str = field(default="BACKWARD_PRE")
    cpu_offload: bool = field(default=False)

    freeze_llm: bool = field(default=False)
    freeze_vit: bool = field(default=False)
    freeze_vae: bool = field(default=True)
    freeze_und: bool = field(default=False)
    freeze_text_transformer: bool = field(default=False)
    freeze_diffusion_transformer: bool = field(default=False)
    copy_init_moe: bool = field(default=True)
    use_flex: bool = field(default=False)
    eval_only: bool = field(default=False)

    think_mode: bool = field(default=False)
    tune_text_cot: bool = field(default=True)
    tune_image_cot: bool = field(default=True)
    rounds: int = field(default=1)
    group_size: int = field(default=8)
    rollout_times: int = field(default=1)
    kl_weight_text: float = field(default=0.001)
    kl_weight_image: float = field(default=0.001)
    rollout_with_ema: bool = field(default=False)
    rollout_vis_step: int = field(default=10)
    batch_size: int = field(default=1)
    test_batch_size: int = field(default=8)
    eval_freq: int = field(default=5)
    reward_fn: str = field(default="hps")
    advantage_fn: str = field(default="")
    format_reward_weight: float = field(default=0.2)
    pack_start_round_idx: int = field(default=0)
    reflection_prompt_type: str = field(default="type_0")
    enable_sde_gen: bool = field(default=False)
    enable_sde_edit: bool = field(default=False)

    format_reward_add_type: str = field(default="")
    optimize_freq: int = field(default=1)
    reward_server_urls: str = field(
        default=""
    )
    reward_server_port: str = field(default="")
    reward_model_name: str = field(default="Qwen2.5-VL-72B-Instruct-AWQ")
    reward_api_key: str = field(default="EMPTY")
    client_type: str = field(default="openai")
    reflection_system_prompt_key: str = field(default="v0")
    rollout_on_same_noise: bool = field(default=True)
    timestep_sample_ratio: float = field(default=0.1)
    max_output_token_n_gen: int = field(default=196)
    max_output_token_n_edit: int = field(default=384)
    do_sample: bool = field(default=True)
    text_temperature: float = field(default=1.0)
    top_k: int = field(default=-1)
    cfg_text_scale: float = field(default=1.0)
    cfg_img_scale: float = field(default=1.0)
    cfg_interval: tuple = field(default=(0.4, 1.0))
    cfg_interval_low: float = field(default=0.4)
    cfg_interval_high: float = field(default=1.0)
    timestep_shift: float = field(default=3.0)
    num_timesteps_gen: int = field(default=30)  # should be fixed
    num_timesteps_edit: int = field(default=30)  # should be fixed
    cfg_renorm_min: float = field(default=0.0)
    cfg_renorm_type_gen: str = field(default="global")
    cfg_renorm_type_edit: str = field(default="text_channe")
    image_size: int = field(default=512)
    eta_mode: str = field(default='monotonic')
    constant_eta_gen: float = field(default=0.2)
    constant_eta_edit: float = field(default=0.2)
    max_sde_timestep_idx_for_edit: int = field(default=2)
    ref_model_path: str = field(default=None)
    resume_step: int = field(default=-1)
    

def save_inference_results(inference_results, curr_step, results_dir, mode="train", num_samples_to_save=2, timestep_idx=None, last_reward=None, change_summary=None):
    from train.rollout_controller import RolloutStepResult
    save_list = []
    per_sample_results = inference_results["per_sample_results"]
    per_sample_rewards = inference_results["per_sample_rewards"]
    if last_reward is not None:
        save_list.append({"last_reward": last_reward})
    if change_summary is not None and len(change_summary) > 0:
        save_list.append({"change_summary": change_summary})
    if timestep_idx is not None:
        save_list.append({f"sde_timestep_idx": timestep_idx})
    if "summary" in inference_results:
        per_sample_summary = inference_results["summary"]
        for sample_idx in range(len(per_sample_summary)):
            save_list.append({f"summary_{sample_idx}": per_sample_summary[sample_idx]})
    for sample_idx in range(min(num_samples_to_save, len(per_sample_results))):
        for step_result_list in per_sample_results[sample_idx]:
            step_result_list: List[RolloutStepResult]
            for step_result in step_result_list:
                step_result_dict = step_result.to_dict()
                if step_result_dict['is_cfg_text'] or step_result_dict['is_cfg_img']:
                    continue
                sample_idx = step_result.sample_idx
                round_idx = step_result.round_idx
                reward = per_sample_rewards[sample_idx][round_idx]
                step_result_dict["reward"] = reward
                save_list.append(step_result_dict)

    save_result_as_html(save_list, curr_step, results_dir, mode=mode)


def compute_entropy(logits):
    probs = torch.softmax(logits.detach(), dim=-1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
    return entropy


def compute_entropy_batched(logits, batch_size=512):
    seq_length, vocab_size = logits.shape
    entropy_list = []
    for start_idx in range(0, seq_length, batch_size):
        end_idx = min(start_idx + batch_size, seq_length)
        batch_logits = logits[start_idx:end_idx]
        probs = F.softmax(batch_logits, dim=-1)
        log_probs = F.log_softmax(batch_logits, dim=-1)
        batch_entropy = -(probs * log_probs).sum(dim=-1)
        entropy_list.append(batch_entropy)
    entropy = torch.cat(entropy_list, dim=0)
    return entropy.mean()


class RolloutBuffer:
    def __init__(self, max_size=10000):
        self.perfect_buffer = []
        self.bad_buffer = []
        self.max_size = max_size
    
    def update(self, metadata, reward, img):
        if reward == 1:
            self.perfect_buffer.append((metadata, reward, img))
        else:
            self.bad_buffer.append((metadata, reward, img))

    def all_gather(self):
        all_perfect_buffer = [None] * dist.get_world_size()
        all_bad_buffer = [None] * dist.get_world_size()
        dist.all_gather_object(all_perfect_buffer, self.perfect_buffer)
        dist.all_gather_object(all_bad_buffer, self.bad_buffer)
        dist.barrier()
        gathered_perfect_buffer = []
        gathered_bad_buffer = []
        for perfect_buffer in all_perfect_buffer:
            gathered_perfect_buffer.extend(perfect_buffer)
        for bad_buffer in all_bad_buffer:
            gathered_bad_buffer.extend(bad_buffer)
        self.perfect_buffer = gathered_perfect_buffer
        self.bad_buffer = gathered_bad_buffer
        torch.cuda.empty_cache()
        
    def clear(self):
        self.perfect_buffer = []
        self.bad_buffer = []

    def __len__(self):
        return (len(self.perfect_buffer), len(self.bad_buffer))
    
    def sample(self, seed:int=0, perfect_ratio=0.15):
        perfect_len, bad_len = len(self.perfect_buffer), len(self.bad_buffer)
        g=torch.Generator(device="cuda")
        g.manual_seed(seed)
        prob = torch.tensor([perfect_ratio], device="cuda")
        mode = torch.bernoulli(prob, generator=g).item()
        if perfect_len == 0: mode = 0
        if bad_len == 0: mode = 1
        if mode == 0:
            idx = torch.randint(0, bad_len, (1,), generator=g).item()
            return self.bad_buffer[idx]
        else:
            idx = torch.randint(0, perfect_len, (1,), generator=g).item()
            return self.perfect_buffer[idx]



def main():
    dist.init_process_group("nccl", timeout=timedelta(seconds=1000))
    device = dist.get_rank() % torch.cuda.device_count()
    torch.cuda.set_device(device)
    if torch.cuda.is_available():
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cuda.matmul.allow_tf32 = True
        if hasattr(torch, "set_float32_matmul_precision"):
            torch.set_float32_matmul_precision("high")

    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    args: tuple[ModelArguments, DataArguments, TrainingArguments] = (
        parser.parse_args_into_dataclasses()
    )
    model_args, data_args, training_args = args
    training_args.commit_id = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('utf-8').strip()

    training_args.cfg_interval = (training_args.cfg_interval_low, training_args.cfg_interval_high)
    if dist.get_rank() == 0:
        logger = create_logger(training_args.results_dir, dist.get_rank())
        wandb.init(
            project=training_args.wandb_project,
            id=f"{training_args.wandb_name}-run{training_args.wandb_runid}",
            name=training_args.wandb_name,
            resume=training_args.wandb_resume,
            mode="offline" if training_args.wandb_offline else "online",
        )
        wandb.config.update(training_args, allow_val_change=True)
        wandb.config.update(model_args, allow_val_change=True)
        wandb.config.update(data_args, allow_val_change=True)
    else:
        logger = create_logger(None, dist.get_rank())
    logger.info(f"Training arguments {training_args}")
    logger.info(f"Model arguments {model_args}")
    logger.info(f"Data arguments {data_args}")
    training_args.checkpoint_dir = os.path.abspath(
        os.path.join(training_args.mydir, training_args.checkpoint_dir)
    )
    training_args.results_dir = os.path.abspath(
        os.path.join(training_args.mydir, training_args.results_dir)
    )
    if dist.get_rank() == 0:
        os.makedirs(training_args.checkpoint_dir, exist_ok=True)
        os.makedirs(training_args.results_dir, exist_ok=True)
    logger.info(f"checkpoint_dir: {training_args.checkpoint_dir}")
    logger.info(f"results_dir: {training_args.results_dir}")

    if training_args.reward_server_port == "-1":
        reward_urls = [training_args.reward_server_urls]
    else:
        reward_server_urls = training_args.reward_server_urls.split(",")
        reward_server_port = training_args.reward_server_port.split(",")
        reward_urls = []
        for url in reward_server_urls:
            for port in reward_server_port:
                reward_urls.append(f"http://{url}:{port}/v1")
    logger.info(f"reward_urls: {reward_urls}")

    if dist.get_world_size() % training_args.num_shard != 0:
        raise ValueError(
            f"World size {dist.get_world_size()} must be divisible by num_shard {training_args.num_shard}"
        )
    training_args.num_replicate = dist.get_world_size() // training_args.num_shard

    policy_groups = None
    if data_args.policy_group_size > 1:
        world_size = dist.get_world_size()
        if world_size % data_args.policy_group_size != 0:
            raise ValueError(
                f"World size {world_size} must be divisible by policy_group_size {data_args.policy_group_size}"
            )

        policy_group_id = dist.get_rank() // data_args.policy_group_size
        policy_group_rank = dist.get_rank() % data_args.policy_group_size

        policy_groups = []
        for i in range(world_size // data_args.policy_group_size):
            group_ranks = list(
                range(
                    i * data_args.policy_group_size,
                    (i + 1) * data_args.policy_group_size,
                )
            )
            group = dist.new_group(ranks=group_ranks)
            policy_groups.append(group)

        current_policy_group = policy_groups[policy_group_id]

        if dist.get_rank() == 0:
            logger.info(
                f"Created {len(policy_groups)} policy groups with size {data_args.policy_group_size}"
            )
            logger.info(f"Rank {dist.get_rank()} is in policy group {policy_group_id}")
    else:
        current_policy_group = None
        policy_group_id = dist.get_rank()
        policy_group_rank = 0

    if training_args.auto_resume:
        resume_from = get_latest_ckpt(training_args.checkpoint_dir)
        if resume_from is None:
            resume_from = training_args.resume_from
            resume_model_only = training_args.resume_model_only
            if resume_model_only:
                finetune_from_ema = training_args.finetune_from_ema
            else:
                finetune_from_ema = False
        else:
            resume_model_only = False
            finetune_from_ema = False
    else:
        resume_from = training_args.resume_from
        resume_model_only = training_args.resume_model_only
        if resume_model_only:
            finetune_from_ema = training_args.finetune_from_ema
        else:
            finetune_from_ema = False

    seed = training_args.global_seed * dist.get_world_size() + dist.get_rank()
    set_seed(seed)

    llm_config:Qwen2Config = Qwen2Config.from_pretrained(model_args.llm_path)
    if training_args.debug:
        llm_config.num_hidden_layers = 1
    llm_config.layer_module = model_args.layer_module
    llm_config.qk_norm = model_args.llm_qk_norm
    llm_config.tie_word_embeddings = model_args.tie_word_embeddings
    llm_config.freeze_und = training_args.freeze_und
    language_model = Qwen2ForCausalLM(llm_config)
    if training_args.copy_init_moe:
        language_model.init_moe()

    if training_args.visual_und:
        vit_config = SiglipVisionConfig.from_pretrained(model_args.vit_path)
        vit_config.num_hidden_layers = (
            vit_config.num_hidden_layers + 1 + model_args.vit_select_layer
        )
        vit_config.rope = model_args.vit_rope
        vit_model = SiglipVisionModel(vit_config)

    if training_args.visual_gen:
        vae_model, vae_config = load_ae(
            local_path=model_args.vae_path,
        )

    config = BagelConfig(
        visual_gen=training_args.visual_gen,
        visual_und=training_args.visual_und,
        llm_config=llm_config,
        vit_config=vit_config if training_args.visual_und else None,
        vae_config=vae_config if training_args.visual_gen else None,
        latent_patch_size=model_args.latent_patch_size,
        max_latent_size=model_args.max_latent_size,
        vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
        connector_act=model_args.connector_act,
        interpolate_pos=model_args.interpolate_pos,
        timestep_shift=training_args.timestep_shift,
    )
    model = Bagel(
        language_model, vit_model if training_args.visual_und else None, config
    )

    patch_latent_dim = model.patch_latent_dim

    if training_args.visual_und:
        model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config)

    # Setup tokenizer for model:
    tokenizer = Qwen2Tokenizer.from_pretrained(os.path.dirname(model_args.llm_path))
    tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
    if num_new_tokens > 0:
        model.language_model.resize_token_embeddings(len(tokenizer))
        model.config.llm_config.vocab_size = len(tokenizer)
        model.language_model.config.vocab_size = len(tokenizer)

    if training_args.freeze_vae and training_args.visual_gen:
        for param in vae_model.parameters():
            param.requires_grad = False
    if training_args.freeze_llm:
        model.language_model.eval()
        for param in model.language_model.parameters():
            param.requires_grad = False
    if training_args.freeze_vit:
        model.vit_model.eval()
        for param in model.vit_model.parameters():
            param.requires_grad = False

    fsdp_config = FSDPConfig(
        sharding_strategy=training_args.sharding_strategy,
        backward_prefetch=training_args.backward_prefetch,
        cpu_offload=training_args.cpu_offload,
        num_replicate=training_args.num_replicate,
        num_shard=training_args.num_shard,
    )
    ema_model = None
    if not training_args.debug:
        model, ema_model = FSDPCheckpoint.try_load_ckpt(
            resume_from, logger, model, ema_model, resume_from_ema=finetune_from_ema
        )
    if training_args.kl_weight_text > 0 or training_args.kl_weight_image > 0:
        ref_model = deepcopy(model)
        if not training_args.debug and training_args.ref_model_path is not None and training_args.ref_model_path != resume_from:
            ref_model, _ = FSDPCheckpoint.try_load_ckpt(
                training_args.ref_model_path, logger, ref_model, None, resume_from_ema=True,
            )
        ref_model.requires_grad_(False)
        ref_model.eval()
        ref_model = fsdp_wrapper(ref_model, fsdp_config)
    else:
        ref_model = None
    if training_args.ema > 0:
        ema_model = deepcopy(model)
        ema_model.requires_grad_(False)
        ema_model.eval()
        ema_model = fsdp_ema_setup(ema_model, fsdp_config)
    else:
        ema_model = None

    fsdp_model = fsdp_wrapper(model, fsdp_config)
    apply_activation_checkpointing(
        fsdp_model,
        checkpoint_wrapper_fn=functools.partial(
            checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT
        ),
        check_fn=grad_checkpoint_check_fn,
    )
    params_to_optimize = list(fsdp_model.parameters())
    logger.info(f"Params to optimize length before filter: {len(params_to_optimize)}")
    params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
    logger.info(f"Params to optimize length after filter: {len(params_to_optimize)}")
    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=training_args.lr,
        betas=(training_args.beta1, training_args.beta2),
        eps=training_args.eps,
        weight_decay=training_args.weight_decay,
    )

    if training_args.lr_scheduler == "cosine":
        scheduler = get_cosine_with_min_lr_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=training_args.warmup_steps,
            num_training_steps=training_args.total_steps,
            min_lr=training_args.min_lr,
        )
    elif training_args.lr_scheduler == "constant":
        scheduler = get_constant_schedule_with_warmup(
            optimizer=optimizer, num_warmup_steps=training_args.warmup_steps
        )
    else:
        raise ValueError

    if resume_model_only:
        train_step = 0 if training_args.resume_step == -1 else training_args.resume_step
        data_status = None
    else:
        optimizer, scheduler, train_step, data_status = (
            FSDPCheckpoint.try_load_train_state(
                resume_from,
                optimizer,
                scheduler,
                fsdp_config,
            )
        )

    if data_args.dataset_name == "hps":
        train_dataset = PromptDataset(data_args.train_data_path)
        test_dataset = PromptDataset(data_args.val_data_path)
    elif data_args.dataset_name in ["geneval", "geneval_plus"]:
        train_dataset = GenevalPromptDataset(data_args.train_data_path,split="")
        test_dataset = GenevalPromptDataset(data_args.val_data_path, split="", num_samples=8 if training_args.debug else None)
    elif data_args.dataset_name == "tiif":
        train_dataset = TIIFDataset(data_args.train_data_path)
        test_dataset = TIIFDataset(data_args.val_data_path)
    else:
        raise ValueError(f"Invalid dataset name: {data_args.dataset_name}")
    # 创建无限循环的DataLoader
    train_sampler = DistributedKRepeatSampler(
        dataset=train_dataset,
        batch_size=1,
        k=1,
        world_size=dist.get_world_size(),
        rank=policy_group_id,
        seed=42,
    )
    train_dataloader = DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        num_workers=1,
        collate_fn=train_dataset.collate_fn,
    )
    train_iter = iter(train_dataloader)

    vae_transform = ImageTransform(
        max_image_size=data_args.vae_max_image_size,
        min_image_size=data_args.vae_min_image_size,
        image_stride=data_args.vae_image_stride,
    )
    vit_transform = ImageTransform(
        max_image_size=data_args.vit_max_image_size,
        min_image_size=data_args.vit_min_image_size,
        image_stride=data_args.vit_image_stride,
    )

    if training_args.visual_gen:
        vae_model.to(device).eval()

    torch.set_default_device("cuda")

    executor = futures.ThreadPoolExecutor(max_workers=8)
    if training_args.reward_fn == "geneval":
        reward_fn = geneval_score(reward_urls[dist.get_rank() % len(reward_urls)])
        eval_reward_fn = geneval_score(reward_urls[dist.get_rank() % len(reward_urls)])
    elif training_args.reward_fn == "hps":
        reward_fn = hps_score(reward_urls[dist.get_rank() % len(reward_urls)])
    elif training_args.reward_fn == "unified_reward":
        reward_fn = functools.partial(UnifiedReward, url=reward_urls[dist.get_rank() % len(reward_urls)]+"/v1")
    elif training_args.reward_fn == "qwen_geneval":
        reward_fn = functools.partial(Qwen_Geneval, url=reward_urls[dist.get_rank() % len(reward_urls)]+"/v1", model_name=training_args.reward_model_name, api_key=training_args.reward_api_key)
    elif training_args.reward_fn == "geneval_plus":
        reward_fn = geneval_plus_reward(url=reward_urls[dist.get_rank() % len(reward_urls)], model_name=training_args.reward_model_name, api_key=training_args.reward_api_key, client_type=training_args.client_type)
        eval_reward_fn = reward_fn
    elif training_args.reward_fn == "yn_reward":
        reward_fn = yn_reward_fn(url=reward_urls[dist.get_rank() % len(reward_urls)], model_name=training_args.reward_model_name, api_key=training_args.reward_api_key, client_type=training_args.client_type)
        eval_reward_fn = reward_fn
    else:
        raise ValueError(f"Invalid reward function: {training_args.reward_fn}")

    if data_args.dataset_name == "geneval_plus":
        eval_reward_fn = geneval_plus_reward(url=reward_urls[dist.get_rank() % len(reward_urls)], model_name=training_args.reward_model_name, api_key=training_args.reward_api_key, client_type=training_args.client_type)

    fsdp_model.train()
    if ema_model is not None:
        ema_model.eval()

    logger.info(
        f"Starting online RL training for {training_args.total_steps} steps, starting at {train_step}..."
    )

    image_shape = (training_args.image_size, training_args.image_size)
    generation_kwargs = dict(
        think=training_args.think_mode,
        max_output_token_n_gen=training_args.max_output_token_n_gen,
        max_output_token_n_edit=training_args.max_output_token_n_edit,
        do_sample=training_args.do_sample,
        text_temperature=training_args.text_temperature,
        topk=training_args.top_k,
        cfg_text_scale=training_args.cfg_text_scale,
        cfg_img_scale=training_args.cfg_img_scale,
        cfg_interval=training_args.cfg_interval,
        timestep_shift=training_args.timestep_shift,
        num_timesteps_gen=training_args.num_timesteps_gen,
        num_timesteps_edit=training_args.num_timesteps_edit,
        cfg_renorm_min=training_args.cfg_renorm_min,
        cfg_renorm_type_gen=training_args.cfg_renorm_type_gen,
        cfg_renorm_type_edit=training_args.cfg_renorm_type_edit,
        image_shapes = image_shape,
        executor=executor,
        reward_fn=reward_fn,
        reward_fn_type=training_args.reward_fn,
        enable_sde=training_args.tune_image_cot,
    )
    if training_args.rollout_with_ema and ema_model is not None:
        infer_model = ema_model
    else:
        infer_model = fsdp_model
    if training_args.tune_image_cot:    
        sde_sampler_gen = SDESampler(eta_mode=training_args.eta_mode, constant_eta=training_args.constant_eta_gen, model_output_type="velocity")
        sde_sampler_edit = SDESampler(eta_mode=training_args.eta_mode, constant_eta=training_args.constant_eta_edit, model_output_type="velocity")
    else:
        sde_sampler_gen = None
        sde_sampler_edit = None
    logger.info(f"Using SDESampler Gen: {sde_sampler_gen} \n SDESampler Edit: {sde_sampler_edit}")
    rollout_controller = MultiRoundRolloutController(
        model=infer_model,
        vae_model=vae_model,
        tokenizer=tokenizer,
        vae_transform=vae_transform,
        vit_transform=vit_transform,
        new_token_ids=new_token_ids,
        sde_sampler=sde_sampler_gen,
    )
    rollout_buffer = RolloutBuffer()

    for curr_step in range(train_step, 100000):
        torch.cuda.synchronize()
        start_time = time.time()
        group_seed = (policy_group_id + 1) * training_args.total_steps + curr_step
        gen_seed = (dist.get_rank() + 1) * training_args.total_steps + curr_step
        group_generator = torch.Generator(device="cuda")
        group_generator.manual_seed(group_seed)
        num_image_tokens = image_shape[0] * image_shape[1] // 16**2
        sample_generator = torch.Generator(device="cuda")
        sample_generator.manual_seed(gen_seed)
        rollout_batch_size = max(training_args.group_size // data_args.policy_group_size, 1)

        generation_kwargs["initial_noise"] = [torch.randn(
            num_image_tokens, patch_latent_dim, dtype=torch.bfloat16, generator=group_generator, device="cuda"
        ) for _ in range(5)]

        with rollout_controller.context():
            # t2i, recaption
            if curr_step % training_args.rounds == 0:
                rollout_buffer.clear()
                train_sampler.set_epoch(curr_step)
                if data_args.dataset_name == "hps":
                    prompts = next(train_iter)
                    metadatas = None
                elif data_args.dataset_name in ["geneval", "geneval_plus", "tiif"]:
                    prompts, metadatas = next(train_iter)
                metadata = metadatas[0]
                generation_kwargs["enable_sde"] = training_args.enable_sde_gen
                if training_args.enable_sde_gen:
                    timestep_idx = torch.randint(0, training_args.num_timesteps_gen - 1, (1,), device="cuda", dtype=torch.int, generator=group_generator)
                    generation_kwargs["sde_timestep_idx"] = timestep_idx.tolist()
                    rollout_controller.sde_sampler = sde_sampler_gen
                inference_results = rollout_controller.rollout_v2(
                    rounds=1,
                    prompts = prompts * rollout_batch_size,
                    prompt_metadata = metadatas * rollout_batch_size,
                    generator=sample_generator,
                    **generation_kwargs,
                )
            else:
                metadata, reward, img = rollout_buffer.sample(seed=group_seed)
                last_reward = reward
                prompt = metadata['prompt']
                reflection_system_prompt = '''The description of the target image is: {prompt}\nHow to further edit the provided image to make it consistent with the target description? Please provide concrete editing instructions in a single sentence. If no editing operation is needed, answer: no further edit needed.'''
                reflection_system_prompt = reflection_system_prompt.format(prompt=prompt)
                rollout_batch_size = max(training_args.group_size // data_args.policy_group_size, 1)
                generation_kwargs["enable_sde"] = training_args.enable_sde_edit
                if training_args.enable_sde_edit:
                    timestep_idx = torch.randint(0, training_args.max_sde_timestep_idx_for_edit, (1,), device="cuda", dtype=torch.int, generator=group_generator)
                    generation_kwargs["sde_timestep_idx"] = timestep_idx.tolist()
                    rollout_controller.sde_sampler = sde_sampler_edit
                inference_results = rollout_controller.rollout_v2(
                    rounds=2,
                    start_round_idx=1,
                    prompts = [prompt] * rollout_batch_size,
                    reflection_prompts = [reflection_system_prompt] * rollout_batch_size,
                    prompt_metadata = [metadata] * rollout_batch_size,
                    input_imgs = [img] * rollout_batch_size,
                    generator=sample_generator,
                    **generation_kwargs,
                )
        inference_results['per_sample_total_rewards'] = [[] for _ in range(rollout_batch_size)]
        edit_img_adv_multiplier = []
        per_round_format_rewards = [[] for _ in range(1)]
        edit_train_stats = {
            "improved_cnt": 0,
            "no_change_cnt": 0,
            "worse_cnt": 0,
            "mean_improvement": 0,
        }
        change_summary = ""
        for sample_idx, per_sample_reward in enumerate(inference_results["per_sample_rewards"]):
            for round_idx, per_round_reward in enumerate(per_sample_reward):
                if training_args.reward_fn == "unified_reward":
                    round_results = per_round_reward.result()
                elif training_args.reward_fn == "geneval":
                    all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict = per_round_reward.result()
                    round_results = all_scores
                elif training_args.reward_fn == "hps":
                    round_results = per_round_reward.result()
                elif training_args.reward_fn == "qwen_geneval":
                    round_results = per_round_reward.result()
                elif training_args.reward_fn == "geneval_plus":
                    round_results = per_round_reward.result()
                elif training_args.reward_fn == "yn_reward":
                    round_results = (per_round_reward.result(), "")
                per_sample_reward[round_idx] = round_results[0]
                per_round_format_rewards[round_idx].append(inference_results['per_sample_format_rewards'][sample_idx][round_idx])
                if curr_step % training_args.rounds == 0:
                    format_reward = inference_results['per_sample_format_rewards'][sample_idx][round_idx]
                    total_reward = round_results[0] * format_reward + training_args.format_reward_weight * format_reward
                    inference_results['per_sample_total_rewards'][sample_idx].append(total_reward)
                    rollout_buffer.update(metadata, round_results[0], inference_results['output_imgs'][sample_idx])
                else:
                    this_reward = round_results[0]
                    improvement = this_reward - last_reward
                    if improvement > 0:
                        edit_train_stats["improved_cnt"] += 1
                        change_summary += "improved,"
                    elif improvement == 0:
                        edit_train_stats["no_change_cnt"] += 1
                        change_summary += "no_change,"
                    else:
                        edit_train_stats["worse_cnt"] += 1
                        change_summary += "worse,"

                    format_reward = inference_results['per_sample_format_rewards'][sample_idx][round_idx]
                    if last_reward == 1 and "no further edit" in inference_results['per_sample_edit_operations'][sample_idx].lower():
                        improvement = 1
                    if "no further edit" in inference_results['per_sample_edit_operations'][sample_idx].lower():
                        edit_img_adv_multiplier.append(0)
                    else:
                        edit_img_adv_multiplier.append(1)

                    total_reward = format_reward * training_args.format_reward_weight + improvement * format_reward
                    inference_results['per_sample_total_rewards'][sample_idx].append(total_reward)
                    edit_train_stats["mean_improvement"] += ((this_reward - last_reward)/rollout_batch_size)
                    
                    # rollout_buffer.update(metadata, round_results[0], inference_results['output_imgs'][sample_idx])
        if curr_step % training_args.rounds == 0:
            rollout_buffer.all_gather()
        logger.info(f"Rank {dist.get_rank()}: Rollout buffer length: {rollout_buffer.__len__()}")
        # save rollout results
        if curr_step % training_args.rollout_vis_step <= training_args.rounds-1:
            save_inference_results(inference_results, curr_step, training_args.results_dir, mode="train", num_samples_to_save=4, last_reward=last_reward if curr_step % training_args.rounds != 0 else None, change_summary=change_summary)
        # compute advantages
        text_advantages, additional_stats = compute_advantages_for_multi_round_rollout(
            inference_results["per_sample_total_rewards"], # with text format reward
            current_policy_group,
            first_subtract_mean=True,
            first_divide_std=True,
            later_operations="",
            first_scaler=1.0,
            later_scaler=1.0,
        )
        image_advantages, additional_stats = compute_advantages_for_multi_round_rollout(
            inference_results["per_sample_rewards"],
            current_policy_group,
            first_subtract_mean=True,
            first_divide_std=True,
            later_operations="",
            first_scaler=1.0,
            later_scaler=1.0,
        )

        if curr_step % training_args.rounds != 0:
            for sample_idx, multiplier in enumerate(edit_img_adv_multiplier):
                image_advantages[sample_idx][0] = image_advantages[sample_idx][0] * multiplier
        logger.info(f"Rank {dist.get_rank()}: Text advantages: {text_advantages}")
        logger.info(f"Rank {dist.get_rank()}: Image advantages: {image_advantages}")

        # count number of gt 0 elements in text_advantages
        text_advantages_gt = torch.tensor(text_advantages, device="cuda", dtype=torch.float)
        text_advantages_lt = torch.tensor(text_advantages, device="cuda", dtype=torch.float)
        text_advantages_zero = torch.tensor(text_advantages, device="cuda", dtype=torch.float)
        text_advantages_gt_0 = (text_advantages_gt > 0).sum()
        text_advantages_lt_0 = (text_advantages_lt < 0).sum()
        text_advantages_zero = (text_advantages_zero == 0).sum()
        dist.all_reduce(text_advantages_gt_0, op=dist.ReduceOp.SUM)
        dist.all_reduce(text_advantages_lt_0, op=dist.ReduceOp.SUM)
        dist.all_reduce(text_advantages_zero, op=dist.ReduceOp.SUM)

        # update rollout stats
        round_idx = curr_step % training_args.rounds
        mode = "recaption" if curr_step % training_args.rounds == 0 else "edit"
        if mode == "edit": mode = mode + "_" + str(round_idx)
        # process cot length
        round_mean_cot_length = inference_results[f'Round_0/mean_cot_length']
        round_mean_cot_length = torch.tensor(round_mean_cot_length, device="cuda", dtype=torch.float).mean()
        dist.all_reduce(round_mean_cot_length, op=dist.ReduceOp.AVG)
        additional_stats[f'{mode}/mean_cot_length'] = round_mean_cot_length.item()
        # process format reward
        round_mean_format_reward = torch.tensor(per_round_format_rewards[0], device="cuda", dtype=torch.float).mean()
        dist.all_reduce(round_mean_format_reward, op=dist.ReduceOp.AVG)
        additional_stats[f'{mode}/mean_format_reward'] = round_mean_format_reward.item()
        additional_stats[f'{mode}/mean_reward'] = additional_stats['round_mean_rewards'].pop("round_0")
        additional_stats[f'{mode}/global_mean_std'] = additional_stats.pop("round_0_global_mean_std")
        if curr_step % training_args.rounds != 0:
            for key, value in edit_train_stats.items():
                edit_train_stats[key] = torch.tensor(value, device="cuda", dtype=torch.float).mean()
                if key != "mean_improvement":
                    dist.all_reduce(edit_train_stats[key], op=dist.ReduceOp.SUM)
                else:
                    dist.all_reduce(edit_train_stats[key], op=dist.ReduceOp.AVG)
            additional_stats[f"{mode}"] = edit_train_stats
        additional_stats[f"{mode}/text_advantages_gt_0"] = text_advantages_gt_0.item()
        additional_stats[f"{mode}/text_advantages_lt_0"] = text_advantages_lt_0.item()
        additional_stats[f"{mode}/text_advantages_zero"] = text_advantages_zero.item()


        dist.barrier()
        loss_dict = {
            "text_policy_loss": [],
            "img_policy_loss": [],
            "text_kl_loss": [],
            "img_kl_loss": [],
        }
        fsdp_model.train()
        optimizer.zero_grad()

        # NOTE: We collect all batches to ensure synchronized iteration counts across ranks.
        # This uses more memory but avoids doing the expensive pack_iterator work twice.
        # The pack_iterator does heavy tensor operations, tokenization, and data packing,
        # so materializing results once is more efficient than iterating twice.
        (
            all_data_points,
            total_text_grad_tokens,
            total_image_num,
        ) = rollout_controller.prepack_v2(
            inference_results,
            with_text_loss=training_args.tune_text_cot,
            with_img_loss=training_args.tune_image_cot,
            timestep_ratio=training_args.timestep_sample_ratio,
            pack_start_round_idx=training_args.pack_start_round_idx,
            pack_end_round_idx=training_args.rounds-1,
        )

        logger.info(
            f"Rank {dist.get_rank()}: Processing {len(all_data_points)} iterations"
        )
        logger.info(f"Rank {dist.get_rank()}: total text grad tokens: {total_text_grad_tokens}, total image num: {total_image_num}")

        # Process the synchronized number of batches

        for iteration_count, (packed_input, loss_info) in enumerate(
            tqdm(
                rollout_controller.pack_iterator_v2(
                    all_data_points,
                    text_advantages=text_advantages,
                    image_advantages=image_advantages,
                    cfg_interval=training_args.cfg_interval,
                ),
                total=1,
                desc=f"Gradient Accumulation [{total_text_grad_tokens} Text Tokens, {total_image_num} Images]",
            )
        ):
            if curr_step == 0 and iteration_count == 0:
                for key, value in packed_input.items():
                    if isinstance(value, torch.Tensor):
                        logger.info(f"Rank {dist.get_rank()} Packed Input: {key} shape: {value.shape}")
                    else:
                        logger.info(f"Rank {dist.get_rank()} Packed Input: {key} is {value}")
                for key, value in loss_info.items():
                    if isinstance(value, torch.Tensor):
                        logger.info(f"Rank {dist.get_rank()} Loss Info: {key} shape: {value.shape}")
                    else:
                        logger.info(f"Rank {dist.get_rank()} Loss Info: {key} is {value}")
            
            # repeat advantages to match the length of label_ids
            cur_text_advantages = loss_info.get("text_advantages", None)
            cur_image_advantages = loss_info.get("image_advantages", None)
            label_ids = loss_info.get("packed_label_ids", None)
            # Step 4: training model with the packed data
            # Forward pass with mixed precision
            loss = 0
            with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
                # Compute ref model logits
                with torch.inference_mode():
                    if (training_args.kl_weight_text > 0 or training_args.kl_weight_image > 0): # forward ref model for text kl loss
                        ref_loss_dict = ref_model(**packed_input, ref_forward=True)
                        ref_per_token_logps = (
                            ref_loss_dict["logits"]
                            .log_softmax(dim=-1)
                            .gather(dim=-1, index=label_ids.unsqueeze(-1))
                            .squeeze(-1)
                        )
                        logger.info(
                            f"Rank {dist.get_rank()}: Ref per token logps shape: {ref_per_token_logps.shape}"
                        )
                    if (training_args.kl_weight_image > 0 and training_args.tune_image_cot):
                        ref_v_pred = ref_loss_dict["packed_mse_preds"]

                output_dict = fsdp_model(**packed_input)
            if training_args.tune_text_cot:
                per_token_logps = (
                    output_dict["logits"]
                    .log_softmax(dim=-1)
                    .gather(dim=-1, index=label_ids.unsqueeze(-1))
                    .squeeze(-1)
                )
                mean_entropy = compute_entropy_batched(output_dict["logits"]).mean()
                per_token_loss = -(
                    torch.exp(per_token_logps - per_token_logps.detach())
                    * cur_text_advantages
                ).sum() / (
                    1e-4 + total_text_grad_tokens
                )  # online policy loss
                loss_dict["text_policy_loss"].append(per_token_loss.item())
                if training_args.kl_weight_text > 0:
                    per_token_kl = (
                        torch.exp(ref_per_token_logps - per_token_logps)
                        - (ref_per_token_logps - per_token_logps)
                        - 1
                    )  # k3 estimation
                    per_token_kl = per_token_kl.sum() / (1e-4 + total_text_grad_tokens)
                    logger.info(
                        f"Rank {dist.get_rank()}: Text kl loss: {per_token_kl.item()}"
                    )
                    loss_dict["text_kl_loss"].append(per_token_kl.item())
                    # ignore unexpected large kl loss
                    if per_token_kl.item() > 10:
                        per_token_kl = per_token_logps.mean() * 0.0
                    loss = per_token_loss + training_args.kl_weight_text * per_token_kl
                else:
                    loss = per_token_loss
                loss = loss * training_args.ce_weight
                logger.info(f"Rank {dist.get_rank()}: Text Loss: {loss.item()}")

            v_pred = output_dict["packed_mse_preds"]
            if training_args.tune_image_cot and v_pred is not None:
                if curr_step % training_args.rounds == 0:
                    sde_sampler = sde_sampler_gen
                else:
                    sde_sampler = sde_sampler_edit
                v_pred = rollout_controller.apply_cfg(
                        v_pred,
                        loss_info["mse_condition_token_indexes"],
                        training_args.cfg_text_scale,
                        training_args.cfg_img_scale,
                        training_args.cfg_interval,
                        training_args.cfg_renorm_min,
                        training_args.cfg_renorm_type_gen,
                        training_args.cfg_renorm_type_edit,
                    )
                logger.info(f"Rank {dist.get_rank()}: v_pred shape: {v_pred.shape}")
                x_curr = sde_sampler.get_x_t_distribution(
                    loss_info["packed_noisy_latent_for_loss"].float(),
                    loss_info["packed_timesteps_for_loss"].float(),
                    loss_info["dts"].float(),
                    v_pred.float(),
                )
                x_prev = loss_info["packed_prev_latents"]
                log_prob = x_curr.log_prob(x_prev)
                # convert it to image level loss
                skip_interval = 1
                img_latent_lengths = [h * w for h, w in loss_info['patchified_vae_latent_shapes_for_loss'][::skip_interval]] # ::2 to skip unconditioned images
                assert sum(img_latent_lengths) == log_prob.shape[0]
                log_prob_per_image = torch.stack([x.mean() for x in torch.split(log_prob, img_latent_lengths)]) # (num_images,)
                logger.info(f"Rank {dist.get_rank()}: log_prob_per_image: {log_prob_per_image}")
                img_per_token_loss = -(
                    torch.exp(log_prob_per_image - log_prob_per_image.detach()) * cur_image_advantages.to(log_prob_per_image.dtype)
                ).sum() / (1e-4 + total_image_num)
                loss_dict["img_policy_loss"].append(img_per_token_loss.item())

                # TODO: add kl loss for the image model
                img_loss = img_per_token_loss
                kl_img_loss = 0
                if training_args.kl_weight_image > 0:
                    ref_v_pred = rollout_controller.apply_cfg(
                        ref_v_pred,
                        loss_info["mse_condition_token_indexes"],
                        training_args.cfg_text_scale,
                        training_args.cfg_img_scale,
                        training_args.cfg_interval,
                        training_args.cfg_renorm_min,
                        training_args.cfg_renorm_type_gen,
                        training_args.cfg_renorm_type_edit,
                    )
                    ref_x_curr = sde_sampler.get_x_t_distribution(
                        loss_info["packed_noisy_latent_for_loss"].float(),
                        loss_info["packed_timesteps_for_loss"].float(),
                        loss_info["dts"].float(),
                        ref_v_pred.float(),
                    )
                    kl_img_loss = x_curr.kl_divergence(ref_x_curr).mean()
                    loss_dict["img_kl_loss"].append(kl_img_loss.item())

                logger.info(f"Rank {dist.get_rank()}: Image Loss: {img_loss.item():.4f}, KL: {kl_img_loss:.4f}")
                loss = loss + img_loss * training_args.mse_weight + kl_img_loss * training_args.kl_weight_image
            # Backward pass
            loss.backward()
        # Gradient clipping and optimization step
        total_norm = fsdp_model.clip_grad_norm_(training_args.max_grad_norm)
        optimizer.step()
        scheduler.step()
        logger.info(f"Rank {dist.get_rank()}: Total norm: {total_norm.item()}")
        for k in list(loss_dict.keys()):
            v = loss_dict[k]
            if len(v) == 0:
                loss_dict.pop(k)
                continue
            v_sum = torch.tensor(sum(v), device=device)
            v_size = torch.tensor(len(v), device=device)
            dist.all_reduce(v_sum, op=dist.ReduceOp.AVG)
            dist.all_reduce(v_size, op=dist.ReduceOp.AVG)
            loss_dict[k] = v_sum / v_size
        if ema_model is not None:
            fsdp_ema_update(ema_model, fsdp_model, decay=training_args.ema)

        # Log loss values for this batch
        if curr_step % training_args.log_every == 0:
            # Measure training speed
            torch.cuda.synchronize()
            end_time = time.time()
            steps_per_sec = training_args.log_every / (end_time - start_time)
            message = f"(step={curr_step:07d})  step time: {end_time - start_time:.2f}s"
            wandb_log = {}
            wandb_log["step_time"] = end_time - start_time
            # all gather rewards
            if training_args.tune_text_cot:
                dist.all_reduce(mean_entropy, op=dist.ReduceOp.AVG)
                mean_entropy = mean_entropy.item()
                wandb_log["mean_entropy"] = mean_entropy
                message += f"Mean Entropy: {mean_entropy:.4f}, "
            for info_key, info_value in additional_stats.items():
                if isinstance(info_value, dict):
                    for key_, value_ in info_value.items():
                        wandb_log[f"{info_key}/{key_}"] = value_
                        message += f"{info_key}/{key_}: {value_:.4f}, "
                else:
                    wandb_log[info_key] = info_value
                    message += f"{info_key}: {info_value:.4f}, "

            message += f"Train Steps/Sec: {steps_per_sec:.2f}, "
            logger.info(message)

            wandb_log["lr"] = optimizer.param_groups[0]["lr"]
            wandb_log["total_norm"] = total_norm.item()

            mem_allocated = torch.tensor(
                torch.cuda.max_memory_allocated() / 1024**3, device=device
            )
            dist.all_reduce(mem_allocated, op=dist.ReduceOp.MAX)
            wandb_log["mem_allocated"] = mem_allocated
            mem_cache = torch.tensor(
                torch.cuda.max_memory_reserved() / 1024**3, device=device
            )
            dist.all_reduce(mem_cache, op=dist.ReduceOp.MAX)
            wandb_log["mem_cache"] = mem_cache
            wandb_log.update({f"train/{k}": v for k, v in loss_dict.items()})
            if dist.get_rank() == 0:
                wandb.log(wandb_log, step=curr_step)

        if curr_step > 0 and (curr_step % training_args.save_every == 0 or curr_step == training_args.total_steps - 1):
            try:
                FSDPCheckpoint.fsdp_save_ckpt(
                    ckpt_dir=training_args.checkpoint_dir,
                    train_steps=curr_step,
                    model=fsdp_model,
                    ema_model=ema_model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    logger=logger,
                    fsdp_config=fsdp_config,
                    data_status=None,  # online training不需要data_status
                )
            except Exception as e:
                logger.error(f"Rank {dist.get_rank()}: Save checkpoint failed: {e}")

    logger.info("Done!")
    if dist.get_rank() == 0:
        wandb.finish()
    dist.destroy_process_group()


if __name__ == "__main__":
    try:
        os.system("touch /root/wait1")
        main()
    except Exception as e:
        os.system("rm -rf /root/wait1")
        raise e
    finally:
        os.system("rm -rf /root/wait1")