import argparse
import math
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pathlib import Path
from fastvideo.utils.parallel_states import (
    initialize_sequence_parallel_state,
    destroy_sequence_parallel_group,
    get_sequence_parallel_state,
    nccl_info,
)
from fastvideo.utils.communications_flux import sp_parallel_dataloader_wrapper
import time
from torch.utils.data import DataLoader
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_model_state_dict, set_model_state_dict, StateDictOptions

from torch.utils.data.distributed import DistributedSampler
from fastvideo.utils.dataset_utils import LengthGroupedSampler
import wandb
from accelerate.utils import set_seed
from tqdm.auto import tqdm
from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs, apply_fsdp_checkpointing
from fastvideo.utils.load import load_transformer
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from fastvideo.dataset.latent_flux_rl_datasets import LatentDataset, latent_collate_function
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from fastvideo.utils.checkpoint import (
    save_checkpoint,
    save_lora_checkpoint,
)
from fastvideo.utils.logging_ import main_print
import cv2
from diffusers.image_processor import VaeImageProcessor

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
import time
from collections import deque
import numpy as np
from einops import rearrange
import torch.distributed as dist
from torch.nn import functional as F
from typing import List
from PIL import Image
from diffusers import FluxTransformer2DModel, AutoencoderKL
from contextlib import contextmanager
from safetensors.torch import save_file
import json

class FSDP_EMA:
    def __init__(self, model, decay, rank):
        self.decay = decay
        self.rank = rank
        self.ema_state_dict_rank0 = {}
        options = StateDictOptions(full_state_dict=True, cpu_offload=True)
        state_dict = get_model_state_dict(model, options=options)

        if self.rank == 0:
            self.ema_state_dict_rank0 = {k: v.clone() for k, v in state_dict.items()}
            main_print("--> Modern EMA handler initialized on rank 0.")

    def update(self, model):
        options = StateDictOptions(full_state_dict=True, cpu_offload=True)
        model_state_dict = get_model_state_dict(model, options=options)

        if self.rank == 0:
            for key in self.ema_state_dict_rank0:
                if key in model_state_dict:
                    self.ema_state_dict_rank0[key].copy_(
                        self.decay * self.ema_state_dict_rank0[key] + (1 - self.decay) * model_state_dict[key]
                    )

    @contextmanager
    def use_ema_weights(self, model):
        backup_options = StateDictOptions(full_state_dict=True, cpu_offload=True)
        backup_state_dict_rank0 = get_model_state_dict(model, options=backup_options)

        load_options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
        set_model_state_dict(
            model,
            model_state_dict=self.ema_state_dict_rank0, 
            options=load_options
        )
        
        try:
            yield
        finally:
            restore_options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
            set_model_state_dict(
                model,
                model_state_dict=backup_state_dict_rank0, 
                options=restore_options
            )

def save_ema_checkpoint(ema_handler, rank, output_dir, step, epoch, config_dict):
    if rank == 0 and ema_handler is not None:
        ema_checkpoint_path = os.path.join(output_dir, f"checkpoint-ema-{step}-{epoch}")
        os.makedirs(ema_checkpoint_path, exist_ok=True)
        weight_path = os.path.join(ema_checkpoint_path ,
                                   "diffusion_pytorch_model.safetensors")
        save_file(ema_handler.ema_state_dict_rank0, weight_path)
        if "dtype" in config_dict:
            del config_dict["dtype"]  # TODO
        config_path = os.path.join(ema_checkpoint_path, "config.json")
        # save dict as json
        import json
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)
        main_print(f"--> EMA checkpoint saved at {ema_checkpoint_path}")

# --- CHANGED: flowGRPO shift calculation ---
def calculate_shift(
    image_seq_len,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu
    
def flow_grpo_step(
    model_output: torch.Tensor,
    latents: torch.Tensor,
    noise_level: float,
    sigmas: torch.Tensor,
    index: int,
    prev_sample: torch.Tensor,
    grpo: bool,
):
    """
    Implements the sde_step_with_logprob logic from flowGRPO fully.
    """
    # 1. 准备 Sigmas 和 dt
    # 注意：Flux 的 sigma 是从 1.0 到 0.0，所以 sigma_prev (下一个时间步) < sigma
    sigma = sigmas[index]
    sigma_prev = sigmas[index + 1]
    dt = sigma_prev - sigma # dt 为负值

    # 2. 处理 sigma=1 的边界情况 (防止除0)
    # flowGRPO 逻辑：如果 sigma==1, 使用 sigma_max (即 sigmas[1]) 代替
    sigma_max = sigmas[1]
    eff_sigma = sigma if sigma < 1.0 else sigma_max

    # 3. 计算 SDE 系数 (对应原代码中的 std_dev_t)
    # 公式: noise_level * sqrt(sigma / (1-sigma))
    # 注意：这里 epsilon 设小一点防止 nan
    S = noise_level * (eff_sigma / (1.0 - eff_sigma + 1e-6)) ** 0.5
    
    # 4. 计算 Drift (Mean) - 包含 SDE 修正项
    # 原代码: 
    # prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt) + 
    #                    model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
    # 提取公共项 S**2 / (2*eff_sigma) * dt
    common_factor = (S ** 2 / (2 * eff_sigma)) * dt
    
    # 修正后的 latents 系数: 1 + common_factor
    # 修正后的 model_output 系数: dt * (1 + S**2 * (1-eff_sigma) / (2*eff_sigma)) 
    #                          = dt + model_output * common_factor * (1-eff_sigma) ??
    # 让我们仔细拆解原公式: model_output * (1 + [S^2*(1-sigma)/(2*sigma)]) * dt
    #                      = model_output * dt + model_output * (S^2*(1-sigma)/(2*sigma)) * dt
    #                      = model_output * dt + model_output * common_factor * (1-eff_sigma)
    
    # 组合 Mean
    drift_correction_latents = latents * common_factor
    drift_correction_output = model_output * common_factor * (1.0 - eff_sigma)
    
    # 标准 Euler 项: latents + model_output * dt
    # 加上 SDE 修正
    prev_sample_mean = latents + model_output * dt + drift_correction_latents + drift_correction_output
    
    # 5. 计算实际采样用的噪声标准差
    # 原代码: std_dev_t * sqrt(-dt)
    # 我们这里 S 就是原代码的 std_dev_t
    sample_std_dev = S * (-dt) ** 0.5

    # 预测 x0 (仅用于日志或调试，不影响 SDE 路径)
    pred_original_sample = latents - sigma * model_output

    # 6. 采样或计算 Log Prob
    # 确保 sample_std_dev 是 tensor 以便后续操作
    if not torch.is_tensor(sample_std_dev):
        sample_std_dev = torch.tensor(sample_std_dev, device=latents.device, dtype=latents.dtype)
    else:
        sample_std_dev = sample_std_dev.to(device=latents.device, dtype=latents.dtype)

    if grpo:
        if prev_sample is None:
             # Sampling Phase
            noise = torch.randn_like(prev_sample_mean)
            prev_sample = prev_sample_mean + noise * sample_std_dev
        
        # Training Phase (Compute Log Prob)
        # log p(x_{t-1} | x_t)
        # Gaussian Log Prob: -0.5 * ((x - mu)/sigma)^2 - log(sigma) - 0.5 * log(2pi)
        # 加上 1e-12 防止 log(0)
        safe_std = sample_std_dev + 1e-12
        
        log_prob = (
            -((prev_sample.detach().float() - prev_sample_mean.float()) ** 2) / (2 * (safe_std ** 2))
            - torch.log(safe_std) 
            - math.log(math.sqrt(2 * math.pi))
        )

        # Mean along all but batch dimension
        log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
        
        # 为了 KL 散度计算，我们需要返回 mean 和 std
        # 注意：这里的 std 是 sample_std_dev (实际采样方差的根), 不是参数 S
        return prev_sample, pred_original_sample, log_prob, prev_sample_mean, sample_std_dev
    else:
        noise = torch.randn_like(prev_sample_mean)
        prev_sample = prev_sample_mean + noise * sample_std_dev
        return prev_sample, pred_original_sample


def assert_eq(x, y, msg=None):
    assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}"


def prepare_latent_image_ids(batch_size, height, width, device, dtype):
    latent_image_ids = torch.zeros(height, width, 3)
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    )

    return latent_image_ids.to(device=device, dtype=dtype)

def pack_latents(latents, batch_size, num_channels_latents, height, width):
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

    return latents

def unpack_latents(latents, height, width, vae_scale_factor):
    batch_size, num_patches, channels = latents.shape

    # VAE applies 8x compression on images but we must also account for packing which requires
    # latent height and width to be divisible by 2.
    height = 2 * (int(height) // (vae_scale_factor * 2))
    width = 2 * (int(width) // (vae_scale_factor * 2))

    latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
    latents = latents.permute(0, 3, 1, 4, 2, 5)

    latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

    return latents

def run_sample_step(
        args,
        z,
        progress_bar,
        sigma_schedule,
        transformer,
        encoder_hidden_states, 
        pooled_prompt_embeds, 
        text_ids,
        image_ids, 
        grpo_sample,
    ):
    if grpo_sample:
        all_latents = [z]
        all_log_probs = []
        all_means = [] # Store means for reference in KL
        
        for i in progress_bar:  # Add progress bar
            B = encoder_hidden_states.shape[0]
            sigma = sigma_schedule[i]
            timestep_value = int(sigma * 1000)
            timesteps = torch.full([encoder_hidden_states.shape[0]], timestep_value, device=z.device, dtype=torch.long)
            transformer.eval()
            with torch.autocast("cuda", torch.bfloat16):
                pred= transformer(
                    hidden_states=z,
                    encoder_hidden_states=encoder_hidden_states,
                    timestep=timesteps/1000,
                    guidance=torch.tensor(
                        [3.5],
                        device=z.device,
                        dtype=torch.bfloat16
                    ),
                    txt_ids=text_ids.repeat(encoder_hidden_states.shape[1],1), # B, L
                    pooled_projections=pooled_prompt_embeds,
                    img_ids=image_ids,
                    joint_attention_kwargs=None,
                    return_dict=False,
                )[0]
            
            # Use flow_grpo_step
            # We discard 'std' here as it's not needed for reference storage
            z, pred_original, log_prob, mean, std = flow_grpo_step(
                pred, 
                z.to(torch.float32), 
                args.noise_level, # Use noise_level instead of eta
                sigmas=sigma_schedule, 
                index=i, 
                prev_sample=None, 
                grpo=True
            )
            
            z.to(torch.bfloat16)
            all_latents.append(z)
            all_log_probs.append(log_prob)
            all_means.append(mean)
            
        latents = pred_original
        all_latents = torch.stack(all_latents, dim=1)  # (batch_size, num_steps + 1, ...)
        all_log_probs = torch.stack(all_log_probs, dim=1)  # (batch_size, num_steps)
        all_means = torch.stack(all_means, dim=1) # (batch_size, num_steps, C, H, W)
        
        # Return 5 items (removed all_stds)
        return z, latents, all_latents, all_log_probs, all_means

        
def grpo_one_step(
            args,
            latents,
            pre_latents,
            encoder_hidden_states, 
            pooled_prompt_embeds, 
            text_ids,
            image_ids,
            transformer,
            timesteps,
            i,
            sigma_schedule,
):
    transformer.train()
    with torch.autocast("cuda", torch.bfloat16):
        pred= transformer(
            hidden_states=latents,
            encoder_hidden_states=encoder_hidden_states,
            timestep=timesteps/1000,
            guidance=torch.tensor(
                [3.5],
                device=latents.device,
                dtype=torch.bfloat16
            ),
            txt_ids=text_ids.repeat(encoder_hidden_states.shape[1],1), # B, L
            pooled_projections=pooled_prompt_embeds,
            img_ids=image_ids.squeeze(0),
            joint_attention_kwargs=None,
            return_dict=False,
        )[0]
    
    # Use flow_grpo_step
    _, _, log_prob, mean, std = flow_grpo_step(
        pred, 
        latents.to(torch.float32), 
        args.noise_level, 
        sigma_schedule, 
        i, 
        prev_sample=pre_latents.to(torch.float32), 
        grpo=True
    )
    return log_prob, mean, std


def sample_reference_model(
    args,
    device, 
    transformer,
    vae,
    encoder_hidden_states, 
    pooled_prompt_embeds, 
    text_ids,
    reward_model,
    tokenizer,
    caption,
    preprocess_val,
):
    w, h, t = args.w, args.h, args.t
    sample_steps = args.sampling_steps
    
    # --- CHANGED: flowGRPO shift logic ---
    sigma_schedule = torch.linspace(1, 0, args.sampling_steps + 1)
    image_seq_len = (w // 16) * (h // 16) 
    mu = calculate_shift(image_seq_len)
    
    assert_eq(
        len(sigma_schedule),
        sample_steps + 1,
        "sigma_schedule must have length sample_steps + 1",
    )

    SPATIAL_DOWNSAMPLE = 8
    IN_CHANNELS = 16
    latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE

    batch_size = 1  
    batch_indices = torch.chunk(torch.arange(encoder_hidden_states.shape[0]), encoder_hidden_states.shape[0] // batch_size)

    all_latents = []
    all_log_probs = []
    all_rewards = []  
    all_image_ids = []
    all_means_ref = [] # Store means from sampling as reference

    for index, batch_idx in enumerate(batch_indices):
        batch_encoder_hidden_states = encoder_hidden_states[batch_idx]
        batch_pooled_prompt_embeds = pooled_prompt_embeds[batch_idx]
        batch_text_ids = text_ids[batch_idx]
        batch_caption = [caption[i] for i in batch_idx]
        
        input_latents = torch.randn(
                (len(batch_idx), IN_CHANNELS, latent_h, latent_w),  #（c,t,h,w)
                device=device,
                dtype=torch.bfloat16,
            )
        
        input_latents_new = pack_latents(input_latents, len(batch_idx), IN_CHANNELS, latent_h, latent_w)
        image_ids = prepare_latent_image_ids(len(batch_idx), latent_h // 2, latent_w // 2, device, torch.bfloat16)
        grpo_sample=True
        progress_bar = range(0, sample_steps)
        with torch.no_grad():
            # Unpack 5 values (removed batch_stds)
            z, latents, batch_latents, batch_log_probs, batch_means = run_sample_step(
                args,
                input_latents_new,
                progress_bar,
                sigma_schedule,
                transformer,
                batch_encoder_hidden_states,
                batch_pooled_prompt_embeds,
                batch_text_ids,
                image_ids,
                grpo_sample,
            )
        
        all_image_ids.append(image_ids)
        all_latents.append(batch_latents)
        all_log_probs.append(batch_log_probs)
        all_means_ref.append(batch_means)

        vae.enable_tiling()
        
        image_processor = VaeImageProcessor(16)
        rank = int(os.environ["RANK"])

        with torch.inference_mode():
            with torch.autocast("cuda", dtype=torch.bfloat16):
                latents = unpack_latents(latents, h, w, 8)
                latents = (latents / 0.3611) + 0.1159
                image = vae.decode(latents, return_dict=False)[0]
                decoded_image = image_processor.postprocess(image)
        
        pil_img_obj = decoded_image[0]
        if args.training_save_path:
            os.makedirs(args.training_save_path, exist_ok=True)
        img_save_path = f"{args.training_save_path}/{rank}_{index}.png"
        pil_img_obj.save(img_save_path)

        if args.use_hpsv2:
            with torch.no_grad():
                image_path = decoded_image[0]
                image = preprocess_val(image_path).unsqueeze(0).to(device=device, non_blocking=True)
                # Process the prompt
                text = tokenizer([batch_caption[0]]).to(device=device, non_blocking=True)
                # Calculate the HPS
                with torch.amp.autocast('cuda'):
                    outputs = reward_model(image, text)
                    image_features, text_features = outputs["image_features"], outputs["text_features"]
                    logits_per_image = image_features @ text_features.T
                    hps_score = torch.diagonal(logits_per_image)
                all_rewards.append(hps_score)
        
        if args.use_hpsv3:
            with torch.no_grad():
                rewards = reward_model.reward([batch_caption[0]], image_paths=[img_save_path])
                scores = [reward[0].item() for reward in rewards]
                score_tensor = torch.tensor([scores[0]], dtype=torch.float32, device=device)
                all_rewards.append(score_tensor)
        
        if args.use_pickscore:
            def calc_probs(processor, model, prompt, images, device):
                # preprocess
                image_inputs = processor(
                    images=images,
                    padding=True,
                    truncation=True,
                    max_length=77,
                    return_tensors="pt",
                ).to(device)
                text_inputs = processor(
                    text=prompt,
                    padding=True,
                    truncation=True,
                    max_length=77,
                    return_tensors="pt",
                ).to(device)
                with torch.no_grad():
                    # embed
                    image_embs = model.get_image_features(**image_inputs)
                    image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
                
                    text_embs = model.get_text_features(**text_inputs)
                    text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
                
                    # score
                    scores = (text_embs @ image_embs.T)[0]
                
                return scores
            score = calc_probs(tokenizer, reward_model, batch_caption, [pil_img_obj], device)
            all_rewards.append(score)

    all_latents = torch.cat(all_latents, dim=0)
    all_log_probs = torch.cat(all_log_probs, dim=0)
    all_rewards = torch.cat(all_rewards, dim=0)
    all_image_ids = torch.stack(all_image_ids, dim=0)
    all_means_ref = torch.cat(all_means_ref, dim=0)
    
    return all_rewards, all_latents, all_log_probs, sigma_schedule, all_image_ids, all_means_ref


def gather_tensor(tensor):
    if not dist.is_initialized():
        return tensor
    world_size = dist.get_world_size()
    gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(gathered_tensors, tensor)
    return torch.cat(gathered_tensors, dim=0)

def train_one_step(
    args,
    device,
    transformer,
    vae,
    reward_model,
    tokenizer,
    optimizer,
    lr_scheduler,
    loader,
    noise_scheduler,
    max_grad_norm,
    preprocess_val,
    ema_handler,
):
    total_loss = 0.0
    avg_log_p_old = 0.0
    avg_log_p_new = 0.0
    avg_ratio = 0.0
    avg_kl = 0.0 # Track KL
    total_steps = 0 
    optimizer.zero_grad()
    (
        encoder_hidden_states, 
        pooled_prompt_embeds, 
        text_ids,
        caption,
    ) = next(loader)
    
    if args.use_group:
        def repeat_tensor(tensor):
            if tensor is None:
                return None
            return torch.repeat_interleave(tensor, args.num_generations, dim=0)

        encoder_hidden_states = repeat_tensor(encoder_hidden_states)
        pooled_prompt_embeds = repeat_tensor(pooled_prompt_embeds)
        text_ids = repeat_tensor(text_ids)

        if isinstance(caption, str):
            caption = [caption] * args.num_generations
        elif isinstance(caption, list):
            caption = [item for item in caption for _ in range(args.num_generations)]
        else:
            raise ValueError(f"Unsupported caption type: {type(caption)}")

    reward, all_latents, all_log_probs, sigma_schedule, all_image_ids, all_means_ref = sample_reference_model(
            args,
            device, 
            transformer,
            vae,
            encoder_hidden_states, 
            pooled_prompt_embeds, 
            text_ids,
            reward_model,
            tokenizer,
            caption,
            preprocess_val,
        )
    batch_size = all_latents.shape[0]
    timestep_value = [int(sigma * 1000) for sigma in sigma_schedule][:args.sampling_steps]
    timestep_values = [timestep_value[:] for _ in range(batch_size)]
    device = all_latents.device
    timesteps =  torch.tensor(timestep_values, device=all_latents.device, dtype=torch.long)

    samples = {
        "timesteps": timesteps.detach().clone()[:, :-1],
        "latents": all_latents[:, :-1][:, :-1],
        "next_latents": all_latents[:, 1:][:, :-1],
        "log_probs": all_log_probs[:, :-1],
        "means_ref": all_means_ref[:, :-1], # Reference Means for KL
        "rewards": reward.to(torch.float32),
        "image_ids": all_image_ids,
        "text_ids": text_ids,
        "encoder_hidden_states": encoder_hidden_states,
        "pooled_prompt_embeds": pooled_prompt_embeds,
    }
    gathered_reward = gather_tensor(samples["rewards"])

    # Advantage calculation
    if args.use_group:
        n = len(samples["rewards"]) // (args.num_generations)
        advantages = torch.zeros_like(samples["rewards"])
        for i in range(n):
            start_idx = i * args.num_generations
            end_idx = (i + 1) * args.num_generations
            group_rewards = samples["rewards"][start_idx:end_idx]
            group_mean = group_rewards.mean()
            group_std = group_rewards.std() + 1e-8
            advantages[start_idx:end_idx] = (group_rewards - group_mean) / group_std
        samples["advantages"] = advantages
    else:
        advantages = (samples["rewards"] - gathered_reward.mean())/(gathered_reward.std()+1e-8)
        samples["advantages"] = advantages

    
    perms = torch.stack(
        [
            torch.randperm(len(samples["timesteps"][0]))
            for _ in range(batch_size)
        ]
    ).to(device) 
    for key in ["timesteps", "latents", "next_latents", "log_probs", "means_ref"]:
        samples[key] = samples[key][
            torch.arange(batch_size).to(device) [:, None],
            perms,
        ]
    samples_batched = {
        k: v.unsqueeze(1)
        for k, v in samples.items()
    }
    
    samples_batched_list = [
        dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
    ]
    train_timesteps = int(len(samples["timesteps"][0])*args.timestep_fraction)
    
    for i,sample in list(enumerate(samples_batched_list)):
        for _ in range(train_timesteps):
            clip_range = args.clip_range
            adv_clip_max = args.adv_clip_max
            
            # --- CHANGED: Return Mean and Std for KL ---
            new_log_probs, prev_sample_mean, std_dev_t = grpo_one_step(
                args,
                sample["latents"][:,_],
                sample["next_latents"][:,_],
                sample["encoder_hidden_states"],
                sample["pooled_prompt_embeds"],
                sample["text_ids"],
                sample["image_ids"],
                transformer,
                sample["timesteps"][:,_],
                perms[i][_],
                sigma_schedule,
            )

            advantages = torch.clamp(
                sample["advantages"],
                -adv_clip_max,
                adv_clip_max,
            )

            old_log_probs = sample["log_probs"][:,_]
            ratio = torch.exp(new_log_probs - old_log_probs)
            
            with torch.no_grad():
                avg_log_p_old += old_log_probs.mean().item()
                avg_log_p_new += new_log_probs.detach().mean().item()
                avg_ratio += ratio.detach().mean().item()
                total_steps += 1

            unclipped_loss = -advantages * ratio
            clipped_loss = -advantages * torch.clamp(
                ratio,
                1.0 - clip_range,
                1.0 + clip_range,
            )
            policy_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
            
            # --- CHANGED: Add KL Divergence Loss ---
            kl_loss = 0.0
            if args.beta > 0:
                # Retrieve reference mean from sampling phase
                prev_sample_mean_ref = sample["means_ref"][:,_].to(torch.float32)
                
                # KL(P || Q) = ((mu_p - mu_q)^2) / (2 * sigma^2)
                # assuming fixed sigma (std_dev_t)
                kl_term = ((prev_sample_mean.to(torch.float32) - prev_sample_mean_ref) ** 2).mean(dim=(1,2), keepdim=True) / (2 * (std_dev_t ** 2 + 1e-12))
                kl_loss = torch.mean(kl_term)
                avg_kl += kl_loss.item()
            
            loss = (policy_loss + args.beta * kl_loss) / (args.gradient_accumulation_steps * train_timesteps)

            loss.backward()
            avg_loss = loss.detach().clone()
            dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
            total_loss += avg_loss.item()
            
        if (i+1)%args.gradient_accumulation_steps==0:
            grad_norm = transformer.clip_grad_norm_(max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

    if total_steps > 0:
        avg_log_p_old /= total_steps
        avg_log_p_new /= total_steps
        avg_ratio /= total_steps
        avg_kl /= total_steps
        
    return total_loss, grad_norm.item(), gathered_reward.mean().item(), gathered_reward.var(unbiased=False).item(), avg_log_p_old, avg_log_p_new, avg_ratio, avg_kl


# --- Modified Sampling Function with TopK mechanism ---
def run_ode_sampling_and_decode_final(
        args,
        progress_bar,
        sigma_schedule,
        transformer,
        encoder_hidden_states, 
        pooled_prompt_embeds, 
        text_ids,
        image_ids, 
        vae,
        reward_model, tokenizer, preprocess_val,
        captions,
        device,
        global_start_idx,
        save_dir=None,
        topk=1
    ):
    IN_CHANNELS = 16
    w, h = args.w, args.h
    SPATIAL_DOWNSAMPLE = 8
    latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE
    world_size = int(os.environ["WORLD_SIZE"])

    rank = dist.get_rank()
    B = encoder_hidden_states.shape[0]
    
    all_final_results = []
    
    for k in range(topk):
        input_latents = torch.randn(
            (B, IN_CHANNELS, latent_h, latent_w),
            device=device,
            dtype=torch.bfloat16,
        )
        input_latents_packed = pack_latents(input_latents, B, IN_CHANNELS, latent_h, latent_w)
        image_processor = VaeImageProcessor(16)
        
        for i in progress_bar:
            sigma = sigma_schedule[i]
            next_sigma = sigma_schedule[i+1]
            
            timestep_value = int(sigma * 1000)
            timesteps = torch.full([B], timestep_value, device=device, dtype=torch.long)
            
            transformer.eval()
            with torch.no_grad():
                with torch.autocast("cuda", torch.bfloat16):
                    model_output = transformer(
                        hidden_states=input_latents_packed,
                        encoder_hidden_states=encoder_hidden_states,
                        timestep=timesteps/1000,
                        guidance=torch.tensor([3.5], device=device, dtype=torch.bfloat16),
                        txt_ids=text_ids.repeat(encoder_hidden_states.shape[1],1),
                        pooled_projections=pooled_prompt_embeds,
                        img_ids=image_ids,
                        return_dict=False,
                    )[0]
                
                # Simple ODE Step (during validation/TopK we often use deterministic ODE or low noise)
                # Matches simple flow step
                dsigma = next_sigma - sigma
                input_latents_packed = input_latents_packed + dsigma * model_output
        
        with torch.inference_mode():
            with torch.autocast("cuda", dtype=torch.bfloat16):
                unpacked_latents = unpack_latents(input_latents_packed, args.h, args.w, 8)
                unpacked_latents = (unpacked_latents / 0.3611) + 0.1159
                decoded_images_tensor = vae.decode(unpacked_latents, return_dict=False)[0]
                decoded_image = image_processor.postprocess(decoded_images_tensor)
        pil_img_obj = decoded_image[0]

        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
        img_save_path = os.path.join(save_dir, f"{global_start_idx + rank}_{rank}.jpg")
        pil_img_obj.save(img_save_path)

        step_rewards = {}
        batch_caption = captions

        if args.use_hpsv2:
            with torch.no_grad():
                image_path = decoded_image[0]
                image = preprocess_val(image_path).unsqueeze(0).to(device=device, non_blocking=True)
                text = tokenizer([batch_caption[0]]).to(device=device, non_blocking=True)
                with torch.amp.autocast('cuda'):
                    outputs = reward_model(image, text)
                    image_features, text_features = outputs["image_features"], outputs["text_features"]
                    logits_per_image = image_features @ text_features.T
                    hps_score = torch.diagonal(logits_per_image)
            step_rewards['hpsv2'] = [hps_score]
        
        if args.use_hpsv3:
            with torch.no_grad():
                rewards = reward_model.reward([batch_caption[0]], image_paths=[img_save_path])
                scores = [reward[0].item() for reward in rewards]
            step_rewards['hpsv3'] = [scores[0]]
        
        if args.use_pickscore:
            def calc_probs(processor, model, prompt, images, device):
                image_inputs = processor(
                    images=images,
                    padding=True,
                    truncation=True,
                    max_length=77,
                    return_tensors="pt",
                ).to(device)
                text_inputs = processor(
                    text=prompt,
                    padding=True,
                    truncation=True,
                    max_length=77,
                    return_tensors="pt",
                ).to(device)
                with torch.no_grad():
                    image_embs = model.get_image_features(**image_inputs)
                    image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
                    text_embs = model.get_text_features(**text_inputs)
                    text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
                    scores = (text_embs @ image_embs.T)[0]
                return scores
            score = calc_probs(tokenizer, reward_model, batch_caption, [pil_img_obj], device)
            step_rewards['pickscore'] = [score]
        
        all_final_results.append({
            'pil_images': [pil_img_obj],
            'rewards': step_rewards,
            'run_id': k
        })
    
    if not all_final_results:
        return {}
    
    selection_model = None
    if args.use_hpsv2 and 'hpsv2' in all_final_results[0]['rewards']:
        selection_model = 'hpsv2'
    elif args.use_hpsv3 and 'hpsv3' in all_final_results[0]['rewards']:
        selection_model = 'hpsv3'
    elif args.use_pickscore and 'pickscore' in all_final_results[0]['rewards']:
        selection_model = 'pickscore'

    if selection_model is None:
        selected_result = all_final_results[0]
    else:
        def get_avg_score(result):
            return sum(result['rewards'][selection_model]) / len(result['rewards'][selection_model])
        all_final_results.sort(key=get_avg_score, reverse=True)
        selected_result = all_final_results[0]
    
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        for b_idx, img in enumerate(selected_result['pil_images']):
            local_id = global_start_idx + b_idx
            save_path = os.path.join(save_dir, f"{local_id:05d}_rank{rank}.jpg")
            img.save(save_path, quality=95)
        captions_file = os.path.join(save_dir, f"captions_rank{rank}.txt")
        with open(captions_file, 'a', encoding='utf-8') as f:
            for b_idx, caption in enumerate(captions):
                local_id = global_start_idx + b_idx
                f.write(f"{local_id:05d}: {caption}\n")
    
    return [r.cpu().item() if torch.is_tensor(r) else r for r in selected_result['rewards'][selection_model]]


@torch.no_grad()
def log_validation(args, step, valid_dataloader, transformer, vae, reward_model, processor, preprocess_val, device):
    w, h = args.w, args.h
    SPATIAL_DOWNSAMPLE = 8
    IN_CHANNELS = 16
    latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE
    valid_samping_steps = 25
    sigma_schedule = torch.linspace(1, 0, valid_samping_steps + 1)
    
    world_size = int(os.environ["WORLD_SIZE"])

    rank = dist.get_rank()
    if rank == 0:
        all_rewards_this_validation = []
        validation_start_time = time.time()

    total_prompts_analyzed = 0
    for batch_idx, batch in enumerate(valid_dataloader):
        (encoder_hidden_states, pooled_prompt_embeds, text_ids, captions) = batch
        current_bs = encoder_hidden_states.shape[0]
        assert current_bs == 1
        input_latents = torch.randn(
            (current_bs, IN_CHANNELS, latent_h, latent_w),
            device=device,
            dtype=torch.bfloat16,
        )

        input_latents_packed = pack_latents(input_latents, current_bs, IN_CHANNELS, latent_h, latent_w)
        image_ids = prepare_latent_image_ids(current_bs, latent_h // 2, latent_w // 2, device, torch.bfloat16)

        progress_bar = tqdm(
            range(0, valid_samping_steps),
            desc=f"Validation step {step} - Batch {batch_idx} Sampling",
            disable=rank != 0,
        )
        step_rewards = run_ode_sampling_and_decode_final(
            args,
            progress_bar,
            sigma_schedule,
            transformer,
            encoder_hidden_states,
            pooled_prompt_embeds,
            text_ids,
            image_ids,
            vae,
            reward_model, processor, preprocess_val, 
            captions,
            device,
            global_start_idx=total_prompts_analyzed,
            save_dir=os.path.join(args.valid_save_path, f"step_{step}_decoded_images"),
        )

        all_step_rewards_list = [None for _ in range(world_size)]
        all_captions_list = [None for _ in range(world_size)]
        if dist.is_initialized():
            dist.all_gather_object(all_step_rewards_list, step_rewards)
            dist.all_gather_object(all_captions_list, captions)
        else:
            all_step_rewards_list = [step_rewards]
            all_captions_list = [captions]

        if rank == 0:
            batch_rewards = []
            for rank_rewards in all_step_rewards_list:
                if isinstance(rank_rewards, list):
                    batch_rewards.extend(rank_rewards)
                elif isinstance(rank_rewards, (int, float)):
                    batch_rewards.append(rank_rewards)
            all_rewards_this_validation.extend(batch_rewards)

        current_step_global_count = 0
        for caps in all_captions_list:
            if isinstance(caps, list):
                current_step_global_count += len(caps)
            elif isinstance(caps, str):
                current_step_global_count += 1

        total_prompts_analyzed += current_step_global_count

        if dist.is_initialized():
            dist.barrier()

    if rank == 0:
        validation_end_time = time.time()
        validation_duration = validation_end_time - validation_start_time

        print(f"Validation at step {step} completed in {validation_duration:.2f} seconds.")

        avg_reward = 0
        if all_rewards_this_validation:
            avg_reward = sum(all_rewards_this_validation) / len(all_rewards_this_validation)
            std_reward = np.std(all_rewards_this_validation) if len(all_rewards_this_validation) > 1 else 0.0

            print(f"  Average Reward: {avg_reward:.4f} ± {std_reward:.4f}")
            print(f"  Total prompts analyzed: {total_prompts_analyzed}")

            results_data = {
                "step": step,
                "avg_reward": avg_reward,
                "std_reward": std_reward,
                "total_prompts": total_prompts_analyzed,
                "duration_seconds": validation_duration,
                "all_rewards": all_rewards_this_validation,
            }

            results_file = os.path.join(args.valid_save_path, f"validation_results_step_{step}.json")
            with open(results_file, "w") as f:
                json.dump(results_data, f, indent=4, ensure_ascii=False)
            print(f"Saved validation results to {results_file}")
        else:
            print("  No rewards were collected during validation.")
        
        return avg_reward
    return None
        

def main(args):
    torch.backends.cuda.matmul.allow_tf32 = True

    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    dist.init_process_group("nccl")
    torch.cuda.set_device(local_rank)
    device = torch.cuda.current_device()
    initialize_sequence_parallel_state(args.sp_size)

    # If passed along, set the training seed now. On GPU...
    if args.seed is not None:
        set_seed(args.seed + rank)

    # Handle the repository creation
    if rank <= 0 and args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    # For mixed precision training we cast all non-trainable weights to half-precision
    preprocess_val = None
    reward_model = None
    processor = None

    # --- Load Reward Models ---
    if args.use_hpsv2:
        from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
        from typing import Union
        import huggingface_hub
        from hpsv2.utils import root_path, hps_version_map
        def initialize_model():
            model_dict = {}
            model, preprocess_train, preprocess_val = create_model_and_transforms(
                'ViT-H-14',
                './hps_ckpt/open_clip_pytorch_model.bin',
                precision='amp',
                device=device,
                jit=False,
                force_quick_gelu=False,
                force_custom_text=False,
                force_patch_dropout=False,
                force_image_size=None,
                pretrained_image=False,
                image_mean=None,
                image_std=None,
                light_augmentation=True,
                aug_cfg={},
                output_dict=True,
                with_score_predictor=False,
                with_region_predictor=False
            )
            model_dict['model'] = model
            model_dict['preprocess_val'] = preprocess_val
            return model_dict
        model_dict = initialize_model()
        model = model_dict['model']
        preprocess_val = model_dict['preprocess_val']
        cp = "./hps_ckpt/HPS_v2.1_compressed.pt"
        checkpoint = torch.load(cp, map_location=f'cuda:{device}')
        model.load_state_dict(checkpoint['state_dict'])
        processor = get_tokenizer('ViT-H-14')
        reward_model = model.to(device)
        reward_model.eval()

    if args.use_hpsv3:
        from hpsv3 import HPSv3RewardInferencer
        # Update paths as per your environment
        inferencer = HPSv3RewardInferencer( 
            config_path="", 
            checkpoint_path="", 
            device=device
            )
        processor = None
        preprocess_val = None
        reward_model = inferencer

    if args.use_pickscore:
        from transformers import AutoProcessor, AutoModel
        processor_name_or_path = "./pretrained_models/CLIP-ViT-H-14-laion2B-s32B-b79K"
        model_pretrained_name_or_path = "./pretrained_models/PickScore_v1"
        processor = AutoProcessor.from_pretrained(processor_name_or_path)
        reward_model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(device)

    main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
    
    # Load Flux Transformer
    transformer = FluxTransformer2DModel.from_pretrained(
            args.pretrained_model_name_or_path,
            subfolder="transformer",
            torch_dtype = torch.float32
    )

    # FSDP Setup
    fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
        transformer,
        args.fsdp_sharding_startegy,
        False,
        args.use_cpu_offload,
        args.master_weight_type,
    )
    
    transformer = FSDP(transformer, **fsdp_kwargs,)

    # EMA Setup
    ema_handler = None
    if args.use_ema:
        ema_handler = FSDP_EMA(transformer, args.ema_decay, rank)

    if args.gradient_checkpointing:
        apply_fsdp_checkpointing(
            transformer, no_split_modules, args.selective_checkpointing
        )
    
    # Load VAE
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="vae",
        torch_dtype = torch.bfloat16,
    ).to(device)

    main_print(
        f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}"
    )
    main_print(f"--> model loaded")

    # Set model as trainable
    transformer.train()

    noise_scheduler = None # Flux uses flow matching, handled manually in loop

    params_to_optimize = transformer.parameters()
    params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))

    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=args.weight_decay,
        eps=1e-8,
    )

    init_steps = 0
    main_print(f"optimizer: {optimizer}")

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=1000000,
        num_cycles=args.lr_num_cycles,
        power=args.lr_power,
        last_epoch=init_steps - 1,
    )

    # --- Dataloaders ---
    train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
    sampler = DistributedSampler(
            train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=args.sampler_seed
        )
    train_dataloader = DataLoader(
        train_dataset,
        sampler=sampler,
        collate_fn=latent_collate_function,
        pin_memory=True,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
        drop_last=True,
    )

    valid_dataset = LatentDataset(args.valid_data_json_path, args.num_latent_t, args.cfg)
    valid_sampler = DistributedSampler(
        valid_dataset, rank=rank, num_replicas=world_size, shuffle=False
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        collate_fn=latent_collate_function,
        pin_memory=True,
        batch_size=1,
        num_workers=1,
        drop_last=False,
    )

    # WandB Setup
    if rank <= 0:
        wandb.init(project=args.wandb_project, config=args, name=args.wandb_name)


    # Train Stats
    total_batch_size = (
        world_size
        * args.gradient_accumulation_steps
        / args.sp_size
        * args.train_sp_batch_size
    )
    main_print("***** Running training *****")
    main_print(f"  Num examples = {len(train_dataset)}")
    main_print(f"  Dataloader size = {len(train_dataloader)}")
    main_print(f"  Instantaneous batch size per device = {args.train_batch_size}")
    main_print(f"  Total train batch size (w. accumulation) = {total_batch_size}")
    main_print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    main_print(f"  Total optimization steps = {args.max_train_steps}")
    main_print(f"  Beta (KL Weight) = {args.beta}")
    main_print(f"  Noise Level = {args.noise_level}")

    progress_bar = tqdm(
        range(0, 100000),
        initial=init_steps,
        desc="Steps",
        disable=local_rank > 0,
    )

    loader = sp_parallel_dataloader_wrapper(
        train_dataloader,
        device,
        args.train_batch_size,
        args.sp_size,
        args.train_sp_batch_size,
    )

    step_times = deque(maxlen=100)

    for epoch in range(1000): # Infinite loop essentially
        if isinstance(sampler, DistributedSampler):
            sampler.set_epoch(epoch)

        for step in range(init_steps+1, args.max_train_steps+1):
            start_time = time.time()
            
            # Save Checkpoint
            if step % args.checkpointing_steps == 0:
                save_checkpoint(transformer, rank, args.output_dir, step, epoch)
                if args.use_ema:
                    save_ema_checkpoint(ema_handler, rank, args.output_dir, step, epoch, dict(transformer.config))
                dist.barrier()
            
            # Validation
            if step % args.validation_steps == 0 or step == 1:
            # if step % args.validation_steps == 0:
                valid_reward_avg = log_validation(
                    args,
                    step,
                    valid_dataloader,
                    transformer,
                    vae,
                    reward_model,
                    processor,
                    preprocess_val,
                    device
                )
                dist.barrier()

            # Train Step
            loss, grad_norm, log_reward, log_var, log_p_old, log_p_new, log_ratio, log_kl = train_one_step(
                args,
                device, 
                transformer,
                vae,
                reward_model,
                processor,
                optimizer,
                lr_scheduler,
                loader,
                noise_scheduler,
                args.max_grad_norm,
                preprocess_val,
                ema_handler,
            )

            if args.use_ema and ema_handler:
                ema_handler.update(transformer)
    
            step_time = time.time() - start_time
            step_times.append(step_time)
            avg_step_time = sum(step_times) / len(step_times)
    
            progress_bar.set_postfix(
                {
                    "loss": f"{loss:.4f}",
                    "grad_norm": f"{grad_norm:.4f}",
                    "reward": f"{log_reward:.4f}",
                    "kl": f"{log_kl:.4f}",
                }
            )
            progress_bar.update(1)
            
            if rank <= 0:
                log_dict = {
                    "train_loss": loss,
                    "learning_rate": lr_scheduler.get_last_lr()[0],
                    "step_time": step_time,
                    "avg_step_time": avg_step_time,
                    "grad_norm": grad_norm,
                    "reward": log_reward,
                    "var": log_var,
                    "log_p_old": log_p_old,
                    "log_p_new": log_p_new,
                    "ratio": log_ratio,
                    "kl_divergence": log_kl, # Log KL for flowGRPO
                }
                if (step % args.validation_steps == 0 or step == 1) and valid_reward_avg is not None:
                    log_dict["valid_reward_avg"] = valid_reward_avg

                wandb.log(log_dict, step=step)

    if get_sequence_parallel_state():
        destroy_sequence_parallel_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # dataset & dataloader
    parser.add_argument("--wandb_project", type=str, required=True)
    parser.add_argument("--data_json_path", type=str, required=True)
    parser.add_argument("--valid_data_json_path", type=str, required=True)
    parser.add_argument("--training_save_path", type=str, required=True)
    parser.add_argument("--valid_save_path", type=str, required=True)
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=10,
        help="Number of subprocesses to use for data loading.",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=16,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--num_latent_t",
        type=int,
        default=1,
        help="number of latent frames",
    )
    # text encoder & vae & diffusion model
    parser.add_argument("--pretrained_model_name_or_path", type=str)
    parser.add_argument("--dit_model_name_or_path", type=str, default=None)
    parser.add_argument("--vae_model_path", type=str, default=None, help="vae model.")
    parser.add_argument("--cache_dir", type=str, default="./cache_dir")

    # diffusion setting
    parser.add_argument("--ema_decay", type=float, default=0.995)
    parser.add_argument("--ema_start_step", type=int, default=0)
    parser.add_argument("--cfg", type=float, default=0.0)
    parser.add_argument(
        "--precondition_outputs",
        action="store_true",
        help="Whether to precondition the outputs of the model.",
    )

    # validation & logs
    parser.add_argument(
        "--seed", type=int, default=None, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help="Save a checkpoint every X updates.",
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
    )

    # optimizer & scheduler & Training
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=10000,
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
    )
    parser.add_argument(
        "--lr_warmup_steps",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--max_grad_norm", default=2.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
    )
    parser.add_argument("--selective_checkpointing", type=float, default=1.0)
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
    )
    parser.add_argument(
        "--use_cpu_offload",
        action="store_true",
    )

    parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
    parser.add_argument(
        "--train_sp_batch_size",
        type=int,
        default=1,
    )

    parser.add_argument("--fsdp_sharding_startegy", default="full")

    # lr_scheduler
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant_with_warmup",
    )
    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--lr_power",
        type=float,
        default=1.0,
    )
    parser.add_argument(
        "--weight_decay", type=float, default=0.01,
    )
    parser.add_argument(
        "--master_weight_type",
        type=str,
        default="fp32",
    )

    # GRPO / flowGRPO training args
    parser.add_argument("--h", type=int, default=512)
    parser.add_argument("--w", type=int, default=512)
    parser.add_argument("--t", type=int, default=1)
    parser.add_argument("--sampling_steps", type=int, default=28)
    parser.add_argument("--sampler_seed", type=int, default=42)
    parser.add_argument("--use_group", action="store_true", default=False)
    parser.add_argument("--num_generations", type=int, default=16)
    
    # Reward Models
    parser.add_argument("--use_hpsv2", action="store_true", default=False)
    parser.add_argument("--use_hpsv3", action="store_true", default=False)
    parser.add_argument("--use_pickscore", action="store_true", default=False)
    
    # flowGRPO specific args
    parser.add_argument(
        "--beta",
        type=float,
        default=0.04,
        help="Coefficient for KL divergence penalty",
    )
    parser.add_argument(
        "--noise_level",
        type=float,
        default=1.0,
        help="Noise level for flow matching SDE (replaces eta)",
    )
    
    # Other
    parser.add_argument("--timestep_fraction", type = float, default=1.0)
    parser.add_argument("--clip_range", type = float, default=1e-4)
    parser.add_argument("--adv_clip_max", type = float, default=5.0)
    parser.add_argument("--use_ema", action="store_true")
    parser.add_argument("--wandb_name", type=str, default="flux_flow_grpo")
    parser.add_argument("--validation_steps", type=int, default=100)

    args = parser.parse_args()
    main(args)