import argparse
import math
import os
from pathlib import Path
from fastvideo.utils.parallel_states import (
    initialize_sequence_parallel_state,
    destroy_sequence_parallel_group,
    get_sequence_parallel_state,
    nccl_info,
)
from fastvideo.utils.communications_flux import sp_parallel_dataloader_wrapper
from torch.utils.data import DataLoader
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_model_state_dict, set_model_state_dict, StateDictOptions

from torch.utils.data.distributed import DistributedSampler
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from accelerate.utils import set_seed
from tqdm.auto import tqdm
from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs, apply_fsdp_checkpointing
from fastvideo.utils.load import load_transformer
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from fastvideo.dataset.latent_flux_rl_datasets import LatentDataset, latent_collate_function
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from fastvideo.utils.checkpoint import (
    save_checkpoint,
    save_lora_checkpoint,
)
from fastvideo.utils.logging_ import main_print
from diffusers.image_processor import VaeImageProcessor

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
import time
from collections import deque
import numpy as np
from einops import rearrange
import torch.distributed as dist
from torch.nn import functional as F
from typing import List
from PIL import Image
from diffusers import FluxTransformer2DModel, AutoencoderKL
from contextlib import contextmanager
from safetensors.torch import save_file

# --- WandB ---
import wandb
import warnings

import tempfile
import json
# 屏蔽包含特定关键词的警告
warnings.filterwarnings("ignore", message=".*cuDNN SDPA backward got grad_output.strides.*")
# --- Helper Functions ---

class FSDP_EMA:
    def __init__(self, model, decay, rank):
        self.decay = decay
        self.rank = rank
        self.ema_state_dict_rank0 = {}
        options = StateDictOptions(full_state_dict=True, cpu_offload=True)
        state_dict = get_model_state_dict(model, options=options)

        if self.rank == 0:
            self.ema_state_dict_rank0 = {k: v.clone() for k, v in state_dict.items()}
            main_print("--> Modern EMA handler initialized on rank 0.")

    def update(self, model):
        options = StateDictOptions(full_state_dict=True, cpu_offload=True)
        model_state_dict = get_model_state_dict(model, options=options)

        if self.rank == 0:
            for key in self.ema_state_dict_rank0:
                if key in model_state_dict:
                    self.ema_state_dict_rank0[key].copy_(
                        self.decay * self.ema_state_dict_rank0[key] + (1 - self.decay) * model_state_dict[key]
                    )

    @contextmanager
    def use_ema_weights(self, model):
        backup_options = StateDictOptions(full_state_dict=True, cpu_offload=True)
        backup_state_dict_rank0 = get_model_state_dict(model, options=backup_options)

        load_options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
        set_model_state_dict(
            model,
            model_state_dict=self.ema_state_dict_rank0, 
            options=load_options
        )
        
        try:
            yield
        finally:
            restore_options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
            set_model_state_dict(
                model,
                model_state_dict=backup_state_dict_rank0, 
                options=restore_options
            )

def save_ema_checkpoint(ema_handler, rank, output_dir, step, epoch, config_dict):
    if rank == 0 and ema_handler is not None:
        ema_checkpoint_path = os.path.join(output_dir, f"checkpoint-ema-{step}-{epoch}")
        os.makedirs(ema_checkpoint_path, exist_ok=True)
        weight_path = os.path.join(ema_checkpoint_path, "diffusion_pytorch_model.safetensors")
        save_file(ema_handler.ema_state_dict_rank0, weight_path)
        if "dtype" in config_dict:
            del config_dict["dtype"]  # TODO
        config_path = os.path.join(ema_checkpoint_path, "config.json")
        # save dict as json
        import json
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)
        #torch.save(ema_handler.ema_state_dict_rank0, os.path.join(ema_checkpoint_path, "ema_model.pt"))
        main_print(f"--> EMA checkpoint saved at {ema_checkpoint_path}")

def sd3_time_shift(shift, t):
    return (shift * t) / (1 + (shift - 1) * t)

def pack_latents(latents, batch_size, num_channels_latents, height, width):
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
    return latents

def unpack_latents(latents, height, width, vae_scale_factor):
    batch_size, num_patches, channels = latents.shape
    height = 2 * (int(height) // (vae_scale_factor * 2))
    width = 2 * (int(width) // (vae_scale_factor * 2))
    latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
    latents = latents.permute(0, 3, 1, 4, 2, 5)
    latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
    return latents

def assert_eq(x, y, msg=None):
    assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}"

def prepare_latent_image_ids(batch_size, height, width, device, dtype):
    latent_image_ids = torch.zeros(height, width, 3)
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    )
    return latent_image_ids.to(device=device, dtype=dtype)

# --------------------- Ours

class StatsTracker:
    def __init__(self, device):
        self.device = device
        self.metrics = {}  # {key: {'sum': tensor, 'sq_sum': tensor, 'count': tensor}}

    def update(self, key, value):
        """
        value 可以是 scalar，也可以是 tensor。
        会自动 detach 并转为 float32 计算。
        """
        if key not in self.metrics:
            self.metrics[key] = {
                'sum': torch.tensor(0.0, device=self.device, dtype=torch.float32),
                'sq_sum': torch.tensor(0.0, device=self.device, dtype=torch.float32),
                'count': torch.tensor(0.0, device=self.device, dtype=torch.float32)
            }
        
        # 确保 value 是 tensor
        if not torch.is_tensor(value):
            v = torch.tensor(value, device=self.device, dtype=torch.float32)
        else:
            v = value.detach().to(self.device, dtype=torch.float32)

        # 累加逻辑
        if v.numel() > 1:
            self.metrics[key]['sum'] += v.sum()
            self.metrics[key]['sq_sum'] += (v ** 2).sum()
            self.metrics[key]['count'] += v.numel()
        else:
            self.metrics[key]['sum'] += v
            self.metrics[key]['sq_sum'] += (v ** 2)
            self.metrics[key]['count'] += 1.0

    def reduce_and_summary(self):
        """
        执行分布式 all_reduce 并计算 mean 和 var。
        返回一个字典: {key_mean: float, key_var: float}
        """
        summary = {}
        # 将所有统计量打包成一个 Tensor 进行一次通信，减少 overhead
        # 结构: [sum_1, sq_sum_1, cnt_1, sum_2, sq_sum_2, cnt_2, ...]
        flat_data = []
        keys = sorted(self.metrics.keys())
        
        for k in keys:
            m = self.metrics[k]
            flat_data.extend([m['sum'], m['sq_sum'], m['count']])
            
        if not flat_data:
            return summary
            
        buffer = torch.stack(flat_data)
        
        # 分布式规约
        if dist.is_initialized():
            dist.all_reduce(buffer, op=dist.ReduceOp.SUM)
            
        # 解包并计算统计量
        idx = 0
        for k in keys:
            total_sum = buffer[idx].item()
            total_sq_sum = buffer[idx+1].item()
            total_count = buffer[idx+2].item()
            idx += 3
            
            safe_count = max(1.0, total_count)
            mean = total_sum / safe_count
            # Var[X] = E[X^2] - (E[X])^2
            var = (total_sq_sum / safe_count) - (mean ** 2)
            
            summary[f"{k}_mean"] = mean
            summary[f"{k}_var"] = max(0.0, var) # 防止精度误差导致负数
            
        return summary

def get_constant_budget_M(base_M, step_idx, total_steps):
    # return base_M
    """
    Constant Compute Budget Scheduler.
    保持总 NFE 不变，将计算量均摊到每一步。
    
    Total Cost = base_M * sum(1..total-1)
    Per Step Budget = Total Cost / (total-1)
                    = base_M * total / 2
    
    M_t = Budget / remaining_rollout_steps
    """
    # 1. 计算当前这一步还需要跑多少步 ODE (不包括当前SDE这一步)
    # total_steps=16, step_idx=0 -> rollout 1~15 -> len=15
    # step_idx=14 -> rollout 15~15 -> len=1
    remaining_rollout_steps = max(1, total_steps - 1 - step_idx)
    
    # 2. 计算每一步的恒定预算 (Budget per step)
    # 公式: Budget = M_static * N / 2
    # 这里用 float 计算避免精度损失
    per_step_budget = base_M * total_steps / 2.0
    
    # 3. 计算当前的 M
    target_M = per_step_budget / remaining_rollout_steps
    
    # 4. 取偶数 (Nearest Even)
    M = int(round(target_M))
    if M % 2 != 0:
        M += 1 # 或者 M -= 1，看你想稍微超一点还是省一点，+1 更稳健
        
    return max(2, M)

# --- Core Logic: Flux Step (SDE & ODE) ---
def flux_step(
    model_output: torch.Tensor,
    latents: torch.Tensor,
    eta: float,
    sigma: float,
    next_sigma: float,
    prev_sample: torch.Tensor = None, # If None, we sample. If provided, we compute log_prob.
    sde_solver: bool = True,
    noise: torch.Tensor = None, # [新增] 支持传入外部噪声
):
    """
    Performs a single step update (ODE or SDE) and computes log_prob if needed.
    """
    # Euler update components
    dsigma = next_sigma - sigma
    
    # ODE Prediction (Deterministic Mean)
    # x_{t-1} = x_t + (sigma_{next} - sigma) * v_t
    prev_sample_mean = latents + dsigma * model_output

    # Auxiliary predictions for Score calculation
    # x0_pred = x_t - sigma * v_t
    pred_original_sample = latents - sigma * model_output
    
    # SDE Noise Parameters
    delta_t = sigma - next_sigma # positive step size
    # DanceGRPO logic: std_dev = eta * sqrt(dt)
    # Note: Ensure eta > 0 for SDE
    std_dev_t = eta * math.sqrt(delta_t) if delta_t > 0 else 0.0

    # SDE Drift Correction (Score-based)
    # Correct the drift to maintain marginal distribution under noise injection
    if sde_solver and std_dev_t > 0:
        # score = (x_t - x_0) / sigma^2 (Tweedie) -> BUT Flux v = x1 - x0.
        # Let's stick to the DanceGRPO formula which worked for Flux
        # score ~ -(x_t - x_0_pred) / sigma^2 ? 
        # Actually in Flux v-pred (v=x1-x0, xt = (1-t)x0+tx1):
        # score = - (xt - x0) / sigma^2 is standard Gaussian score if x0 is Dirac.
        # Let's use the formula from the provided snippet which matches Eq.9 in Flow-GRPO
        score_estimate = -(latents - pred_original_sample * (1 - sigma)) / (sigma**2 + 1e-6)
        log_term = -0.5 * (eta**2) * score_estimate
        prev_sample_mean = prev_sample_mean + log_term * dsigma

    # Sampling / Evaluation
    if prev_sample is None:
        # Action: Sample new state
        if std_dev_t > 0:
            if noise is None:
                noise = torch.randn_like(prev_sample_mean)
            prev_sample = prev_sample_mean + noise * std_dev_t
        else:
            prev_sample = prev_sample_mean # Fallback to ODE if sigma=0
        
        # In rollout phase, we don't strictly need log_prob, but good to have structure
        log_prob = None 
        
    else:
        # Training: Evaluate probability of existing action
        # Log Normal(x; mu, sigma)
        if std_dev_t > 0:
            var = std_dev_t ** 2
            log_prob = (
                -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * var)
                - math.log(std_dev_t) 
                - 0.5 * math.log(2 * math.pi)
            )
            # Mean over dimensions [C, H, W]
            log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
        else:
            # Deterministic step probability is ill-defined (Dirac), return 0 or handle separately
            log_prob = torch.zeros((prev_sample.shape[0],), device=prev_sample.device)

    return prev_sample, pred_original_sample, log_prob

def get_reverse_sde_noise(args, current_beam_width, M_branch, latent_dim_size, data_dtype, device):
    noise_combined_list = []
    total_dim = math.prod(latent_dim_size)

    for i in range(current_beam_width):
        if args.noise_type == "orth":
            half_M = M_branch // 2
            random_mat = torch.randn((total_dim, half_M), device=device, dtype=torch.float32)
            q, _ = torch.linalg.qr(random_mat, mode='reduced')
            norms = torch.linalg.vector_norm(random_mat, dim=0) # Shape: [half_M]
            noise_half_flat = q * norms.unsqueeze(0)
            noise_half_flat = noise_half_flat.t()
            noise_half = noise_half_flat.view(half_M, *latent_dim_size).to(dtype=data_dtype)
            noise_mirror = -noise_half
            noise_b = torch.cat([noise_half, noise_mirror], dim=0)
            
        elif args.noise_type == "flip":
            half_M = M_branch // 2
            noise_half = torch.randn((half_M, *latent_dim_size), device=device, dtype=data_dtype)
            noise_mirror = -noise_half
            noise_b = torch.cat([noise_half, noise_mirror], dim=0)
        else:
            noise_b = torch.randn((M_branch, *latent_dim_size), device=device, dtype=data_dtype)
            
        noise_combined_list.append(noise_b)

    noise_combined = torch.cat(noise_combined_list, dim=0) # [Total_Candidates, ...]
    return noise_combined

def decode_and_reward(args, vae, image_processor, reward_model, scout_latents, orig_caption, device):
    with torch.autocast("cuda", dtype=torch.bfloat16):
        unpacked_s = unpack_latents(scout_latents, args.h, args.w, 8)
        unpacked_s = (unpacked_s / 0.3611) + 0.1159
        decoded_s = vae.decode(unpacked_s, return_dict=False)[0]
        pil_imgs_s = image_processor.postprocess(decoded_s)
    
    # 计算 Baseline Reward
    rewards_baseline = compute_rewards(args, reward_model, pil_imgs_s, [orig_caption], device)
    return rewards_baseline, pil_imgs_s

def compute_rewards(args, reward_models, pil_images, captions, device):
    scores = []
    
    # --- HPSv2 ---
    if args.use_hpsv2:
        if 'hpsv2' not in reward_models:
            raise ValueError("HPSv2 requested but not initialized in reward_models")

        # 从字典解包
        hps_model = reward_models['hpsv2']['model']
        hps_preprocess = reward_models['hpsv2']['preprocess']
        hps_tokenizer = reward_models['hpsv2']['tokenizer']
        
        with torch.no_grad():
            # 处理图片
            processed_images = torch.stack([hps_preprocess(img) for img in pil_images]).to(device)
            # 处理文本
            text_inputs = hps_tokenizer(captions).to(device)
            
            with torch.amp.autocast('cuda'):
                outputs = hps_model(processed_images, text_inputs)
                logits = outputs["image_features"] @ outputs["text_features"].T
                # 取对角线作为分数
                scores = torch.diagonal(logits)
    
    # --- HPSv3 ---
    if args.use_hpsv3:
        if 'hpsv3' not in reward_models:
            raise ValueError("HPSv3 requested but not initialized in reward_models")

        # 从字典解包
        hps_model = reward_models['hpsv3']['model']

        save_dir = args.hpsv3_save_dir

        rank = dist.get_rank() if dist.is_initialized() else 0
        ts = int(time.time() * 1000)  # 毫秒时间戳
        unique_filename = f"rk{rank}_ts{ts}.jpg"

        os.makedirs(save_dir, exist_ok=True)
        img_save_path = os.path.join(save_dir, unique_filename)
    
        pil_img_obj = pil_images[0]
        pil_img_obj.save(img_save_path)
        with torch.no_grad():
            rewards = hps_model.reward(captions, image_paths=[img_save_path])
            scores = torch.tensor([rewards[0][0].item()], dtype=torch.float32, device=device)

    # --- PickScore ---
    elif args.use_pickscore:
        if 'pickscore' not in reward_models:
            raise ValueError("PickScore requested but not initialized in reward_models")

        # 从字典解包
        pick_model = reward_models['pickscore']['model']
        pick_proc = reward_models['pickscore']['processor']
        
        # 预处理
        image_inputs = pick_proc(images=pil_images, padding=True, truncation=True, max_length=77, return_tensors="pt").to(device)
        text_inputs = pick_proc(text=captions, padding=True, truncation=True, max_length=77, return_tensors="pt").to(device)
        
        with torch.no_grad():
            img_emb = pick_model.get_image_features(**image_inputs)
            img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
            
            txt_emb = pick_model.get_text_features(**text_inputs)
            txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
            
            # 计算相似度
            scores = (txt_emb @ img_emb.T).diagonal()
            
    return scores.float() # Return [B] tensor


def get_reward_models(args, device):
    reward_models = dict()

    if args.use_hpsv2:
        from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
        from typing import Union
        import huggingface_hub
        from hpsv2.utils import root_path, hps_version_map

        def initialize_model():
            model_dict = {}
            model, preprocess_train, preprocess_val = create_model_and_transforms(
                'ViT-H-14',
                './hps_ckpt/open_clip_pytorch_model.bin',
                precision='amp',
                device=device,
                jit=False,
                force_quick_gelu=False,
                force_custom_text=False,
                force_patch_dropout=False,
                force_image_size=None,
                pretrained_image=False,
                image_mean=None,
                image_std=None,
                light_augmentation=True,
                aug_cfg={},
                output_dict=True,
                with_score_predictor=False,
                with_region_predictor=False
            )
            model_dict['model'] = model
            model_dict['preprocess_val'] = preprocess_val
            return model_dict

        # 加载基础模型结构
        model_dict = initialize_model()
        model = model_dict['model']
        preprocess_val = model_dict['preprocess_val']
        
        # 加载权重
        cp = "./hps_ckpt/HPS_v2.1_compressed.pt"
        checkpoint = torch.load(cp, map_location=f'cuda:{device}')
        model.load_state_dict(checkpoint['state_dict'])
        
        # 获取 Tokenizer
        hps_tokenizer = get_tokenizer('ViT-H-14')
        
        # 移动到设备并设为评估模式
        model = model.to(device).eval()
        
        # [关键修改] 存入字典
        reward_models['hpsv2'] = {
            'model': model,
            'preprocess': preprocess_val,
            'tokenizer': hps_tokenizer
        }
        tokenizer = hps_tokenizer
        main_print(f"--> HPSv2 loaded")

    
    if args.use_hpsv3:
        from hpsv3 import HPSv3RewardInferencer
        inferencer = HPSv3RewardInferencer(
                        config_path="", 
                        checkpoint_path="", 
                        device=device)
        reward_models['hpsv3'] = {
            'model': inferencer,
            'preprocess': None,
            'tokenizer': None
        }
        main_print(f"--> HPSv3 loaded")

    # [修改] PickScore 初始化逻辑 (保持字典结构一致)
    if args.use_pickscore:
        from transformers import AutoProcessor, AutoModel
        processor_name_or_path = "./pretrained_models/CLIP-ViT-H-14-laion2B-s32B-b79K"
        model_pretrained_name_or_path = "./pretrained_models/PickScore_v1"
        
        try:
            processor = AutoProcessor.from_pretrained(processor_name_or_path)
            model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(device)
            
            # [关键修改] 存入字典
            reward_models['pickscore'] = {
                'model': model,
                'processor': processor
            }
            main_print(f"--> PickScore loaded")
        except Exception as e:
            main_print(f"Failed to load PickScore: {e}")
    

    if args.use_video_align:
        from transformers import AutoProcessor, AutoModel
        
        from utils.rm_video_align_wrapper import VideoAlignWrapper
        video_align_path = ""
        reward_model = VideoAlignWrapper(video_align_path, device)
        reward_model.inferencer.model.requires_grad_(False)
        reward_models['video_align']  = {
            'model': reward_model,
        }

    return reward_models


def rollout_and_reward(
    # --- Context Args ---
    args, device, transformer, vae, image_processor, reward_model,
    # --- Rollout State ---
    latents, # [1, C, H, W]
    start_step_idx, # 从哪一步开始跑 (0 ~ sampling_steps-1)
    full_sigma_schedule,
    # --- Condition ---
    encoder_hidden_states,
    txt_ids,
    pooled_projections,
    img_ids,
    caption, # String
    sampling_steps
):
    """
    执行 Deterministic ODE Rollout 直到结束，然后解码并打分。
    Args:
        start_step_idx: 当前 latents 对应的时间步索引。如果是 0，表示从头开始；如果是 t，表示从 t 开始跑到 end。
    """
    curr_latents = latents.clone()
    
    # 1. ODE Rollout Loop
    if start_step_idx < sampling_steps:
        for i in range(start_step_idx, sampling_steps):
            sigma = full_sigma_schedule[i]
            next_sigma = full_sigma_schedule[i+1]
            t_val = int(sigma * 1000)
            t_tensor = torch.full([1], t_val, device=device, dtype=torch.long)
            
            with torch.autocast("cuda", torch.bfloat16):
                model_output = transformer(
                    hidden_states=curr_latents,
                    encoder_hidden_states=encoder_hidden_states,
                    timestep=t_tensor/1000,
                    guidance=torch.tensor([3.5], device=device, dtype=torch.bfloat16),
                    txt_ids=txt_ids,
                    pooled_projections=pooled_projections,
                    img_ids=img_ids,
                    return_dict=False,
                )[0]
            
            # Euler Step
            dsigma = next_sigma - sigma
            curr_latents = curr_latents + dsigma * model_output

    # 2. Decode & Reward
    reward, pil_imgs = decode_and_reward(args, vae, image_processor, reward_model, curr_latents, caption, device)
    
    return reward, pil_imgs


def select_next_candidates(
    all_rewards: torch.Tensor,  # [Total_Candidates]
    target_K: int,              # 需要选多少个
    strategy: str = "best",     # "best", "softmax", "random"
    temperature: float = 1.0,   # 仅用于 softmax
):
    """
    从所有候选者中选择 target_K 个作为下一代的父节点。
    """
    total_candidates = all_rewards.shape[0]
    
    # 边界情况：如果候选数量不足，直接全选
    if total_candidates <= target_K:
        return torch.arange(total_candidates, device=all_rewards.device)

    # 1. Best (Truncation Selection) - 贪婪选择 Top-K
    if strategy == "best":
        _, top_indices = torch.topk(all_rewards, target_K)
        return top_indices

    # 2. Softmax (Boltzmann Selection) - 概率选择
    # P(i) = exp(R_i / T) / sum(...)
    elif strategy == "softmax":
        scaled_logits = all_rewards / temperature
        probs = F.softmax(scaled_logits, dim=0)
        
        # 使用 multinomial 进行无放回采样 (replacement=False), 这样保证选出的 K 个是不同的
        selected_indices = torch.multinomial(probs, target_K, replacement=False)
        return selected_indices

    # 3. Tournament (锦标赛选择) - 进化策略常用
    # 每次随机选 N 个，取其中最好的 1 个，重复 K 次
    elif strategy == "tournament":
        tournament_size = 3 # 经典值
        selected_indices = []
        # 我们需要选 K 个，所以循环 K 次
        # 注意：这种方式可能会选出重复的个体。如果不想重复，需要维护 mask。
        # 这里为了简单实现“无放回”逻辑：
        
        indices_pool = torch.arange(total_candidates, device=all_rewards.device)
        rewards_pool = all_rewards.clone()
        
        for _ in range(target_K):
            # 随机抽样 tournament_size 个索引
            perm = torch.randperm(len(indices_pool), device=all_rewards.device)[:tournament_size]
            contestants_indices = indices_pool[perm]
            contestants_rewards = rewards_pool[perm]
            
            # 找出这几个里面最好的
            winner_local_idx = torch.argmax(contestants_rewards)
            winner_global_idx = contestants_indices[winner_local_idx]
            
            selected_indices.append(winner_global_idx)
            
            # 从池子中移除胜者，避免下次重复选中 (无放回)
            # 这里的效率略低，但对于 K 不大的情况可以接受
            mask = torch.ones(len(indices_pool), dtype=torch.bool, device=all_rewards.device)
            mask[perm[winner_local_idx]] = False
            indices_pool = indices_pool[mask]
            rewards_pool = rewards_pool[mask]
            
        return torch.stack(selected_indices)

    # 4. Random (Baseline)
    elif strategy == "random":
        # 随机无放回采样
        perm = torch.randperm(total_candidates, device=all_rewards.device)
        return perm[:target_K]

    else:
        raise ValueError(f"Unknown selection strategy: {strategy}")

@torch.no_grad()
def log_validation(args, step, valid_dataloader, transformer, vae, reward_models, device):
    image_processor = VaeImageProcessor(16)
    w, h = args.w, args.h
    SPATIAL_DOWNSAMPLE = 8
    IN_CHANNELS = 16
    latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE
    
    valid_sampling_steps = 25
    sigma_schedule = torch.linspace(1, 0, valid_sampling_steps + 1)
    sigma_schedule = sd3_time_shift(args.shift, sigma_schedule)
    
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    rank = dist.get_rank() if dist.is_initialized() else 0

    if rank == 0:
        all_rewards_this_validation = []
        validation_start_time = time.time()

    total_prompts_analyzed = 0
    
    valid_save_path = args.valid_save_path
    step_save_dir = os.path.join(valid_save_path, f"step_{step}_images")
    if rank == 0:
        os.makedirs(step_save_dir, exist_ok=True)

    for batch_idx, batch in enumerate(tqdm(valid_dataloader, desc=f"Valid Step {step}", disable=(rank != 0))):
        (encoder_hidden_states, pooled_prompt_embeds, text_ids, captions) = batch
        current_bs = encoder_hidden_states.shape[0]
        assert current_bs == 1

        # Prepare latents and IDs
        input_latents = torch.randn(
            (current_bs, IN_CHANNELS, latent_h, latent_w),
            device=device,
            dtype=torch.bfloat16,
        )
        input_latents_packed = pack_latents(input_latents, current_bs, IN_CHANNELS, latent_h, latent_w)
        image_ids = prepare_latent_image_ids(current_bs, latent_h // 2, latent_w // 2, device, torch.bfloat16)
        cond_txt = text_ids.repeat(encoder_hidden_states.shape[1], 1)

        reward_val, pil_imgs = rollout_and_reward(
            args, device, transformer, vae, image_processor, reward_models,
            input_latents_packed, 
            0, # start_step_idx
            sigma_schedule,
            encoder_hidden_states,
            cond_txt,
            pooled_prompt_embeds,
            image_ids,
            captions[0], # Single caption string
            valid_sampling_steps # [修改] 传入 steps
        )
        
        img_save_path = os.path.join(step_save_dir, f"rank{rank}_batch{batch_idx}_{reward_val.item():.4f}.jpg")
        pil_imgs[0].save(img_save_path)

        # 4. 收集分数 (All Gather)
        step_reward_scalar = reward_val.item()
        
        all_rewards_list = [None for _ in range(world_size)]
        if dist.is_initialized():
            dist.all_gather_object(all_rewards_list, step_reward_scalar)
        else:
            all_rewards_list = [step_reward_scalar]

        if rank == 0:
            for r in all_rewards_list:
                all_rewards_this_validation.append(r)

        total_prompts_analyzed += world_size

    # 5. 统计与日志 (Rank 0 Only)
    if rank == 0:
        validation_end_time = time.time()
        validation_duration = validation_end_time - validation_start_time

        avg_reward = 0.0
        std_reward = 0.0
        if all_rewards_this_validation:
            avg_reward = sum(all_rewards_this_validation) / len(all_rewards_this_validation)
            std_reward = float(np.std(all_rewards_this_validation))

            # Save JSON log
            results_data = {
                "step": step,
                "avg_reward": avg_reward,
                "std_reward": std_reward,
                "total_prompts": total_prompts_analyzed,
                "duration": validation_duration,
                "rewards": all_rewards_this_validation
            }
            with open(os.path.join(valid_save_path, f"val_results_step_{step}.json"), "w") as f:
                json.dump(results_data, f, indent=4)
                
        return avg_reward
    
    return 0.0


def train_dense_reward_step(
    args,
    device,
    transformer,
    vae,
    reward_model,
    optimizer,
    lr_scheduler,
    loader,
    max_grad_norm,
    ema_handler,
    global_step_idx=0,
):
    image_processor = VaeImageProcessor(16)
    save_root = os.path.join(args.training_save_path, "images_vis")
    os.makedirs(save_root, exist_ok=True)
    
    # --- 日志累加器 ---
    tracker = StatsTracker(device)
    
    # 1. Load Data
    batch_data = next(loader)
        
    (encoder_hidden_states, pooled_prompt_embeds, text_ids, captions) = batch_data
    B = encoder_hidden_states.shape[0]
    assert B == 1
    
    # 2. Prepare Initial Latents & Schedule
    w, h = args.w, args.h
    latents = torch.randn((B, 16, h//8, w//8), device=device, dtype=torch.bfloat16)
    packed_latents = pack_latents(latents, B, 16, h//8, w//8)
    image_ids = prepare_latent_image_ids(B, h // 16, w // 16, device, torch.bfloat16)
    
    full_sigma_schedule = torch.linspace(1, 0, args.sampling_steps + 1)
    full_sigma_schedule = sd3_time_shift(args.shift, full_sigma_schedule)

    # Branching Factor
    BASE_M = args.M # 这里的 args.M 对应上面的 24
    
    # 切换为训练模式
    transformer.train()
    
    # 原始 Condition (Batch Size = 1)
    orig_enc_hidden = encoder_hidden_states
    orig_pooled = pooled_prompt_embeds
    orig_txt_ids = text_ids
    orig_caption = captions[0] # String

    # 初始化当前潜在变量
    curr_latents = packed_latents # 从 [1, L, C] 变为 [BeamWidth, L, C]
    current_beam_width = 1

    # ============================================================
    #  Baseline ODE Scout (Pre-Check) 在开始任何分支探索之前，先跑一次标准的 ODE 推理
    # ============================================================
    with torch.no_grad():
        # 直接调用封装函数，从第 0 步开始跑到底
        rewards_baseline, _ = rollout_and_reward(
            args, device, transformer, vae, image_processor, reward_model,
            curr_latents, # 初始 Latents
            0, # start_step_idx = 0
            full_sigma_schedule,
            orig_enc_hidden,
            orig_txt_ids.repeat(orig_enc_hidden.shape[1], 1),
            orig_pooled,
            image_ids,
            orig_caption,
            args.sampling_steps
        )
        tracker.update("ode_sample_reward", rewards_baseline.sum().item())


    # ============================================================
    # Main Loop: Beam Search Trajectory Optimization
    # ============================================================
    
    for step_idx in range(args.sampling_steps-1):
        total_budget_M = get_constant_budget_M(BASE_M, step_idx, args.sampling_steps)
        M_branch = max(2, total_budget_M // current_beam_width)
        if M_branch % 2 != 0: M_branch += 1

        if dist.get_rank() == 0 and global_step_idx % 100 == 0:
            print(f"Step {step_idx}: M={M_branch} current_beam_width={current_beam_width}")

        sigma = full_sigma_schedule[step_idx]
        next_sigma = full_sigma_schedule[step_idx + 1]
        
        timestep_val = int(sigma * 1000)
        ts_tensor = torch.full([1], timestep_val, device=device, dtype=torch.long)
        timestep_val = int(sigma * 1000)
        t_tensor_single = torch.full([1], timestep_val, device=device, dtype=torch.long)

        # 准备用于 Step D 选择的容器 (只存 Detached Tensors)
        all_candidates_list = [] 
        all_rewards_list = []
        
        # 动态 Eta 计算 (Decay Strategy)
        progress = step_idx / max(1, args.sampling_steps) 
        min_eta_ratio = 0.1 
        decay_factor = (1.0 - progress) * (1.0 - min_eta_ratio) + min_eta_ratio
        current_eta = decay_factor ** args.decay_factor_norm * args.eta

        cond_txt_input = orig_txt_ids.repeat(orig_enc_hidden.shape[1], 1) # [L_txt, 3]
        
        # 遍历每一个 Parent (Group)
        for beam_i in range(current_beam_width):
            latent_parent = curr_latents[beam_i:beam_i+1] # [1, L, C]

            # 这里的 drift_parent 带有梯度，是计算图的起点
            with torch.autocast("cuda", torch.bfloat16):
                drift_parent = transformer(
                    hidden_states=latent_parent,
                    encoder_hidden_states=orig_enc_hidden,
                    timestep=t_tensor_single/1000,
                    guidance=torch.tensor([3.5], device=device, dtype=torch.bfloat16),
                    txt_ids=cond_txt_input,
                    pooled_projections=orig_pooled,
                    img_ids=image_ids,
                    return_dict=False,
                )[0]

            # 扩展 (此时 drift_expanded 依然带有梯度)
            drift_expanded = drift_parent.repeat_interleave(M_branch, dim=0) 
            latent_expanded = latent_parent.repeat_interleave(M_branch, dim=0)
            
            # 生成结构化噪声 (噪声生成逻辑封装)
            noise_b = get_reverse_sde_noise(args, 1, M_branch, latent_parent.shape[1:], drift_expanded.dtype, device) # beam=1 here
            
            # SDE Sampling (No Grad for sampling next state)
            with torch.no_grad():
                candidates_group, _, _ = flux_step(
                    drift_expanded, 
                    latent_expanded, 
                    current_eta, 
                    sigma, 
                    next_sigma, 
                    prev_sample=None, 
                    sde_solver=True,
                    noise=noise_b
                )

                # Rollout Evaluation (Serial)
                rewards_group_list = []
                next_start_idx = step_idx + 1
                
                for k in range(M_branch):
                    curr_candidate = candidates_group[k:k+1] # [1, ...]
                    # [重构] 调用封装函数，从 next_start_idx 跑到底
                    r_val, pil_imgs = rollout_and_reward(
                        args, device, transformer, vae, image_processor, reward_model,
                        curr_candidate,
                        next_start_idx, # 从这一步继续往下跑
                        full_sigma_schedule,
                        orig_enc_hidden,
                        cond_txt_input, # 注意这里是用扩展过的 txt_ids
                        orig_pooled,
                        image_ids,
                        orig_caption,
                        args.sampling_steps
                    )
                    rewards_group_list.append(r_val)

                    # Debug Save
                    if args.save_images:
                        step_dir = os.path.join(save_root, f"step_{global_step_idx:03d}_t{step_idx}")
                        os.makedirs(step_dir, exist_ok=True)
                        save_path = os.path.join(step_dir, f"rank{dist.get_rank()}_b{beam_i}_k{k}_{r_val.item():.4f}.jpg")
                        pil_imgs[0].save(save_path, quality=95)

                rewards_group = torch.cat(rewards_group_list, dim=0) # [M_branch]

                # [关键] 存入 all_lists 供 Step D 使用 (必须 detach)
                all_candidates_list.append(candidates_group.detach())
                all_rewards_list.append(rewards_group)
                tracker.update("sample_reward", rewards_group)

            # ------------------------------------------------------------
            # Step C: RL Update (Gradient Calculation)
            # ------------------------------------------------------------
            mean = rewards_group.mean()
            std = rewards_group.std() + 1e-8
            if args.advantage_type == "rank":
                ranks = torch.argsort(torch.argsort(rewards_group))
                adv_group = ((ranks.float() / (M_branch - 1)) - 0.5)
            elif args.advantage_type == "group":
                adv_group = (rewards_group - mean) / std
            else:
                raise ValueError(f"Unsupported advantage_type: {args.advantage_type!r}. ")

            if args.clip_adv:
                adv_group = torch.clamp(adv_group, -args.adv_clip_max, args.adv_clip_max)
                
            # ------------------------------------------------------------
            # 日志记录 (Advantage Pair)
            # ------------------------------------------------------------
            adv_grouped_for_log = adv_group.view(1, M_branch)
            adv_pair_log_path = f"{args.training_save_path}/advantage_pairs_rank_{dist.get_rank()}.txt"
            with open(adv_pair_log_path, "a") as f_adv:
                f_adv.write(f"GlobalStep: {global_step_idx}, T: {step_idx}, Type: {args.noise_type}\n")
                f_adv.write(f"Pair_Idx, Adv_Pos, Adv_Neg, Diff\n")
                
                half_M = M_branch // 2
                parent_0_adv = adv_grouped_for_log[0] 
    
                for k in range(half_M):
                    p_val = parent_0_adv[k].item()
                    n_val = parent_0_adv[k + half_M].item()
                    f_adv.write(f"{k}, {p_val:.4f}, {n_val:.4f}, {p_val-n_val:.4f}\n")
                f_adv.write("-" * 20 + "\n")
            
            # 4. Re-compute Log Prob & Backward
            
            _, _, log_prob_group = flux_step(
                drift_expanded, 
                latent_expanded, 
                current_eta, 
                sigma, 
                next_sigma, 
                prev_sample=candidates_group, # Detached candidates
                sde_solver=True
            )

            adv_group = adv_group.to(log_prob_group.device, dtype=log_prob_group.dtype)
            # Loss Calculation
            loss_group = -adv_group * torch.exp(log_prob_group - log_prob_group.detach())
            loss_group = loss_group.mean()
            
            # Loss Scaling
            scale_factor = 1.0 / ( (args.sampling_steps - 1) * current_beam_width )
            loss_to_backward = loss_group * scale_factor

            # 4. Backward & Free Graph
            loss_to_backward.backward()

            tracker.update("loss", loss_to_backward.item()) # 记录原始 Loss
            tracker.update("log_p", log_prob_group.mean())

            # # ============================================================
            # # Optimizer Step & Aggregation
            # # ============================================================
            # grad_norm_val = transformer.clip_grad_norm_(max_grad_norm).item()
            # optimizer.step()
            # lr_scheduler.step()
            # optimizer.zero_grad()
            
            # tracker.update("grad_norm", grad_norm_val)
            
            # 显式删除计算图相关的变量，确保显存释放
            del drift_parent, drift_expanded, log_prob_group, loss_group, loss_to_backward
            # latent_expanded, candidates_group, rewards_group 不需要梯度，Python GC 会处理，或者也可以 del

        # ------------------------------------------------------------
        # Step D: Beam Search Selection (Using Strategy)
        # ------------------------------------------------------------
        # 拼接所有 Groups 的结果
        all_candidates = torch.cat(all_candidates_list, dim=0) # [BeamWidth * M, ...]
        all_rewards = torch.cat(all_rewards_list, dim=0)       # [BeamWidth * M]
        
        total_budget_M_next_step = get_constant_budget_M(BASE_M, step_idx+1, args.sampling_steps)
        target_K = max(1, int(total_budget_M_next_step / args.group_num)) 

        selected_indices = select_next_candidates(
            all_rewards, 
            target_K, 
            strategy=args.selection_strategy, # "best", "softmax", "random"
            temperature=getattr(args, 'selection_temp', 1.0)
        )

        curr_latents = all_candidates[selected_indices].detach()
        current_beam_width = len(selected_indices) # 理论上等于 target_K 或更小
        
        # ============================================================
        # Optimizer Step & Aggregation
        # ============================================================
        # grad_norm_val = transformer.clip_grad_norm_(max_grad_norm).item()
        # optimizer.step()
        # lr_scheduler.step()
        # optimizer.zero_grad()
        
        # tracker.update("grad_norm", grad_norm_val)
    # # ============================================================
    # # Optimizer Step & Aggregation
    # # ============================================================
    grad_norm_val = transformer.clip_grad_norm_(max_grad_norm).item()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()

    results = tracker.reduce_and_summary()
    results["grad_norm_mean"] = grad_norm_val

    return results


def main(args):
    torch.backends.cuda.matmul.allow_tf32 = True

    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    dist.init_process_group("nccl")
    torch.cuda.set_device(local_rank)
    device = torch.cuda.current_device()
    initialize_sequence_parallel_state(args.sp_size)

    # If passed along, set the training seed now. On GPU...
    if args.seed is not None:
        # TODO: t within the same seq parallel group should be the same. Noise should be different.
        set_seed(args.seed + rank)
    # We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.

    # Handle the repository creation
    if rank <= 0 and args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    # 初始化 reward_models 字典
    reward_models = get_reward_models(args, device)
    

    main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
    # keep the master weight to float32
    # For mixed precision training we cast all non-trainable weigths to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required
    transformer = FluxTransformer2DModel.from_pretrained(
            args.pretrained_model_name_or_path,
            subfolder="transformer",
            torch_dtype = torch.float32
    )

    main_print(f"--> loading model done, begin to fsdp")
    fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
        transformer,
        args.fsdp_sharding_startegy,
        False,
        args.use_cpu_offload,
        args.master_weight_type,
    )
    transformer = FSDP(transformer, **fsdp_kwargs,)
    main_print(
        f"--> fsdp done, Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}"
    )

    ema_handler = None
    if args.use_ema:
        ema_handler = FSDP_EMA(transformer, args.ema_decay, rank)
    main_print(f"--> ema done")

    if args.gradient_checkpointing:
        apply_fsdp_checkpointing(
            transformer, no_split_modules, args.selective_checkpointing
        )
    main_print(f"--> apply_fsdp_checkpointing done")
    
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="vae",
        torch_dtype = torch.bfloat16,
    ).to(device)
    main_print(f"--> vae done")


    # Set model as trainable.
    transformer.train()
    params_to_optimize = transformer.parameters()
    params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))

    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=args.weight_decay,
        eps=1e-8,
    )

    init_steps = 0
    main_print(f"optimizer: {optimizer}")

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=1000000,
        num_cycles=args.lr_num_cycles,
        power=args.lr_power,
        last_epoch=init_steps - 1,
    )

    train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
    sampler = DistributedSampler(
            train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=args.sampler_seed
        )
    train_dataloader = DataLoader(
        train_dataset,
        sampler=sampler,
        collate_fn=latent_collate_function,
        pin_memory=True,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
        drop_last=True,
    )

    valid_dataset = LatentDataset(args.valid_data_json_path, args.num_latent_t, args.cfg)
    valid_sampler = DistributedSampler(
        valid_dataset, rank=rank, num_replicas=world_size, shuffle=False
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        collate_fn=latent_collate_function,
        pin_memory=True,
        batch_size=1,
        num_workers=1,
        drop_last=False,
    )

    #vae.enable_tiling()

    if rank <= 0:
        wandb.init(project=args.wandb_project, config=args, name=args.wandb_name)

    # Train!
    main_print("***** Running training *****")
    main_print(f"  Num examples = {len(train_dataset)}")
    main_print(f"  Dataloader size = {len(train_dataloader)}")
    main_print(f"  Resume training from step {init_steps}")
    main_print(f"  Instantaneous batch size per device = {args.train_batch_size}")
    main_print(f"  Total optimization steps per epoch = {args.max_train_steps}")
    main_print(
        f"  Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
    )
    # print dtype
    main_print(f"  Master weight dtype: {transformer.parameters().__next__().dtype}")

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        assert NotImplementedError("resume_from_checkpoint is not supported now.")
        # TODO

    progress_bar = tqdm(
        range(0, 100000),
        initial=init_steps,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=local_rank > 0,
    )

    loader = sp_parallel_dataloader_wrapper(
        train_dataloader,
        device,
        args.train_batch_size,
        args.sp_size,
        args.train_sp_batch_size,
    )

    step_times = deque(maxlen=100)

    # The number of epochs 1 is a random value; you can also set the number of epochs to be two.
    for epoch in range(1):
        if isinstance(sampler, DistributedSampler):
            sampler.set_epoch(epoch) # Crucial for distributed shuffling per epoch

        
        for step in range(init_steps+1, args.max_train_steps+1):
            start_time = time.time()
            if step % args.checkpointing_steps == 0:
                save_checkpoint(transformer, rank, args.output_dir, step, epoch)
                if args.use_ema:
                    save_ema_checkpoint(ema_handler, rank, args.output_dir, step, epoch, dict(transformer.config))
                dist.barrier()
            
            # Run validation before training step if it's a validation step
            if step % args.validation_steps == 0 or step == 1:
            # if step % args.validation_steps == 0 :
                valid_reward_avg = log_validation(
                    args,
                    step,
                    valid_dataloader,
                    transformer,
                    vae,
                    reward_models,
                    device
                )
                if rank <= 0:
                    wandb.log({"valid_reward_avg": valid_reward_avg}, step=step)
                dist.barrier()

            log_res = train_dense_reward_step(
                args,
                device, 
                transformer,
                vae,
                reward_models,
                optimizer,
                lr_scheduler,
                loader,
                args.max_grad_norm,
                ema_handler,
                step
            )

            if args.use_ema and ema_handler:
                ema_handler.update(transformer)
    
            step_time = time.time() - start_time
            step_times.append(step_time)
            avg_step_time = sum(step_times) / len(step_times)
    
            progress_bar.set_postfix(log_res)
            progress_bar.update(1)
            if rank <= 0:
                wandb.log(
                    {
                        "train_loss": log_res['loss_mean'],
                        "learning_rate": lr_scheduler.get_last_lr()[0],
                        "step_time": step_time,
                        "avg_step_time": avg_step_time,
                        "grad_norm": log_res['grad_norm_mean'],
                        "reward": log_res['ode_sample_reward_mean'],
                        "sample_reward": log_res['sample_reward_mean'],
                        "var": log_res['sample_reward_var'],
                        "log_p_old": log_res['log_p_mean']
                    },
                    step=step,
                )

    if get_sequence_parallel_state():
        destroy_sequence_parallel_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--validation_steps", type=int,  default=10, help="validation every steps")
    parser.add_argument("--selection_strategy", type=str, required=True)
    parser.add_argument("--valid_data_json_path", type=str, required=True)
    parser.add_argument("--valid_save_path", type=str, required=True)
    parser.add_argument("--training_save_path", type=str, required=True)
    parser.add_argument("--hpsv3_save_dir", type=str, required=True)
    
    # dataset & dataloader
    parser.add_argument("--data_json_path", type=str, required=True)
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=10,
        help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=16,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--num_latent_t",
        type=int,
        default=1,
        help="number of latent frames",
    )
    # text encoder & vae & diffusion model
    parser.add_argument("--pretrained_model_name_or_path", type=str)
    parser.add_argument("--dit_model_name_or_path", type=str, default=None)
    parser.add_argument("--vae_model_path", type=str, default=None, help="vae model.")
    parser.add_argument("--cache_dir", type=str, default="./cache_dir")

    # diffusion setting
    parser.add_argument("--ema_decay", type=float, default=0.995)
    parser.add_argument("--ema_start_step", type=int, default=0)
    parser.add_argument("--cfg", type=float, default=0.0)

    # validation & logs
    parser.add_argument(
        "--seed", type=int, default=None, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )

    # optimizer & scheduler & Training
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--lr_warmup_steps",
        type=int,
        default=10,
        help="Number of steps for the warmup in the lr scheduler.",
    )
    parser.add_argument(
        "--max_grad_norm", default=2.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument("--selective_checkpointing", type=float, default=1.0)
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--use_cpu_offload",
        action="store_true",
        help="Whether to use CPU offload for param & gradient & optimizer states.",
    )

    parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
    parser.add_argument(
        "--train_sp_batch_size",
        type=int,
        default=1,
        help="Batch size for sequence parallel training",
    )

    parser.add_argument("--fsdp_sharding_startegy", default="full")

    # lr_scheduler
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant_with_warmup",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
        help="Number of cycles in the learning rate scheduler.",
    )
    parser.add_argument(
        "--lr_power",
        type=float,
        default=1.0,
        help="Power factor of the polynomial scheduler.",
    )
    parser.add_argument(
        "--weight_decay", type=float, default=0.01, help="Weight decay to apply."
    )
    parser.add_argument(
        "--master_weight_type",
        type=str,
        default="fp32",
        help="Weight type to use - fp32 or bf16.",
    )

    #GRPO training
    parser.add_argument(
        "--h",
        type=int,
        default=None,   
        help="video height",
    )
    parser.add_argument(
        "--w",
        type=int,
        default=None,   
        help="video width",
    )
    parser.add_argument(
        "--t",
        type=int,
        default=None,   
        help="video length",
    )
    parser.add_argument(
        "--sampling_steps",
        type=int,
        default=None,   
        help="sampling steps",
    )
    parser.add_argument(
        "--eta",
        type=float,
        default=None,   
        help="noise eta",
    )
    parser.add_argument(
        "--sampler_seed",
        type=int,
        default=None,   
        help="seed of sampler",
    )
    parser.add_argument(
        "--use_hpsv2",
        action="store_true",
        default=False,
        help="whether use hpsv2 as reward model",
    )
    parser.add_argument(
        "--use_hpsv3",
        action="store_true",
        default=False,
        help="whether use hpsv3 as reward model",
    )
    parser.add_argument(
        "--use_pickscore",
        action="store_true",
        default=False,
        help="whether use pickscore as reward model",
    )
    parser.add_argument(
        "--use_video_align",
        action="store_true",
        default=False,
        help="whether use pickscore as reward model",
    )
    parser.add_argument(
        "--shift",
        type = float,
        default=1.0,
        help="shift for timestep scheduler",
    )
    parser.add_argument(
        "--adv_clip_max",
        type = float,
        default=5.0,
        help="clipping advantage",
    )
    parser.add_argument(
        "--use_ema", 
        action="store_true", 
        help="Enable Exponential Moving Average of model weights."
    )
    parser.add_argument(
        "--wandb_name",
        type=str,
        default="flux_grpo"
    )


    parser.add_argument("--decay_factor_norm", type=float, default=1.0)
    parser.add_argument("--noise_type", type=str, default="plain")
    parser.add_argument("--advantage_type", type=str, default="group")
    parser.add_argument("--save_images", action="store_true", help="Whether to save intermediate decoded images to disk.")
    parser.add_argument("--clip_adv", action="store_true")
    parser.add_argument("--M", type=int, default=None)
    parser.add_argument("--group_num", type=int, default=14)
    parser.add_argument("--wandb_project", type=str, default="flux")


    args = parser.parse_args()
    main(args)