from collections import defaultdict
import contextlib
import os
import datetime
from concurrent import futures
import time
import json
import hashlib
from typing import Any, Dict, List, Tuple
import copy
from absl import app, flags
from ml_collections import config_flags
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data.distributed import DistributedSampler
import logging

# Setup logger
logger = logging.getLogger(__name__)
from diffusers import QwenImageEditPipeline, QwenImageTransformer2DModel
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import calculate_shift, calculate_dimensions

from flow_grpo.fsdp_utils import (
    FSDPConfig,
    fsdp_wrapper,
    init_distributed,
    save_fsdp_checkpoint,
    register_optimizer_offload_hooks,
    offload_fsdp_model_to_cpu,
    load_fsdp_model_to_gpu
)
import numpy as np
import flow_grpo.prompts
import flow_grpo.rewards
from flow_grpo.stat_tracking import PerPromptStatTracker
from flow_grpo.diffusers_patch.qwenimage_edit_pipeline_with_logprob import pipeline_with_logprob
from flow_grpo.diffusers_patch.sd3_sde_with_logprob import sde_step_with_logprob
import torch
import wandb
import requests
from functools import partial
import tqdm
import tempfile
from PIL import Image
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, PeftModel
import random
from torch.utils.data import Dataset, DataLoader, Sampler
from flow_grpo.ema import EMAModuleWrapper
from transformers import Qwen2_5_VLForConditionalGeneration
from signal_config import SignalConfig

tqdm = partial(tqdm.tqdm, dynamic_ncols=True)

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.")

def gather_tensor(tensor, world_size):
    if world_size == 1:
        return tensor
    
    gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(gather_list, tensor)
    return torch.cat(gather_list)


def _gather_all_objects(obj: Any, rank: int, world_size: int) -> List[Any]:
    if not dist.is_available() or not dist.is_initialized():
        return obj
    gathered: List[Any] = [None] * world_size
    dist.all_gather_object(gathered, obj)
    combined: List[Any] = []
    for part in gathered:
        if part:
            combined.extend(part)
    return combined

def set_seed(seed, device_specific=True):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device_specific and torch.cuda.is_available():
        # For device-specific seeding
        torch.cuda.manual_seed_all(seed + dist.get_rank() if dist.is_initialized() else seed)

class GenevalPromptImageDataset(Dataset):
    def __init__(self, dataset, split='train'):
        self.dataset = dataset
        self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl')
        with open(self.file_path, 'r', encoding='utf-8') as f:
            self.metadatas = [json.loads(line) for line in f]
            self.prompts = [item['prompt'] for item in self.metadatas]
        
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        item = {
            "prompt": self.prompts[idx],
            "metadata": self.metadatas[idx]
        }
        # Assuming 'image' in metadata contains a path to the image file
        image_path = self.metadatas[idx]['image']
        item["prompt_with_image_path"] = f"{self.prompts[idx]}_{image_path}"
        image = Image.open(os.path.join(self.dataset, image_path)).convert('RGB')
        item["image"] = image
        return item

    @staticmethod
    def collate_fn(examples):
        prompts = [example["prompt"] for example in examples]
        metadatas = [example["metadata"] for example in examples]
        images = [example["image"] for example in examples]
        prompt_with_image_paths = [example["prompt_with_image_path"] for example in examples]
        return prompts, metadatas, images, prompt_with_image_paths


class InMemoryPromptImageDataset(Dataset):
    """Dataset backed by in-memory prompt+image samples broadcast from the signal service."""

    def __init__(self, samples: List[Any]):
        self.samples = samples or []

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        sample = self.samples[idx]
        if isinstance(sample, dict):
            prompt = sample.get("prompt")
            metadata = sample.get("metadata")
        else:
            prompt = str(sample)
            metadata = {}
        
        # 从 metadata 中提取图像路径并加载第一张图片
        image = None
        image_path = metadata.get("ori_image")
        if image_path:
            try:
                image = Image.open(image_path).convert('RGB')
            except Exception as e:
                logger.warning(f"Failed to load image {image_path}: {e}")
                image = None
        
        return {"prompt": prompt, "metadata": metadata, "image": image}

    @staticmethod
    def collate_fn(examples: List[Dict[str, Any]]):
        prompts = [example["prompt"] for example in examples]
        metadatas = [example["metadata"] for example in examples]
        ref_images = [example["image"] for example in examples]
        return prompts, metadatas, ref_images

class DistributedKRepeatSampler(Sampler):
    def __init__(self, dataset, batch_size, k, num_replicas, rank, seed=0):
        self.dataset = dataset
        self.batch_size = batch_size  # Batch size per replica
        self.k = k                    # Number of repetitions per sample
        self.num_replicas = num_replicas  # Total number of replicas
        self.rank = rank              # Current replica rank
        self.seed = seed              # Random seed for synchronization
        
        # Compute the number of unique samples needed per iteration
        self.total_samples = self.num_replicas * self.batch_size
        assert self.total_samples % self.k == 0, f"k can not divide n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}"
        self.m = self.total_samples // self.k  # Number of unique samples
        self.epoch = 0

    def __iter__(self):
        while True:
            # Generate a deterministic random sequence to ensure all replicas are synchronized
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            
            # Randomly select m unique samples
            indices = torch.randperm(len(self.dataset), generator=g)[:self.m].tolist()
            
            # Repeat each sample k times to generate n*b total samples
            repeated_indices = [idx for idx in indices for _ in range(self.k)]
            
            # Shuffle to ensure uniform distribution
            shuffled_indices = torch.randperm(len(repeated_indices), generator=g).tolist()
            shuffled_samples = [repeated_indices[i] for i in shuffled_indices]
            
            # Split samples to each replica
            per_card_samples = []
            for i in range(self.num_replicas):
                start = i * self.batch_size
                end = start + self.batch_size
                per_card_samples.append(shuffled_samples[start:end])
            
            # Return current replica's sample indices
            yield per_card_samples[self.rank]
    
    def set_epoch(self, epoch):
        self.epoch = epoch  # Used to synchronize random state across epochs


def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device):
    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
            text_encoders, tokenizers, prompt, max_sequence_length
        )
        prompt_embeds = prompt_embeds.to(device)
        pooled_prompt_embeds = pooled_prompt_embeds.to(device)
        text_ids = text_ids.to(device)
    return prompt_embeds, pooled_prompt_embeds

def calculate_zero_std_ratio(prompts, gathered_rewards):
    """
    Calculate the proportion of unique prompts whose reward standard deviation is zero.
    
    Args:
        prompts: List of prompts.
        gathered_rewards: Dictionary containing rewards, must include the key 'ori_avg'.
        
    Returns:
        zero_std_ratio: Proportion of prompts with zero standard deviation.
        prompt_std_devs: Mean standard deviation across all unique prompts.
    """
    # Convert prompt list to NumPy array
    prompt_array = np.array(prompts)
    
    # Get unique prompts and their group information
    unique_prompts, inverse_indices, counts = np.unique(
        prompt_array, 
        return_inverse=True,
        return_counts=True
    )
    
    # Group rewards for each prompt
    grouped_rewards = gathered_rewards['ori_avg'][np.argsort(inverse_indices)]
    split_indices = np.cumsum(counts)[:-1]
    reward_groups = np.split(grouped_rewards, split_indices)
    
    # Calculate standard deviation for each group
    prompt_std_devs = np.array([np.std(group) for group in reward_groups])
    
    # Calculate the ratio of zero standard deviation
    zero_std_count = np.count_nonzero(prompt_std_devs == 0)
    zero_std_ratio = zero_std_count / len(prompt_std_devs)
    
    return zero_std_ratio, prompt_std_devs.mean()

def create_generator(prompts, base_seed):
    generators = []
    for prompt in prompts:
        # Use a stable hash (SHA256), then convert it to an integer seed
        hash_digest = hashlib.sha256(prompt.encode()).digest()
        prompt_hash_int = int.from_bytes(hash_digest[:4], 'big')  # Take the first 4 bytes as part of the seed
        seed = (base_seed + prompt_hash_int) % (2**31) # Ensure the number is within a valid range
        gen = torch.Generator().manual_seed(seed)
        generators.append(gen)
    return generators


def _normalize_signal_samples(samples: List[Any]) -> List[Dict[str, Any]]:
    normalized: List[Dict[str, Any]] = []
    for sample in samples:
        if isinstance(sample, dict):
            prompt = sample.get("prompt", "")
            metadata = copy.deepcopy(sample.get("metadata", {}))
            if not isinstance(metadata, dict):
                metadata = {"metadata": metadata}
            image = sample.get("image")
            images = sample.get("images")
            if image:
                metadata.setdefault("images", [])
                if isinstance(metadata["images"], list):
                    if image not in metadata["images"]:
                        metadata["images"].insert(0, image)
                else:
                    metadata["images"] = [image]
            if images and not metadata.get("images"):
                metadata["images"] = images
        else:
            prompt = str(sample)
            metadata = {}
        normalized.append({"prompt": prompt, "metadata": metadata})
    return normalized


def _ensure_rollout_id(metadata: Dict[str, Any], prompt: str, fallback_idx: int) -> str:
    rollout_id = metadata.get("rollout_id")
    if rollout_id:
        return rollout_id
    fallback_source = metadata.get("improved_prompt") or metadata.get("original_prompt") or prompt
    digest = hashlib.sha1(f"{fallback_source}_{fallback_idx}".encode("utf-8")).hexdigest()[:10]
    metadata["rollout_id"] = digest
    return digest


def _instruction_group_key(metadata: Dict[str, Any], fallback_idx: int) -> Tuple[str, str]:
    data_source = metadata.get("data_source") or metadata.get("source") or metadata.get("category") or "unknown"
    sample_idx = metadata.get("sample_idx")
    if sample_idx is None:
        sample_idx = metadata.get("idx", fallback_idx)
        metadata.setdefault("sample_idx", sample_idx)
    return str(data_source), str(sample_idx)


def perform_prompt_selection(
    samples: List[Dict[str, Any]],
    pipeline,
    config,
    device,
    rank: int,
    local_rank: int,
    world_size: int,
    selection_reward_fn,
    autocast,
    global_step: int,
    is_distributed: bool,
    executor,
    images_save_dir: str,
):
    if not samples or selection_reward_fn is None:
        fallback_samples = [copy.deepcopy(sample) for sample in samples]
        return fallback_samples, []

    transformer_training = pipeline.transformer.training
    pipeline.transformer.eval()

    selection_images_per_prompt = max(1, config.sample.num_image_per_prompt // 4)
    candidate_samples: Dict[str, Dict[str, Any]] = {}
    group_to_rollouts: Dict[Tuple[str, str], List[str]] = defaultdict(list)
    instruction_order: List[Tuple[str, str]] = []
    candidate_order: Dict[str, int] = {}

    for idx, sample in enumerate(samples):
        metadata = sample.setdefault("metadata", {})
        rollout_id = _ensure_rollout_id(metadata, sample.get("prompt", ""), idx)
        group_key = _instruction_group_key(metadata, idx)
        if group_key not in group_to_rollouts:
            instruction_order.append(group_key)
        group_to_rollouts[group_key].append(rollout_id)
        candidate_samples.setdefault(rollout_id, sample)
        candidate_order.setdefault(rollout_id, len(candidate_order))

    selection_sample_list: List[Dict[str, Any]] = []
    for sample in samples:
        for _ in range(selection_images_per_prompt):
            selection_sample_list.append(copy.deepcopy(sample))

    if not selection_sample_list:
        fallback_samples = [copy.deepcopy(sample) for sample in samples]
        return fallback_samples, []

    if rank == 0:
        logger.info(
            "Step %s: Selection stage evaluating %s prompts (%s instructions) with %s images each",
            global_step,
            len(samples),
            len(instruction_order),
            selection_images_per_prompt,
        )

    selection_dataset = InMemoryPromptImageDataset(selection_sample_list)
    if is_distributed:
        selection_sampler: Sampler = DistributedSampler(
            selection_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            drop_last=False,
        )
        selection_sampler.set_epoch(global_step)
    else:
        selection_sampler = None

    selection_dataloader = DataLoader(
        selection_dataset,
        batch_size=config.sample.train_batch_size,
        sampler=selection_sampler,
        shuffle=selection_sampler is None,
        collate_fn=InMemoryPromptImageDataset.collate_fn,
        num_workers=1,
        drop_last=False,
    )

    selection_stats_local: Dict[str, Dict[str, float]] = defaultdict(lambda: {"sum": 0.0, "count": 0.0})
    selection_futures: List[futures.Future] = []
    selection_batches: List[Tuple[torch.Tensor, List[str], List[Dict[str, Any]]]] = []

    with torch.no_grad():
        for i, (prompts, prompt_metadata, ref_images) in enumerate(
            tqdm(
                selection_dataloader,
                desc="Selection: sampling",
                disable=local_rank != 0,
                position=0,
            )
        ):
            ref_images = [img.resize((config.resolution, config.resolution)) if img else None for img in ref_images]
            if config.sample.same_latent:
                generator = create_generator(prompts, base_seed=global_step * 10000 + 100000 + i)
            else:
                generator = None

            with autocast():
                collected_data = pipeline_with_logprob(
                    pipeline,
                    ref_images,
                    prompts,
                    negative_prompt=[" "] * len(prompts),
                    num_inference_steps=config.sample.num_steps,
                    true_cfg_scale=config.sample.guidance_scale,
                    output_type="pt",
                    height=config.resolution,
                    width=config.resolution,
                    noise_level=config.sample.noise_level,
                    generator=generator,
                    sde_window_size=config.sample.sde_window_size,
                    sde_window_range=config.sample.sde_window_range,
                )

            images = collected_data["images"]
            images_cpu = images.detach().cpu()
            future = executor.submit(
                selection_reward_fn,
                images_cpu,
                prompts,
                prompt_metadata,
                ref_images,
                only_strict=True,
            )
            time.sleep(0)
            selection_futures.append(future)
            selection_batches.append((images_cpu, prompts, prompt_metadata))

    selection_all_images: List[torch.Tensor] = []
    selection_all_prompts: List[str] = []
    selection_all_metadatas: List[Dict[str, Any]] = []
    selection_rewards_local: Dict[str, List[np.ndarray]] = defaultdict(list)

    batch_iter = zip(selection_batches, selection_futures)
    for (images_cpu, prompts, prompt_metadata), reward_future in tqdm(
        batch_iter,
        desc="Selection: scoring",
        disable=local_rank != 0,
        position=0,
    ):
        try:
            score_dict, _ = reward_future.result()
        except Exception as exc:  # noqa: BLE001
            if rank == 0:
                logger.error("Selection reward computation failed: %s", exc)
            continue

        selection_all_images.append(images_cpu)
        selection_all_prompts.extend(prompts)
        selection_all_metadatas.extend(prompt_metadata)

        for key, value in score_dict.items():
            if isinstance(value, torch.Tensor):
                arr = value.detach().cpu().numpy()
            else:
                arr = np.asarray(value)
            selection_rewards_local[key].append(arr)

        batch_scores = score_dict.get("avg") or []
        if isinstance(batch_scores, torch.Tensor):
            batch_scores = batch_scores.detach().cpu().tolist()
        elif isinstance(batch_scores, np.ndarray):
            batch_scores = batch_scores.tolist()

        for sample_idx, metadata in enumerate(prompt_metadata):
            rollout_id = metadata.get("rollout_id")
            if rollout_id is None or sample_idx >= len(batch_scores):
                continue
            score_val = float(batch_scores[sample_idx])
            stats = selection_stats_local[rollout_id]
            stats["sum"] += score_val
            stats["count"] += 1.0

    if selection_all_images:
        selection_images_tensor = torch.cat(selection_all_images, dim=0)
        selection_rewards_np = {key: np.concatenate(value) for key, value in selection_rewards_local.items() if value}
        save_images_and_metadata(
            selection_images_tensor,
            selection_all_prompts,
            selection_all_metadatas,
            selection_rewards_np,
            images_save_dir,
            global_step,
            split='selection',
            rank=rank,
            world_size=world_size,
        )

    if transformer_training:
        pipeline.transformer.train()

    selection_stats_payload = [
        {"rollout_id": rollout_id, "sum": stats["sum"], "count": stats["count"]}
        for rollout_id, stats in selection_stats_local.items()
    ]
    aggregated_selection_stats = _gather_all_objects(selection_stats_payload, rank, world_size)

    prompt_mean_scores: Dict[str, float] = {}
    aggregated_map: Dict[str, Dict[str, float]] = defaultdict(lambda: {"sum": 0.0, "count": 0.0})
    for item in aggregated_selection_stats:
        rollout_id = item.get("rollout_id")
        if not rollout_id:
            continue
        aggregated_entry = aggregated_map[rollout_id]
        aggregated_entry["sum"] += float(item.get("sum", 0.0))
        aggregated_entry["count"] += float(item.get("count", 0.0))

    for rollout_id, stats in aggregated_map.items():
        count = stats.get("count", 0.0)
        if count <= 0:
            continue
        prompt_mean_scores[rollout_id] = stats["sum"] / count

    selection_reward_entries: List[Dict[str, Any]] = []
    if rank == 0:
        for rollout_id, score in prompt_mean_scores.items():
            sample = candidate_samples.get(rollout_id)
            if sample is None:
                continue
            meta = copy.deepcopy(sample.get("metadata", {}))
            meta.pop("images", None)
            meta["selection_images_per_prompt"] = selection_images_per_prompt
            selection_reward_entries.append(
                {
                    "rollout_id": rollout_id,
                    "score": float(score),
                    "metadata": meta,
                }
            )

    candidate_mean_scores: Dict[str, float] = {}
    for rollout_id in candidate_samples.keys():
        if rollout_id in prompt_mean_scores:
            candidate_mean_scores[rollout_id] = prompt_mean_scores[rollout_id]
        else:
            candidate_mean_scores[rollout_id] = float("-inf")

    selected_rollout_ids: List[str] = []
    for group_key in instruction_order:
        candidate_ids = group_to_rollouts.get(group_key, [])
        if not candidate_ids:
            continue
        best_rollout = max(
            candidate_ids,
            key=lambda rid: (
                candidate_mean_scores.get(rid, float("-inf")),
                -candidate_order.get(rid, 0),
            ),
        )
        selected_rollout_ids.append(best_rollout)

    selected_samples = [copy.deepcopy(candidate_samples[rid]) for rid in selected_rollout_ids if rid in candidate_samples]
    if not selected_samples:
        selected_samples = [copy.deepcopy(sample) for sample in samples]
        if rank == 0:
            logger.warning("Selection stage failed to pick candidates; falling back to all prompts.")
    elif rank == 0:
        logger.info(
            "Step %s: Selected %s prompts for unified reward stage", global_step, len(selected_samples)
        )

    return selected_samples, selection_reward_entries

        
def compute_log_prob(transformer, pipeline, sample, j, config, rank):
    calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, 1)
    img_shapes = [
        [
            (1, config.resolution // pipeline.vae_scale_factor // 2, config.resolution // pipeline.vae_scale_factor // 2),
            (1, calculated_height // pipeline.vae_scale_factor // 2, calculated_width // pipeline.vae_scale_factor // 2),
        ]
    ]* len(sample["latents"][:, j])
    txt_seq_lens = sample["prompt_embeds_mask"].sum(dim=1).tolist()
    negative_txt_seq_lens = sample["negative_prompt_embeds_mask"].sum(dim=1).tolist()

    # Predict the noise residual
    # txt_seq_lens是最长的,sample["prompt_embeds_mask"]和sample["prompt_embeds"]可能有没必要的padding
    sample["prompt_embeds_mask"] = sample["prompt_embeds_mask"][:, :max(txt_seq_lens+negative_txt_seq_lens)]
    sample["negative_prompt_embeds_mask"] = sample["negative_prompt_embeds_mask"][:, :max(txt_seq_lens+negative_txt_seq_lens)]
    sample["prompt_embeds"] = sample["prompt_embeds"][:, :max(txt_seq_lens+negative_txt_seq_lens)]
    sample["negative_prompt_embeds"] = sample["negative_prompt_embeds"][:, :max(txt_seq_lens+negative_txt_seq_lens)]

    latent_model_input = sample["latents"][:, j]
    if sample["image_latents"] is not None:
        latent_model_input = torch.cat([latent_model_input, sample["image_latents"]],dim=1)

    noise_pred = transformer(
        hidden_states=torch.cat([latent_model_input, latent_model_input], dim=0),
        timestep=torch.cat([sample["timesteps"][:, j], sample["timesteps"][:, j]], dim=0) / 1000,
        guidance=None,
        encoder_hidden_states_mask=torch.cat([sample["prompt_embeds_mask"], sample["negative_prompt_embeds_mask"]], dim=0),
        encoder_hidden_states=torch.cat([sample["prompt_embeds"], sample["negative_prompt_embeds"]], dim=0),
        img_shapes=img_shapes*2,
        txt_seq_lens=txt_seq_lens+negative_txt_seq_lens,
    )[0]
    noise_pred, neg_noise_pred = noise_pred.chunk(2, dim=0)
    noise_pred = noise_pred[:, : sample["latents"][:, j].size(1)]
    neg_noise_pred = neg_noise_pred[:, : sample["latents"][:, j].size(1)]
    comb_pred = neg_noise_pred + config.sample.guidance_scale * (noise_pred - neg_noise_pred)

    cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
    noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
    noise_pred = comb_pred * (cond_norm / noise_norm)
    # compute the log prob of next_latents given latents under the current model
    prev_sample, log_prob, prev_sample_mean, std_dev_t = sde_step_with_logprob(
        pipeline.scheduler,
        noise_pred.float(),
        sample["timesteps"][:, j],
        sample["latents"][:, j].float(),
        prev_sample=sample["next_latents"][:, j].float(),
        noise_level=config.sample.noise_level,
    )

    return prev_sample, log_prob, prev_sample_mean, std_dev_t

def eval(pipeline, test_dataloader, config, rank, local_rank, world_size, device, global_step, reward_fn, executor, autocast, ema, transformer_trainable_parameters, save_dir):
    if config.train.ema:
        ema.copy_ema_to(transformer_trainable_parameters, store_temp=True)

    # Separate containers for local saving and global logging
    all_rewards_gather = defaultdict(list)  # for wandb/logging
    all_rewards_local = defaultdict(list)   # for saving (no collectives)
    all_images = []
    all_prompts = []
    all_metadatas = []
    
    # 新增：评估阶段每个数据生成的图片数量（缺省为1）
    try:
        eval_num_images = getattr(config.sample, "eval_num_image_per_prompt", 1)
    except Exception:
        eval_num_images = 1
    eval_num_images = int(max(1, eval_num_images))

    last_used_prompts = None  # 用于日志中 caption 对齐

    for test_batch in tqdm(
            test_dataloader,
            desc="Eval: ",
            disable=local_rank != 0,
            position=0,
        ):
        prompts, prompt_metadata, ref_images = test_batch
        ref_images = [ref_image.resize((config.resolution, config.resolution)) if ref_image else None for ref_image in ref_images]

        # 新增：按 num_image_per_prompt 展开 batch
        if eval_num_images > 1:
            batch_prompts = [p for p in prompts for _ in range(eval_num_images)]
            batch_metadatas = [m for m in prompt_metadata for _ in range(eval_num_images)]
            batch_ref_images = [img for img in ref_images for _ in range(eval_num_images)]
        else:
            batch_prompts = prompts
            batch_metadatas = prompt_metadata
            batch_ref_images = ref_images

        with autocast():
            with torch.no_grad():
                collected_data = pipeline_with_logprob(
                        pipeline,
                        batch_ref_images,
                        batch_prompts,
                        negative_prompt=[" "]*len(batch_prompts),
                        num_inference_steps=config.sample.eval_num_steps,
                        true_cfg_scale=config.sample.guidance_scale,
                        output_type="pt",
                        height=config.resolution,
                        width=config.resolution, 
                        noise_level=0,
                        sde_window_size=0,
                )
        images = collected_data["images"]
        # 使用展开后的 prompts 和 metadatas 计算 reward
        rewards = executor.submit(reward_fn, images, batch_prompts, batch_metadatas, batch_ref_images, only_strict=False)
        time.sleep(0)
        rewards, reward_metadata = rewards.result()

        # Accumulate local data for saving
        all_images.append(images)
        all_prompts.extend(batch_prompts)        # 使用展开后的 prompts
        all_metadatas.extend(batch_metadatas)    # 使用展开后的 metadatas
        for key, value in rewards.items():
            all_rewards_local[key].append(np.asarray(value))

        # Accumulate gathered rewards for logging
        for key, value in rewards.items():
            rewards_gather = gather_tensor(torch.as_tensor(value, device=device), world_size).cpu().float().numpy()
            all_rewards_gather[key].append(rewards_gather)

        # 记录最后一次 batch 的 prompts（展开后）用于日志 caption
        last_used_prompts = batch_prompts
    
    # Concatenate all images and local rewards, then save locally (group by prompt, single json per prompt)
    all_images_local = torch.cat(all_images, dim=0) if len(all_images) > 0 else torch.empty(0)
    all_rewards_local = {key: np.concatenate(value) for key, value in all_rewards_local.items()}
    save_images_and_metadata(
        all_images_local, 
        all_prompts, 
        all_metadatas, 
        all_rewards_local,
        save_dir,
        global_step,
        split='eval',
        rank=rank,
        world_size=world_size
    )
    
    # Existing logging (uses gathered rewards)
    last_batch_images_gather = gather_tensor(torch.as_tensor(images, device=device), world_size).cpu().float().numpy()
    # 使用展开后的 last_used_prompts 构造 caption，避免数量不一致
    last_batch_prompt_ids = pipeline.tokenizer(
        last_used_prompts if last_used_prompts is not None else batch_prompts,
        padding="max_length",
        max_length=256,
        truncation=True,
        return_tensors="pt",
    ).input_ids.to(device)
    last_batch_prompt_ids_gather = gather_tensor(last_batch_prompt_ids, world_size).cpu().float().numpy()
    last_batch_prompts_gather = pipeline.tokenizer.batch_decode(
        last_batch_prompt_ids_gather, skip_special_tokens=True
    )
    last_batch_rewards_gather = {}
    for key, value in rewards.items():
        last_batch_rewards_gather[key] = gather_tensor(torch.as_tensor(value, device=device), world_size).cpu().float().numpy()

    all_rewards = {key: np.concatenate(value) for key, value in all_rewards_gather.items()}
    if rank == 0:
        with tempfile.TemporaryDirectory() as tmpdir:
            num_samples = min(15, len(last_batch_images_gather))
            sample_indices = range(num_samples)
            for idx, index in enumerate(sample_indices):
                image = last_batch_images_gather[index]
                pil = Image.fromarray(
                    (image.transpose(1, 2, 0) * 255).astype(np.uint8)
                )
                pil = pil.resize((config.resolution, config.resolution))
                pil.save(os.path.join(tmpdir, f"{idx}.jpg"))
            sampled_prompts = [last_batch_prompts_gather[index] for index in sample_indices]
            sampled_rewards = [{k: last_batch_rewards_gather[k][index] for k in last_batch_rewards_gather} for index in sample_indices]
            for key, value in all_rewards.items():
                print(key, value.shape)
            wandb.log(
                {
                    "eval_images": [
                        wandb.Image(
                            os.path.join(tmpdir, f"{idx}.jpg"),
                            caption=f"{prompt:.1000} | " + " | ".join(f"{k}: {v:.2f}" for k, v in reward.items() if v != -10),
                        )
                        for idx, (prompt, reward) in enumerate(zip(sampled_prompts, sampled_rewards))
                    ],
                    **{f"eval_reward_{key}": np.mean(value[value != -10]) for key, value in all_rewards.items()},
                },
                step=global_step,
            )
    
    eval_reward_entries: List[Dict[str, Any]] = []
    # 选择用于提交的指标（默认优先 config.train.submit_metric，其次 avg）
    submit_metric = getattr(config.train, "submit_metric", "avg")
    avg_scores = None
    if all_rewards_local:
        avg_scores = all_rewards_local.get(submit_metric)
        if avg_scores is None:
            first_key = next(iter(all_rewards_local.keys()))
            avg_scores = all_rewards_local[first_key]

    # 修改：按 rollout_id 聚合（每个数据多图求平均）后再提交
    if avg_scores is not None:
        avg_array = np.asarray(avg_scores)
        rollout_to_scores: Dict[str, List[float]] = defaultdict(list)
        rollout_to_meta: Dict[str, Dict[str, Any]] = {}

        for meta, score in zip(all_metadatas, avg_array.tolist()):
            if isinstance(meta, dict):
                rollout_id = meta.get("rollout_id") or meta.get("metadata", {}).get("rollout_id")
                payload_meta = copy.deepcopy(meta)
            else:
                rollout_id = None
                payload_meta = {"metadata": meta}
            if rollout_id is None:
                continue
            payload_meta.pop("images", None)
            rollout_to_scores[rollout_id].append(float(score))
            if rollout_id not in rollout_to_meta:
                rollout_to_meta[rollout_id] = payload_meta

        for rollout_id, values in rollout_to_scores.items():
            eval_reward_entries.append(
                {
                    "rollout_id": rollout_id,
                    "score": float(np.mean(values)),  # 每个数据的多图均值
                    "metadata": rollout_to_meta[rollout_id],
                }
            )

    reward_summary = {
        key: float(np.mean(np.asarray(values))) for key, values in all_rewards_local.items()
    } if all_rewards_local else {}

    if config.train.ema:
        ema.copy_temp_to(transformer_trainable_parameters)

    return eval_reward_entries, reward_summary


def validate(
    global_step: int,
    signal_base_url: str,
    signal_timeout: float,
    signal_config: SignalConfig,
    rank: int,
    world_size: int,
    local_rank: int,
    pipeline,
    config,
    device,
    eval_reward_fn,
    executor,
    autocast,
    ema,
    transformer_trainable_parameters,
    images_save_dir: str,
    is_distributed: bool,
):
    """执行验证流程，包括获取验证数据、运行验证和提交结果"""
    eval_key = signal_config.get_key("eval", "data", global_step)
    eval_payload = _fetch_signal_payload(signal_base_url, eval_key, signal_timeout, rank)
    payload_eval_step = eval_payload.get("global_step")
    if payload_eval_step is not None and payload_eval_step != global_step and rank == 0:
        logger.warning(
            "Eval payload step mismatch: expected %s, received %s", global_step, payload_eval_step
        )

    eval_samples = eval_payload.get("eval_data") or eval_payload.get("data") or []
    if not eval_samples:
        if rank == 0:
            logger.info("Step %s received empty eval payload; skipping validation.", global_step)
        return

    # 获取数据后将模型移到GPU
    inference_dtype = torch.float32
    if config.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif config.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16
    if rank == 0:
        logger.info(f"Step {global_step}: Moving pipeline to {device} for validation")
    _move_pipeline_to(device, pipeline, inference_dtype, None)
    if is_distributed:
        dist.barrier()

    test_dataset = InMemoryPromptImageDataset(eval_samples)
    if is_distributed:
        eval_sampler = DistributedSampler(
            test_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=False,
            drop_last=False,
        )
    else:
        eval_sampler = None
    
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=config.sample.test_batch_size,
        sampler=eval_sampler,
        shuffle=False,
        collate_fn=InMemoryPromptImageDataset.collate_fn,
        num_workers=1,
        drop_last=False,
    )

    if rank == 0:
        logger.info(f"Step {global_step}: Running evaluation on {len(test_dataloader)} batches")
    eval_entries_local, eval_reward_summary_local = eval(
        pipeline, test_dataloader, config, rank, local_rank, world_size, device, 
        global_step, eval_reward_fn, executor, autocast, ema, 
        transformer_trainable_parameters, images_save_dir
    )

    # 提交reward前将模型移到CPU
    cpu_device = torch.device("cpu")
    if rank == 0:
        logger.info(f"Step {global_step}: Moving pipeline to CPU after evaluation")
    _move_pipeline_to(cpu_device, pipeline, inference_dtype, None)

    if is_distributed:
        dist.barrier()

    eval_entries_global = _gather_all_objects(eval_entries_local, rank, world_size)
    summary_list = _gather_all_objects(
        [eval_reward_summary_local] if eval_reward_summary_local else [],
        rank,
        world_size,
    )

    eval_reward_mean: Dict[str, float] = {}
    if summary_list:
        merged_summary: Dict[str, List[float]] = defaultdict(list)
        for summary in summary_list:
            for key, value in summary.items():
                merged_summary[key].append(value)
        eval_reward_mean = {key: float(np.mean(values)) for key, values in merged_summary.items()}
    if not eval_reward_mean and eval_entries_global:
        eval_reward_mean = {
            "score": float(np.mean([entry.get("score", 0.0) for entry in eval_entries_global]))
        }

    if rank == 0 and eval_entries_global:
        reward_key = signal_config.get_key("eval", "reward", global_step)
        logger.info(f"Step {global_step}: Submitting {len(eval_entries_global)} eval rewards with key: {reward_key}")
        logger.info(f"Step {global_step}: Eval reward mean: {eval_reward_mean}")
        eval_submission_payload = {
            "global_step": int(global_step),
            "mode": "eval",
            "rewards": eval_entries_global,
        }
        _submit_signal_payload(signal_base_url, reward_key, eval_submission_payload, rank, signal_timeout)


def _gather_all_objects(obj: Any, rank: int, world_size: int) -> List[Any]:
    if not dist.is_available() or not dist.is_initialized():
        return obj
    gathered: List[Any] = [None] * world_size
    dist.all_gather_object(gathered, obj)
    combined: List[Any] = []
    for part in gathered:
        if part:
            combined.extend(part)
    return combined

def _gather_all_objects(obj: Any, rank: int, world_size: int) -> List[Any]:
    if not dist.is_available() or not dist.is_initialized():
        return obj
    gathered: List[Any] = [None] * world_size
    dist.all_gather_object(gathered, obj)
    combined: List[Any] = []
    for part in gathered:
        if part:
            combined.extend(part)
    return combined

def _broadcast_object(obj: Any, rank: int, src: int = 0) -> Any:
    if not dist.is_available() or not dist.is_initialized():
        return obj
    payload: List[Any] = [obj if rank == src else None]
    dist.broadcast_object_list(payload, src=src)
    return payload[0]


def _fetch_signal_payload(base_url: str, key: str, timeout: float, rank: int) -> Dict[str, Any]:
    payload: Dict[str, Any] = {}
    if rank == 0:
        try:
            print(f"Fetching signal payload for key: {key} from {base_url}")
            response = requests.get(
                f"{base_url}/wait/{key}",
                params={"timeout": timeout},
                timeout=timeout + 5,
            )
            response.raise_for_status()
            payload = response.json().get("data", {}) or {}
        except Exception as exc:  # noqa: BLE001
            logger.error("Failed to fetch signal payload for key %s: %s", key, exc)
            payload = {}
            
        print(f"Get payload for key: {key}")
    return _broadcast_object(payload, rank)


def _submit_signal_payload(base_url: str, key: str, payload: Dict[str, Any], rank: int, timeout: float) -> None:
    if rank != 0:
        return
    try:
        print(f"Submitting signal payload for key: {key} to {base_url}")
        requests.post(
            f"{base_url}/submit/{key}",
            json=payload,
            timeout=timeout,
        )
    except Exception as exc:  # noqa: BLE001
        logger.error("Failed to submit payload for key %s: %s", key, exc)


def _move_pipeline_to(device: torch.device, pipeline, inference_dtype: torch.dtype, optimizer=None) -> None:
    if device.type == "cuda":
        pipeline.vae = pipeline.vae.to(device)
        pipeline.text_encoder = pipeline.text_encoder.to(device)
        load_fsdp_model_to_gpu(pipeline.transformer, device)
    else:
        pipeline.vae = pipeline.vae.to(device)
        pipeline.text_encoder = pipeline.text_encoder.to(device)
        offload_fsdp_model_to_cpu(pipeline.transformer)

    print(f"Move pipeline to {device}, with dtype={next(pipeline.transformer.parameters()).dtype}")

    if device.type != "cuda":
        torch.cuda.empty_cache()
        

def save_images_and_metadata(images, prompts, metadatas, rewards, save_dir, global_step, split='train', rank=0, world_size=1):
    """
    Save images and metadata with structure:
      - metadata: {save_dir}/{split}_{global_step}/{data_source}_{sample_idx}/{rollout_id}.json
      - images: {save_dir}/{split}_{global_step}/{data_source}_{sample_idx}/{rollout_id}_images/
    Multiple images can belong to the same rollout_id, and multiple rollout_ids can belong to the same {data_source}_{sample_idx}.
    """
    step_dir = os.path.join(save_dir, f"{split}_{global_step}")
    os.makedirs(step_dir, exist_ok=True)

    # Normalize tensors/arrays
    if isinstance(images, torch.Tensor):
        images_np = images.detach().cpu().float().numpy()
    else:
        images_np = np.asarray(images)

    num_local = len(images_np)
    rewards_np = {k: np.asarray(v) for k, v in rewards.items()}

    # Group images by (data_source, sample_idx, rollout_id)
    from collections import defaultdict
    groups = defaultdict(list)
    
    for i in range(num_local):
        meta = metadatas[i] if i < len(metadatas) else {}
        
        # Extract metadata fields
        if isinstance(meta, dict):
            # Try to get from top-level or nested metadata
            rollout_id = meta.get("rollout_id") or meta.get("metadata", {}).get("rollout_id")
            data_source = meta.get("data_source") or meta.get("metadata", {}).get("data_source", "unknown")
            sample_idx = meta.get("sample_idx") or meta.get("idx") or meta.get("metadata", {}).get("sample_idx") or meta.get("metadata", {}).get("idx", i)
        else:
            rollout_id = None
            data_source = "unknown"
            sample_idx = i
        
        # Generate rollout_id if not present
        if not rollout_id:
            import uuid
            rollout_id = str(uuid.uuid4())[:8]
        
        # Group key: (data_source, sample_idx, rollout_id)
        group_key = (str(data_source), str(sample_idx), str(rollout_id))
        groups[group_key].append(i)
    
    # Process each group
    for (data_source, sample_idx, rollout_id), indices in groups.items():
        # Create folder name: {data_source}_{sample_idx}
        safe_data_source = data_source.replace("/", "_").replace(" ", "_")
        folder_name = f"{safe_data_source}_{sample_idx}"
        sample_dir = os.path.join(step_dir, folder_name)
        os.makedirs(sample_dir, exist_ok=True)
        
        # Create images directory: {rollout_id}_images
        images_dir = os.path.join(sample_dir, f"{rollout_id}_images")
        os.makedirs(images_dir, exist_ok=True)
        
        # Collect all images and rewards for this rollout_id
        image_paths = []
        all_rewards_for_group = defaultdict(list)
        rep_idx = indices[0]
        rep_meta = metadatas[rep_idx] if rep_idx < len(metadatas) else {}
        rep_prompt = prompts[rep_idx] if rep_idx < len(prompts) else ""
        
        for seq, i in enumerate(indices):
            # Save image
            img = images_np[i]
            if img.ndim == 3 and img.shape[0] in (1, 3, 4):
                img_arr = (np.clip(img, 0.0, 1.0).transpose(1, 2, 0) * 255).astype(np.uint8)
            elif img.ndim == 2:
                img_arr = (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)
            else:
                # Fallback: try last-dim channels
                if img.ndim == 3:
                    img_arr = (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)
                else:
                    continue
            
            image_filename = f"image_{seq}_rank{rank}.jpg"
            Image.fromarray(img_arr).save(os.path.join(images_dir, image_filename))
            image_paths.append(f"{rollout_id}_images/{image_filename}")
            
            # Collect rewards
            for k in rewards_np.keys():
                all_rewards_for_group[k].append(float(rewards_np[k][i]))
        
        # Create metadata JSON with all images for this rollout_id
        metadata_obj = {
            "prompt": rep_prompt,
            "rollout_id": rollout_id,
            "data_source": data_source,
            "sample_idx": sample_idx,
            "rank": int(rank),
            "global_step": int(global_step),
            "split": split,
            "num_images": len(image_paths),
            "image_paths": image_paths,
            "rewards": dict(all_rewards_for_group),
            "original_metadata": rep_meta,
        }
        
        # Save metadata JSON: {rollout_id}.json
        json_path = os.path.join(sample_dir, f"{rollout_id}.json")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(metadata_obj, f, indent=2, ensure_ascii=False)


def get_transformer_layer_cls():
    from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformerBlock
    from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionBlock, Qwen2_5_VLDecoderLayer
    return {
        QwenImageTransformerBlock,
        # QwenImageResidualBlock,
        # QwenImageResample,
        # QwenImageResidualBlock,
        # QwenImageMidBlock,
        # QwenImageAttentionBlock,
        Qwen2_5_VLVisionBlock,
        Qwen2_5_VLDecoderLayer
        }




def main(_):
    # basic Accelerate and logging setup
    config = FLAGS.config

    # Initialize distributed training
    is_distributed, rank, world_size, local_rank = init_distributed()
    device = torch.device(f'cuda:{local_rank}') if torch.cuda.is_available() else torch.device('cpu')
    
    # number of timesteps within each trajectory to train on
    if config.sample.sde_window_size > 0:
        num_train_timesteps = config.sample.sde_window_size
    else:
        num_train_timesteps = config.sample.num_steps - 1

    # Create project directory
    project_dir = os.path.join(config.logdir, config.run_name)
    os.makedirs(project_dir, exist_ok=True)
    
    # Create save directory for images
    images_save_dir = os.path.join(project_dir, "saved_images")
    os.makedirs(images_save_dir, exist_ok=True)
    
    if rank == 0:
        wandb.init(
            project="flow_grpo",
            # mode="disabled"
        )
    logger.info(f"\n{config}")

    # set seed (device_specific is very important to get different prompts on different devices)
    set_seed(config.seed, device_specific=True)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    inference_dtype = torch.float32
    if config.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif config.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16

    # load scheduler, tokenizer and models.
    pipeline = QwenImageEditPipeline.from_pretrained(
        config.pretrained.model,
        torch_dtype=inference_dtype,
        low_cpu_mem_usage=True,
    )
    
    # Switch Text Encoder
    text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        config.pretrained.te_model,
        torch_dtype=inference_dtype,
        low_cpu_mem_usage=True
    )
    old_te = pipeline.text_encoder
    pipeline.text_encoder = text_encoder
    del old_te
    
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.transformer.requires_grad_(not config.use_lora)
    # disable safety checker
    pipeline.safety_checker = None
    # make the progress bar nicer
    pipeline.set_progress_bar_config(
        position=1,
        disable=local_rank != 0,
        leave=False,
        desc="Timestep",
        dynamic_ncols=True,
    )

    if config.use_lora:
        # Set correct lora layers
        target_modules = [
            "attn.to_k",
            "attn.to_q",
            "attn.to_v",
            "attn.to_out.0",
            "attn.add_k_proj",
            "attn.add_q_proj",
            "attn.add_v_proj",
            "attn.to_add_out",
            "img_mlp.net.0.proj",
            "img_mlp.net.2",
            "txt_mlp.net.0.proj",
            "txt_mlp.net.2",
        ]
        transformer_lora_config = LoraConfig(
            r=64,
            lora_alpha=128,
            init_lora_weights="gaussian",
            target_modules=target_modules,
        )
        if config.train.lora_path:
            pipeline.transformer = PeftModel.from_pretrained(pipeline.transformer, config.train.lora_path)
            # After loading with PeftModel.from_pretrained, all parameters have requires_grad set to False. You need to call set_adapter to enable gradients for the adapter parameters.
            pipeline.transformer.set_adapter("default")
        else:
            pipeline.transformer = get_peft_model(pipeline.transformer, transformer_lora_config)
    
    transformer = pipeline.transformer

    # Setup FSDP configuration
    fsdp_config = FSDPConfig(
        sharding_strategy="FULL_SHARD",
        backward_prefetch="BACKWARD_PRE",
        cpu_offload=False,  # Set to True if memory is limited
        num_replicate=1,
        num_shard=world_size,
        mixed_precision_dtype=inference_dtype,
        use_activation_checkpointing=config.activation_checkpointing,
        use_device_mesh=False, 
    )
    # Wrap language model with FSDP
    transformer.cpu().to(dtype=torch.float32)
    transformer = fsdp_wrapper(transformer, fsdp_config, get_transformer_layer_cls)
    pipeline.transformer = transformer

    if config.train.beta > 0:
        transformer_ref = QwenImageTransformer2DModel.from_pretrained(
            config.pretrained.model,
            subfolder="transformer",
            torch_dtype=inference_dtype
        )
        transformer_ref.eval()
        transformer_ref.requires_grad_(False)
        transformer_ref.cpu().to(dtype=torch.float32)
        transformer_ref = fsdp_wrapper(transformer_ref, fsdp_config, get_transformer_layer_cls)
    
    # pipeline.text_encoder.cpu().to(dtype=torch.float32)
    # pipeline.text_encoder = fsdp_wrapper(pipeline.text_encoder, fsdp_config, get_transformer_layer_cls)

    transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
    # This ema setting affects the previous 20 × 8 = 160 steps on average.
    # ema = EMAModuleWrapper(transformer_trainable_parameters, decay=0.9, update_step_interval=8, device=device)
    ema = None
    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    # Initialize the optimizer
    if config.train.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        transformer_trainable_parameters,
        lr=config.train.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )

    if config.fsdp_optimizer_offload:
        register_optimizer_offload_hooks(optimizer)
    
    train_dataset = None
    test_dataset = None
    train_dataloader = None
    test_dataloader = None

    if config.sample.num_image_per_prompt == 1:
        config.per_prompt_stat_tracking = False
    # initialize stat tracker
    if config.per_prompt_stat_tracking:
        stat_tracker = PerPromptStatTracker(config.sample.global_std)

    # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
    # more memory
    if config.mixed_precision == "fp16":
        autocast = lambda: torch.amp.autocast('cuda', dtype=torch.float16)
    elif config.mixed_precision == "bf16":
        autocast = lambda: torch.amp.autocast('cuda', dtype=torch.bfloat16)
    else:
        autocast = contextlib.nullcontext

    # FSDP doesn't need deepspeed configuration
    # prepare prompt and reward fn
    reward_fn = getattr(flow_grpo.rewards, 'multi_score')(device, config.reward_fn)
    eval_reward_fn = getattr(flow_grpo.rewards, 'multi_score')(device, config.reward_fn)
    selection_reward_config = getattr(config, "selection_reward_fn", None)
    selection_reward_fn = getattr(flow_grpo.rewards, 'multi_score')(device, selection_reward_config)
    
    # FSDP setup completed above
    # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
    # remote server running llava inference.
    executor = futures.ThreadPoolExecutor(max_workers=4)

    try:
        signal_config = SignalConfig()
    except ValueError as exc:
        raise ValueError("Signal base URL is required. Set TRAIN_SIGNAL_URL or config.signal.url.") from exc

    signal_base_url = signal_config.url
    signal_timeout = signal_config.timeout

    cpu_device = torch.device("cpu")
    
    global_step = 0
    
    vlm_train_step = os.getenv("VLM_TRAIN_STEP", None)
    dit_train_step = os.getenv("DIT_TRAIN_STEP", None)
    
    frozen = False
    
    if not config.train.get("enable", True):
        frozen = True

    # 在训练循环开始前进行一次验证
    if config.eval_freq > 0:
        validate(
            global_step=0,
            signal_base_url=signal_base_url,
            signal_timeout=signal_timeout,
            signal_config=signal_config,
            rank=rank,
            world_size=world_size,
            local_rank=local_rank,
            pipeline=pipeline,
            config=config,
            device=device,
            eval_reward_fn=eval_reward_fn,
            executor=executor,
            autocast=autocast,
            ema=ema,
            transformer_trainable_parameters=transformer_trainable_parameters,
            images_save_dir=images_save_dir,
            is_distributed=is_distributed,
        )

    
    while global_step < 1000:
        
        global_step += 1
        
        pipeline.transformer.eval()

        train_key = signal_config.get_key("train", "data", global_step)
        if rank == 0:
            logger.info(f"Step {global_step}: Fetching training data with key: {train_key}")
        train_payload = _fetch_signal_payload(signal_base_url, train_key, signal_timeout, rank)
        raw_train_samples = train_payload.get("train_data") or train_payload.get("data") or []
        if rank == 0:
            logger.info(f"Step {global_step}: Received {len(raw_train_samples)} train samples")
        train_samples = _normalize_signal_samples(raw_train_samples)
        if rank == 0:
            logger.info(f"Step {global_step}: Normalized {len(train_samples)} prompts before selection")
        
        if not train_samples:
            if rank == 0:
                logger.warning("Step %s received empty train payload; skipping.", global_step)
            if is_distributed:
                dist.barrier()
            continue

        payload_train_step = train_payload.get("global_step")
        if payload_train_step is not None and payload_train_step != global_step and rank == 0:
            logger.warning(
                "Train payload step mismatch: expected %s, received %s", global_step, payload_train_step
            )

        # 获取训练数据后将模型移到GPU
        if rank == 0:
            logger.info(f"Step {global_step}: Moving pipeline to {device}")
        _move_pipeline_to(device, pipeline, inference_dtype, optimizer)
        if is_distributed:
            dist.barrier()

        selected_samples, selection_reward_entries = perform_prompt_selection(
            train_samples,
            pipeline,
            config,
            device,
            rank,
            local_rank,
            world_size,
            selection_reward_fn,
            autocast,
            global_step,
            is_distributed,
            executor,
            images_save_dir,
        )
        reward_entries_local_for_signal = selection_reward_entries
        
        global_reward_entries = _gather_all_objects(reward_entries_local_for_signal, rank, world_size)
        
        if not frozen:
            if not selected_samples:
                selected_samples = [copy.deepcopy(sample) for sample in train_samples]
            if rank == 0:
                logger.info(
                    "Step %s: Proceeding with %s prompts for unified reward training",
                    global_step,
                    len(selected_samples),
                )
            train_samples = [
                copy.deepcopy(sample)
                for sample in selected_samples
                for _ in range(config.sample.num_image_per_prompt)
            ]
            if rank == 0:
                logger.info(
                    "Step %s: Expanded selected prompts to %s samples (num_image_per_prompt=%s)",
                    global_step,
                    len(train_samples),
                    config.sample.num_image_per_prompt,
                )

            train_dataset = InMemoryPromptImageDataset(train_samples)
            if is_distributed:
                train_sampler = DistributedSampler(
                    train_dataset,
                    num_replicas=world_size,
                    rank=rank,
                    shuffle=True,
                    drop_last=False,
                )
            else:
                train_sampler = None

            train_dataloader = DataLoader(
                train_dataset,
                batch_size=config.sample.train_batch_size,
                sampler=train_sampler,
                shuffle=train_sampler is None,
                collate_fn=InMemoryPromptImageDataset.collate_fn,
                num_workers=1,
                drop_last=False,
            )

            if train_sampler is not None:
                train_sampler.set_epoch(global_step)

            num_batches_per_epoch = len(train_dataloader)
            samples_per_epoch = len(train_dataset)
            total_train_batch_size = (
                config.train.batch_size * world_size * config.train.gradient_accumulation_steps
            )

            if rank == 0:
                logger.info("***** Running global step %s *****", global_step)
                logger.info("  Train samples per step: %s", samples_per_epoch)
                logger.info("  Train batches per step: %s", num_batches_per_epoch)
                logger.info("  Train batch size per device = %s", config.train.batch_size)
                logger.info("  Sample batch size per device = %s", config.sample.train_batch_size)
                logger.info("  Gradient Accumulation steps = %s", config.train.gradient_accumulation_steps)

            #################### SAMPLING ####################
            if rank == 0:
                logger.info(f"Step {global_step}: Starting sampling phase")
            pipeline.transformer.eval()

            gradient_context = torch.no_grad() if frozen else contextlib.nullcontext()
            samples = []
            all_train_images = []
            all_train_prompts = []
            all_train_metadatas = []
            all_train_rewards = defaultdict(list)

            with gradient_context:
                for i, (prompts, prompt_metadata, ref_images) in enumerate(
                    tqdm(
                        train_dataloader,
                        desc=f"Step {global_step}: sampling",
                        disable=local_rank != 0,
                        position=0,
                    )
                ):
                    ref_images = [
                        ref_image.resize((config.resolution, config.resolution)) if ref_image else None
                        for ref_image in ref_images
                    ]
                    if config.sample.same_latent:
                        generator = create_generator(prompts, base_seed=global_step * 10000 + i)
                    else:
                        generator = None
                    with autocast():
                        with torch.no_grad():
                            collected_data = pipeline_with_logprob(
                                pipeline,
                                ref_images,
                                prompts,
                                negative_prompt=[" "] * len(prompts),
                                num_inference_steps=config.sample.num_steps,
                                true_cfg_scale=config.sample.guidance_scale,
                                output_type="pt",
                                height=config.resolution,
                                width=config.resolution,
                                noise_level=config.sample.noise_level,
                                generator=generator,
                                sde_window_size=config.sample.sde_window_size,
                                sde_window_range=config.sample.sde_window_range,
                            )

                    images = collected_data["images"]
                    total_images = images.shape[0]

                    prompt_ids = pipeline.tokenizer(
                        prompts,
                        padding="max_length",
                        max_length=256,
                        truncation=True,
                        return_tensors="pt",
                    ).input_ids.to(device)

                    latents = torch.stack(collected_data["all_latents"], dim=1)
                    log_probs = torch.stack(collected_data["all_log_probs"], dim=1)
                    timesteps = (
                        torch.stack(collected_data["all_timesteps"]).unsqueeze(0).repeat(total_images, 1)
                    )
                    rewards = executor.submit(
                        reward_fn,
                        images,
                        prompts,
                        prompt_metadata,
                        ref_images,
                        only_strict=True,
                    )
                    time.sleep(0)

                    all_train_images.append(images)
                    all_train_prompts.extend(prompts)
                    all_train_metadatas.extend(prompt_metadata)

                    if rank == 0 and i == 0:
                        logger.info(
                            f"Step {global_step}: First batch - generated {images.shape[0]} images"
                        )

                    samples.append(
                        {
                            "prompt_ids": prompt_ids,
                            "prompt_embeds": collected_data["prompt_embeds"],
                            "prompt_embeds_mask": collected_data["prompt_embeds_mask"],
                            "negative_prompt_embeds": collected_data["negative_prompt_embeds"],
                            "negative_prompt_embeds_mask": collected_data["negative_prompt_embeds_mask"],
                            "image_latents": collected_data.get("image_latents"),
                            "timesteps": timesteps,
                            "latents": latents[:, :-1],
                            "next_latents": latents[:, 1:],
                            "log_probs": log_probs,
                            "rewards": rewards,
                        }
                    )

            max_prompt_embeds_len = max(
                [sample["prompt_embeds_mask"].shape[1] for sample in samples]
            )

            if rank == 0:
                logger.info(
                    f"Step {global_step}: Waiting for {len(samples)} batch rewards to be computed"
                )
            for sample in tqdm(
                samples,
                desc="Waiting for rewards",
                disable=local_rank != 0,
                position=0,
            ):
                seq_pad_len = max_prompt_embeds_len - sample["prompt_embeds"].shape[1]
                sample["prompt_embeds"] = torch.nn.functional.pad(
                    sample["prompt_embeds"],
                    (0, 0, 0, seq_pad_len),
                    value=0,
                )
                sample["prompt_embeds_mask"] = torch.nn.functional.pad(
                    sample["prompt_embeds_mask"],
                    (0, seq_pad_len),
                    value=0,
                )
                sample["negative_prompt_embeds"] = torch.nn.functional.pad(
                    sample["negative_prompt_embeds"],
                    (0, 0, 0, seq_pad_len),
                    value=0,
                )
                sample["negative_prompt_embeds_mask"] = torch.nn.functional.pad(
                    sample["negative_prompt_embeds_mask"],
                    (0, seq_pad_len),
                    value=0,
                )

                rewards, reward_metadata = sample["rewards"].result()
                sample["rewards"] = {
                    key: torch.as_tensor(value, device=device).float()
                    for key, value in rewards.items()
                }
                for key, value in rewards.items():
                    all_train_rewards[key].extend(value)

            all_train_images_tensor = (
                torch.cat(all_train_images, dim=0) if len(all_train_images) > 0 else torch.empty(0)
            )
            all_train_rewards_np = {
                key: np.asarray(all_train_rewards[key]) for key in all_train_rewards.keys()
            }
            save_images_and_metadata(
                all_train_images_tensor,
                all_train_prompts,
                all_train_metadatas,
                all_train_rewards_np,
                images_save_dir,
                global_step,
                split='train',
                rank=rank,
                world_size=world_size,
            )

            samples = {
                k: torch.cat([s[k] for s in samples], dim=0)
                if not isinstance(samples[0][k], dict)
                else {
                    sub_key: torch.cat([s[k][sub_key] for s in samples], dim=0)
                    for sub_key in samples[0][k]
                }
                for k in samples[0].keys()
            }

            samples["rewards"]["ori_avg"] = samples["rewards"]["avg"]
            samples["rewards"]["avg"] = samples["rewards"]["avg"].unsqueeze(1).repeat(
                1, num_train_timesteps
            )
            if rank == 0:
                logger.info(
                    f"Step {global_step}: Gathering rewards across {world_size} processes"
                )
            gathered_rewards = {
                key: gather_tensor(value, world_size) for key, value in samples["rewards"].items()
            }
            gathered_rewards = {
                key: value.cpu().float().numpy() for key, value in gathered_rewards.items()
            }
            if rank == 0:
                logger.info(
                    "Step %s: Reward stats - %s",
                    global_step,
                    ", ".join(
                        [
                            f"{k}: {v.mean():.4f}"
                            for k, v in gathered_rewards.items()
                            if "_strict_accuracy" not in k and "_accuracy" not in k
                        ]
                    ),
                )
                wandb.log(
                    {
                        "global_step": global_step,
                        **{
                            f"reward_{key}": value.mean()
                            for key, value in gathered_rewards.items()
                            if "_strict_accuracy" not in key and "_accuracy" not in key
                        },
                    },
                    step=global_step,
                )

            if config.per_prompt_stat_tracking:
                prompt_ids = gather_tensor(samples["prompt_ids"], world_size).cpu().float().numpy()
                prompts = pipeline.tokenizer.batch_decode(
                    prompt_ids, skip_special_tokens=True
                )
                advantages = stat_tracker.update(prompts, gathered_rewards["avg"])
                group_size, trained_prompt_num = stat_tracker.get_stats()
                zero_std_ratio, reward_std_mean = calculate_zero_std_ratio(
                    prompts, gathered_rewards
                )
                if rank == 0:
                    wandb.log(
                        {
                            "group_size": group_size,
                            "trained_prompt_num": trained_prompt_num,
                            "zero_std_ratio": zero_std_ratio,
                            "reward_std_mean": reward_std_mean,
                        },
                        step=global_step,
                    )
                stat_tracker.clear()
            else:
                advantages = (
                    gathered_rewards["avg"] - gathered_rewards["avg"].mean()
                ) / (gathered_rewards["avg"].std() + 1e-4)

            advantages = torch.as_tensor(advantages)
            samples["advantages"] = (
                advantages.reshape(world_size, -1, advantages.shape[-1])[rank].to(device)
            )

            del samples["rewards"]
            del samples["prompt_ids"]

            total_batch_size, num_timesteps = samples["timesteps"].shape
            gradient_accumulation_steps = (
                config.train.gradient_accumulation_steps * num_train_timesteps
            )

            #################### TRAINING ####################
            if rank == 0:
                logger.info(
                    f"Step {global_step}: Starting training phase with {config.train.num_inner_epochs} inner epochs"
                )
            for inner_epoch in range(config.train.num_inner_epochs):
                if rank == 0:
                    logger.info(f"Step {global_step}.{inner_epoch}: Starting inner epoch")
                samples_batched = {
                    k: v.reshape(-1, config.sample.train_batch_size, *v.shape[1:])
                    for k, v in samples.items()
                }
                samples_batched = [
                    dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
                ]

                pipeline.transformer.train()
                info = defaultdict(list)
                for i, sample in tqdm(
                    list(enumerate(samples_batched)),
                    desc=f"Step {global_step}.{inner_epoch}: training",
                    position=0,
                    disable=local_rank != 0,
                ):
                    train_timesteps = [step_index for step_index in range(num_train_timesteps)]
                    for j in tqdm(
                        train_timesteps,
                        desc="Timestep",
                        position=1,
                        leave=False,
                        disable=local_rank != 0,
                    ):
                        if (i * num_train_timesteps + j + 1) % gradient_accumulation_steps == 0:
                            should_sync = True
                        else:
                            should_sync = False

                        with autocast():
                            prev_sample, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(
                                transformer, pipeline, sample, j, config, rank
                            )
                            if config.train.beta > 0:
                                with torch.no_grad():
                                    _, _, prev_sample_mean_ref, _ = compute_log_prob(
                                        transformer_ref, pipeline, sample, j, config, rank
                                    )
                        advantages = torch.clamp(
                            sample["advantages"][:, j],
                            -config.train.adv_clip_max,
                            config.train.adv_clip_max,
                        )
                        ratio = torch.exp(log_prob - sample["log_probs"][:, j])
                        unclipped_loss = -advantages * ratio
                        clipped_loss = -advantages * torch.clamp(
                            ratio,
                            1.0 - config.train.clip_range,
                            1.0 + config.train.clip_range,
                        )
                        policy_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
                        policy_loss = policy_loss / gradient_accumulation_steps
                        if config.train.beta > 0:
                            kl_loss = (
                                (prev_sample_mean - prev_sample_mean_ref) ** 2
                            ).mean(dim=(1, 2), keepdim=True) / (2 * std_dev_t ** 2)
                            kl_loss = torch.mean(kl_loss)
                            kl_loss = kl_loss / gradient_accumulation_steps
                            loss = policy_loss + config.train.beta * kl_loss
                        else:
                            loss = policy_loss

                        info["approx_kl"].append(
                            0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
                        )
                        info["clipfrac"].append(
                            torch.mean(
                                (
                                    torch.abs(ratio - 1.0) > config.train.clip_range
                                ).float()
                            )
                        )
                        info["clipfrac_gt_one"].append(
                            torch.mean(
                                (
                                    ratio - 1.0 > config.train.clip_range
                                ).float()
                            )
                        )
                        info["clipfrac_lt_one"].append(
                            torch.mean(
                                (
                                    1.0 - ratio > config.train.clip_range
                                ).float()
                            )
                        )
                        info["policy_loss"].append(policy_loss)
                        if config.train.beta > 0:
                            info["kl_loss"].append(kl_loss)

                        info["loss"].append(loss)

                        loss.backward()
                        if should_sync:
                            torch.nn.utils.clip_grad_norm_(
                                transformer.parameters(), config.train.max_grad_norm
                            )
                            optimizer.step()
                            optimizer.zero_grad()

                    if config.train.ema:
                        ema.step(transformer_trainable_parameters, global_step)

                info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                if is_distributed:
                    for k, v in info.items():
                        dist.all_reduce(v, op=dist.ReduceOp.SUM)
                        info[k] = v / world_size
                info.update({"global_step": global_step, "inner_epoch": inner_epoch})
                if rank == 0:
                    logger.info(
                        f"Step {global_step}.{inner_epoch}: Training metrics - loss: {info.get('loss', 0):.4f}, policy_loss: {info.get('policy_loss', 0):.4f}, approx_kl: {info.get('approx_kl', 0):.4f}"
                    )
                    wandb.log(info, step=global_step)

        # 保存模型
        if global_step % config.save_freq == 0:
            if rank == 0:
                logger.info(f"Step {global_step}: Saving FSDP checkpoint to {config.save_dir}")
            save_fsdp_checkpoint(config.save_dir, transformer, global_step, rank)
                
        reward_key = signal_config.get_key("train", "reward", global_step)
        reward_submission_payload = None

        if rank == 0:
            aggregated_rewards: Dict[str, Dict[str, Any]] = {}
            for entry in global_reward_entries:
                aggregated_rewards[entry["rollout_id"]] = entry

            reward_results: List[Dict[str, Any]] = []
            for rollout_id, info in aggregated_rewards.items():
                mean_score = info.get("score", 0.0)
                payload_entry: Dict[str, Any] = {"rollout_id": rollout_id, "score": mean_score}
                if info.get("metadata"):
                    payload_entry["metadata"] = info["metadata"]
                reward_results.append(payload_entry)

            reward_submission_payload = {
                "global_step": int(global_step),
                "mode": "train",
                "rewards": reward_results,
            }

        # 提交reward前将模型移到CPU
        if rank == 0:
            logger.info(f"Step {global_step}: Moving pipeline to CPU for checkpoint/eval")
        _move_pipeline_to(cpu_device, pipeline, inference_dtype, optimizer)
        if is_distributed:
            dist.barrier()

        # 提交奖励
        if rank == 0:
            logger.info(f"Step {global_step}: Submitting rewards to signal service with key: {reward_key}")
        if reward_submission_payload is not None:
            _submit_signal_payload(signal_base_url, reward_key, reward_submission_payload, rank, signal_timeout)

        # 训练完成后进行验证
        if config.eval_freq > 0 and global_step % config.eval_freq == 0:
            if rank == 0:
                logger.info(f"Step {global_step}: Starting validation")
            validate(
                global_step=global_step,
                signal_base_url=signal_base_url,
                signal_timeout=signal_timeout,
                signal_config=signal_config,
                rank=rank,
                world_size=world_size,
                local_rank=local_rank,
                pipeline=pipeline,
                config=config,
                device=device,
                eval_reward_fn=eval_reward_fn,
                executor=executor,
                autocast=autocast,
                ema=ema,
                transformer_trainable_parameters=transformer_trainable_parameters,
                images_save_dir=images_save_dir,
                is_distributed=is_distributed,
            )

        
if __name__ == "__main__":
    app.run(main)

