from collections import defaultdict
import contextlib
import os
import datetime
from concurrent import futures
import time
import json
import hashlib
import math
from typing import Any, Dict, List, Tuple
import copy
from absl import app, flags
from ml_collections import config_flags
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data.distributed import DistributedSampler
import logging

# Setup logger
logger = logging.getLogger(__name__)
from diffusers import DiffusionPipeline, QwenImageTransformer2DModel
from diffusers.utils.torch_utils import is_compiled_module

from flow_grpo.fsdp_utils import (
    FSDPConfig,
    fsdp_wrapper,
    init_distributed,
    save_fsdp_checkpoint,
    register_optimizer_offload_hooks,
    offload_fsdp_model_to_cpu,
    load_fsdp_model_to_gpu
)
import numpy as np
import flow_grpo.prompts
import flow_grpo.rewards
from flow_grpo.stat_tracking import PerPromptStatTracker
from flow_grpo.diffusers_patch.qwenimage_pipeline_with_logprob import pipeline_with_logprob
from flow_grpo.diffusers_patch.sd3_sde_with_logprob import sde_step_with_logprob
import torch
import wandb
import requests
from functools import partial
import tqdm
import tempfile
from PIL import Image
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, PeftModel
import random
from torch.utils.data import Dataset, DataLoader, Sampler
from flow_grpo.ema import EMAModuleWrapper
from transformers import Qwen2_5_VLForConditionalGeneration
from signal_config import SignalConfig

tqdm = partial(tqdm.tqdm, dynamic_ncols=True)

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.")

def gather_tensor(tensor, world_size):
    if world_size == 1:
        return tensor
    
    gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(gather_list, tensor)
    return torch.cat(gather_list)


def _gather_all_objects(obj: Any, rank: int, world_size: int) -> List[Any]:
    if not dist.is_available() or not dist.is_initialized():
        return obj
    gathered: List[Any] = [None] * world_size
    dist.all_gather_object(gathered, obj)
    combined: List[Any] = []
    for part in gathered:
        if part:
            combined.extend(part)
    return combined

def set_seed(seed, device_specific=True):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device_specific and torch.cuda.is_available():
        # For device-specific seeding
        torch.cuda.manual_seed_all(seed + dist.get_rank() if dist.is_initialized() else seed)

class TextPromptDataset(Dataset):
    def __init__(self, dataset, split='train'):
        self.file_path = os.path.join(dataset, f'{split}.txt')
        with open(self.file_path, 'r') as f:
            self.prompts = [line.strip() for line in f.readlines()]
        # qwen image比较大，为了速度，测试时只取512个样本
        if split == 'test' and dataset=='pickscore':
            self.prompts = self.prompts[:512]
        
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        return {"prompt": self.prompts[idx], "metadata": {}}

    @staticmethod
    def collate_fn(examples):
        prompts = [example["prompt"] for example in examples]
        metadatas = [example["metadata"] for example in examples]
        return prompts, metadatas

class GenevalPromptDataset(Dataset):
    def __init__(self, dataset, split='train'):
        self.file_path = os.path.join(dataset, f'{split}_metadata.jsonl')
        with open(self.file_path, 'r', encoding='utf-8') as f:
            self.metadatas = [json.loads(line) for line in f]
            self.prompts = [item['prompt'] for item in self.metadatas]
        
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx]}

    @staticmethod
    def collate_fn(examples):
        prompts = [example["prompt"] for example in examples]
        metadatas = [example["metadata"] for example in examples]
        return prompts, metadatas


class InMemoryPromptDataset(Dataset):
    """Dataset backed by in-memory prompt samples broadcast from the signal service."""

    def __init__(self, samples: List[Any]):
        self.samples = samples or []

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        sample = self.samples[idx]
        if isinstance(sample, dict):
            prompt = sample.get("prompt", "")
            metadata = sample.get("metadata", {})
        else:
            prompt = str(sample)
            metadata = {}
        return {"prompt": prompt, "metadata": metadata}

    @staticmethod
    def collate_fn(examples: List[Dict[str, Any]]):
        prompts = [example["prompt"] for example in examples]
        metadatas = [example["metadata"] for example in examples]
        return prompts, metadatas


def _broadcast_object(obj: Any, rank: int, src: int = 0) -> Any:
    if not dist.is_available() or not dist.is_initialized():
        return obj
    payload: List[Any] = [obj if rank == src else None]
    dist.broadcast_object_list(payload, src=src)
    return payload[0]


def _fetch_signal_payload(base_url: str, key: str, timeout: float, rank: int) -> Dict[str, Any]:
    payload: Dict[str, Any] = {}
    if rank == 0:
        try:
            print(f"Fetching signal payload for key: {key} from {base_url}")
            response = requests.get(
                f"{base_url}/wait/{key}",
                params={"timeout": timeout},
                timeout=timeout + 5,
            )
            response.raise_for_status()
            payload = response.json().get("data", {}) or {}
        except Exception as exc:  # noqa: BLE001
            logger.error("Failed to fetch signal payload for key %s: %s", key, exc)
            payload = {}
            
        print(f"Get payload for key: {key}")
    return _broadcast_object(payload, rank)


def _submit_signal_payload(base_url: str, key: str, payload: Dict[str, Any], rank: int, timeout: float) -> None:
    if rank != 0:
        return
    try:
        print(f"Submitting signal payload for key: {key} to {base_url}")
        requests.post(
            f"{base_url}/submit/{key}",
            json=payload,
            timeout=timeout,
        )
    except Exception as exc:  # noqa: BLE001
        logger.error("Failed to submit payload for key %s: %s", key, exc)


def _move_pipeline_to(device: torch.device, pipeline, inference_dtype: torch.dtype, optimizer=None) -> None:
    if device.type == "cuda":
        pipeline.vae = pipeline.vae.to(device)
        pipeline.text_encoder = pipeline.text_encoder.to(device)
        load_fsdp_model_to_gpu(pipeline.transformer, device)
    else:
        pipeline.vae = pipeline.vae.to(device)
        pipeline.text_encoder = pipeline.text_encoder.to(device)
        offload_fsdp_model_to_cpu(pipeline.transformer)

    # if optimizer is not None:
    #     _move_optimizer_state(optimizer, device)
    print(f"Move pipeline to {device}, with dtype={next(pipeline.transformer.parameters()).dtype}")

    if device.type != "cuda":
        torch.cuda.empty_cache()


def _load_text_encoder_weights(pipeline, text_encoder_path: str, inference_dtype: torch.dtype) -> None:
    if not text_encoder_path:
        return
    new_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        text_encoder_path,
        torch_dtype=inference_dtype,
    )
    old_encoder = pipeline.text_encoder
    pipeline.text_encoder = new_encoder
    del old_encoder
    pipeline.text_encoder.requires_grad_(False)

class DistributedKRepeatSampler(Sampler):
    def __init__(self, dataset, batch_size, k, num_replicas, rank, seed=0):
        self.dataset = dataset
        self.batch_size = batch_size  # Batch size per replica
        self.k = k                    # Number of repetitions per sample
        self.num_replicas = num_replicas  # Total number of replicas
        self.rank = rank              # Current replica rank
        self.seed = seed              # Random seed for synchronization
        
        # Compute the number of unique samples needed per iteration
        self.total_samples = self.num_replicas * self.batch_size
        assert self.total_samples % self.k == 0, f"k can not divide n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}"
        self.m = self.total_samples // self.k  # Number of unique samples
        self.epoch = 0

    def __iter__(self):
        while True:
            # Generate a deterministic random sequence to ensure all replicas are synchronized
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            
            # Randomly select m unique samples
            indices = torch.randperm(len(self.dataset), generator=g)[:self.m].tolist()
            
            # Repeat each sample k times to generate n*b total samples
            repeated_indices = [idx for idx in indices for _ in range(self.k)]
            
            # Shuffle to ensure uniform distribution
            shuffled_indices = torch.randperm(len(repeated_indices), generator=g).tolist()
            shuffled_samples = [repeated_indices[i] for i in shuffled_indices]
            
            # Split samples to each replica
            per_card_samples = []
            for i in range(self.num_replicas):
                start = i * self.batch_size
                end = start + self.batch_size
                per_card_samples.append(shuffled_samples[start:end])
            
            # Return current replica's sample indices
            yield per_card_samples[self.rank]
    
    def set_epoch(self, epoch):
        self.epoch = epoch  # Used to synchronize random state across epochs


def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device):
    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
            text_encoders, tokenizers, prompt, max_sequence_length
        )
        prompt_embeds = prompt_embeds.to(device)
        pooled_prompt_embeds = pooled_prompt_embeds.to(device)
        text_ids = text_ids.to(device)
    return prompt_embeds, pooled_prompt_embeds

def calculate_zero_std_ratio(prompts, gathered_rewards):
    """
    Calculate the proportion of unique prompts whose reward standard deviation is zero.
    
    Args:
        prompts: List of prompts.
        gathered_rewards: Dictionary containing rewards, must include the key 'ori_avg'.
        
    Returns:
        zero_std_ratio: Proportion of prompts with zero standard deviation.
        prompt_std_devs: Mean standard deviation across all unique prompts.
    """
    # Convert prompt list to NumPy array
    prompt_array = np.array(prompts)
    
    # Get unique prompts and their group information
    unique_prompts, inverse_indices, counts = np.unique(
        prompt_array, 
        return_inverse=True,
        return_counts=True
    )
    
    # Group rewards for each prompt
    grouped_rewards = gathered_rewards['ori_avg'][np.argsort(inverse_indices)]
    split_indices = np.cumsum(counts)[:-1]
    reward_groups = np.split(grouped_rewards, split_indices)
    
    # Calculate standard deviation for each group
    prompt_std_devs = np.array([np.std(group) for group in reward_groups])
    
    # Calculate the ratio of zero standard deviation
    zero_std_count = np.count_nonzero(prompt_std_devs == 0)
    zero_std_ratio = zero_std_count / len(prompt_std_devs)
    
    return zero_std_ratio, prompt_std_devs.mean()

def create_generator(prompts, base_seed):
    generators = []
    for prompt in prompts:
        # Use a stable hash (SHA256), then convert it to an integer seed
        hash_digest = hashlib.sha256(prompt.encode()).digest()
        prompt_hash_int = int.from_bytes(hash_digest[:4], 'big')  # Take the first 4 bytes as part of the seed
        seed = (base_seed + prompt_hash_int) % (2**31) # Ensure the number is within a valid range
        gen = torch.Generator().manual_seed(seed)
        generators.append(gen)
    return generators


def _normalize_signal_samples(samples: List[Any]) -> List[Dict[str, Any]]:
    normalized: List[Dict[str, Any]] = []
    for sample in samples:
        if isinstance(sample, dict):
            prompt = sample.get("prompt", "")
            metadata = copy.deepcopy(sample.get("metadata", {}))
        else:
            prompt = str(sample)
            metadata = {}
        normalized.append({"prompt": prompt, "metadata": metadata})
    return normalized


def _ensure_rollout_id(metadata: Dict[str, Any], prompt: str, fallback_idx: int) -> str:
    rollout_id = metadata.get("rollout_id")
    if rollout_id:
        return rollout_id
    fallback_source = metadata.get("improved_prompt") or metadata.get("original_prompt") or prompt
    digest = hashlib.sha1(f"{fallback_source}_{fallback_idx}".encode("utf-8")).hexdigest()[:10]
    metadata["rollout_id"] = digest
    return digest


def _instruction_group_key(metadata: Dict[str, Any], fallback_idx: int) -> Tuple[str, str]:
    data_source = metadata.get("data_source") or metadata.get("source") or metadata.get("category") or "unknown"
    sample_idx = metadata.get("sample_idx")
    if sample_idx is None:
        sample_idx = metadata.get("idx", fallback_idx)
        metadata.setdefault("sample_idx", sample_idx)
    return str(data_source), str(sample_idx)


def perform_prompt_selection(
    samples: List[Dict[str, Any]],
    pipeline,
    config,
    device,
    rank: int,
    local_rank: int,
    world_size: int,
    selection_reward_fn,
    autocast,
    global_step: int,
    is_distributed: bool,
    executor,
    images_save_dir: str,
):
    if not samples or selection_reward_fn is None:
        fallback_samples = [copy.deepcopy(sample) for sample in samples]
        return fallback_samples, []

    transformer_training = pipeline.transformer.training
    pipeline.transformer.eval()

    selection_images_per_prompt = max(1, config.sample.num_image_per_prompt // 4)
    candidate_samples: Dict[str, Dict[str, Any]] = {}
    group_to_rollouts: Dict[Tuple[str, str], List[str]] = defaultdict(list)
    instruction_order: List[Tuple[str, str]] = []
    candidate_order: Dict[str, int] = {}

    for idx, sample in enumerate(samples):
        metadata = sample.setdefault("metadata", {})
        rollout_id = _ensure_rollout_id(metadata, sample.get("prompt", ""), idx)
        group_key = _instruction_group_key(metadata, idx)
        if group_key not in group_to_rollouts:
            instruction_order.append(group_key)
        group_to_rollouts[group_key].append(rollout_id)
        candidate_samples.setdefault(rollout_id, sample)
        candidate_order.setdefault(rollout_id, len(candidate_order))

    selection_sample_list: List[Dict[str, Any]] = []
    for sample in samples:
        for _ in range(selection_images_per_prompt):
            selection_sample_list.append(copy.deepcopy(sample))

    if not selection_sample_list:
        fallback_samples = [copy.deepcopy(sample) for sample in samples]
        return fallback_samples, []

    if rank == 0:
        logger.info(
            "Step %s: Selection stage evaluating %s prompts (%s instructions) with %s images each",
            global_step,
            len(samples),
            len(instruction_order),
            selection_images_per_prompt,
        )

    selection_dataset = InMemoryPromptDataset(selection_sample_list)
    if is_distributed:
        selection_sampler: Sampler = DistributedSampler(
            selection_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            drop_last=False,
        )
        selection_sampler.set_epoch(global_step)
    else:
        selection_sampler = None

    selection_dataloader = DataLoader(
        selection_dataset,
        batch_size=config.sample.train_batch_size,
        sampler=selection_sampler,
        shuffle=selection_sampler is None,
        collate_fn=InMemoryPromptDataset.collate_fn,
        num_workers=1,
        drop_last=False,
    )

    selection_stats_local: Dict[str, Dict[str, float]] = defaultdict(lambda: {"sum": 0.0, "count": 0.0})
    selection_futures: List[futures.Future] = []
    selection_batches: List[Tuple[torch.Tensor, List[str], List[Dict[str, Any]]]] = []

    with torch.no_grad():
        for i, (prompts, prompt_metadata) in enumerate(
            tqdm(
                selection_dataloader,
                desc="Selection: sampling",
                disable=local_rank != 0,
                position=0,
            )
        ):
            if config.sample.same_latent:
                generator = create_generator(prompts, base_seed=global_step * 10000 + 100000 + i)
            else:
                generator = None

            with autocast():
                collected_data = pipeline_with_logprob(
                    pipeline,
                    prompts,
                    negative_prompt=[" "] * len(prompts),
                    num_inference_steps=config.sample.num_steps,
                    true_cfg_scale=config.sample.guidance_scale,
                    output_type="pt",
                    height=config.resolution,
                    width=config.resolution,
                    noise_level=config.sample.noise_level,
                    generator=generator,
                    sde_window_size=config.sample.sde_window_size,
                    sde_window_range=config.sample.sde_window_range,
                )

            images = collected_data["images"]
            images_cpu = images.detach().cpu()
            future = executor.submit(
                selection_reward_fn,
                images_cpu,
                prompts,
                prompt_metadata,
                only_strict=True,
            )
            time.sleep(0)
            selection_futures.append(future)
            selection_batches.append((images_cpu, prompts, prompt_metadata))

    selection_all_images: List[torch.Tensor] = []
    selection_all_prompts: List[str] = []
    selection_all_metadatas: List[Dict[str, Any]] = []
    selection_rewards_local: Dict[str, List[np.ndarray]] = defaultdict(list)

    # 处理 selection_batches 和 selection_futures，显示进度条（仅在 rank==0）
    batch_iter = zip(selection_batches, selection_futures)
    for (images_cpu, prompts, prompt_metadata), reward_future in tqdm(
        batch_iter,
        desc="Selection: scoring",
        disable=local_rank != 0,
        position=0,
    ):
        try:
            score_dict, _ = reward_future.result()
        except Exception as exc:  # noqa: BLE001
            if rank == 0:
                logger.error("Selection reward computation failed: %s", exc)
            continue

        selection_all_images.append(images_cpu)
        selection_all_prompts.extend(prompts)
        selection_all_metadatas.extend(prompt_metadata)

        for key, value in score_dict.items():
            if isinstance(value, torch.Tensor):
                arr = value.detach().cpu().numpy()
            else:
                arr = np.asarray(value)
            selection_rewards_local[key].append(arr)

        batch_scores = score_dict.get("avg") or []
        if isinstance(batch_scores, torch.Tensor):
            batch_scores = batch_scores.detach().cpu().tolist()
        elif isinstance(batch_scores, np.ndarray):
            batch_scores = batch_scores.tolist()

        for sample_idx, metadata in enumerate(prompt_metadata):
            rollout_id = metadata.get("rollout_id")
            if rollout_id is None or sample_idx >= len(batch_scores):
                continue
            score_val = float(batch_scores[sample_idx])
            stats = selection_stats_local[rollout_id]
            stats["sum"] += score_val
            stats["count"] += 1.0

    if selection_all_images:
        selection_images_tensor = torch.cat(selection_all_images, dim=0)
        selection_rewards_np = {key: np.concatenate(value) for key, value in selection_rewards_local.items() if value}
        save_images_and_metadata(
            selection_images_tensor,
            selection_all_prompts,
            selection_all_metadatas,
            selection_rewards_np,
            images_save_dir,
            global_step,
            split='selection',
            rank=rank,
            world_size=world_size,
        )

    if transformer_training:
        pipeline.transformer.train()

    selection_stats_payload = [
        {"rollout_id": rollout_id, "sum": stats["sum"], "count": stats["count"]}
        for rollout_id, stats in selection_stats_local.items()
    ]
    aggregated_selection_stats = _gather_all_objects(selection_stats_payload, rank, world_size)

    prompt_mean_scores: Dict[str, float] = {}
    aggregated_map: Dict[str, Dict[str, float]] = defaultdict(lambda: {"sum": 0.0, "count": 0.0})
    for item in aggregated_selection_stats:
        rollout_id = item.get("rollout_id")
        if not rollout_id:
            continue
        aggregated_entry = aggregated_map[rollout_id]
        aggregated_entry["sum"] += float(item.get("sum", 0.0))
        aggregated_entry["count"] += float(item.get("count", 0.0))

    for rollout_id, stats in aggregated_map.items():
        count = stats.get("count", 0.0)
        if count <= 0:
            continue
        prompt_mean_scores[rollout_id] = stats["sum"] / count

    # 记录第一次采样阶段的所有 prompt 分数，用于上游 reward 服务
    selection_reward_entries: List[Dict[str, Any]] = []
    if rank == 0:
        for rollout_id, score in prompt_mean_scores.items():
            sample = candidate_samples.get(rollout_id)
            if sample is None:
                continue
            meta = copy.deepcopy(sample.get("metadata", {}))
            meta.pop("images", None)
            meta["selection_images_per_prompt"] = selection_images_per_prompt
            selection_reward_entries.append(
                {
                    "rollout_id": rollout_id,
                    "score": float(score),
                    "metadata": meta,
                }
            )

    candidate_mean_scores: Dict[str, float] = {}
    for rollout_id in candidate_samples.keys():
        if rollout_id in prompt_mean_scores:
            candidate_mean_scores[rollout_id] = prompt_mean_scores[rollout_id]
        else:
            candidate_mean_scores[rollout_id] = float("-inf")

    selected_rollout_ids: List[str] = []
    for group_key in instruction_order:
        candidate_ids = group_to_rollouts.get(group_key, [])
        if not candidate_ids:
            continue
        best_rollout = max(
            candidate_ids,
            key=lambda rid: (
                candidate_mean_scores.get(rid, float("-inf")),
                -candidate_order.get(rid, 0),
            ),
        )
        selected_rollout_ids.append(best_rollout)

    selected_samples = [copy.deepcopy(candidate_samples[rid]) for rid in selected_rollout_ids if rid in candidate_samples]
    if not selected_samples:
        selected_samples = [copy.deepcopy(sample) for sample in samples]
        if rank == 0:
            logger.warning("Selection stage failed to pick candidates; falling back to all prompts.")
    elif rank == 0:
        logger.info(
            "Step %s: Selected %s prompts for unified reward stage", global_step, len(selected_samples)
        )

    return selected_samples, selection_reward_entries


def save_images_and_metadata(images, prompts, metadatas, rewards, save_dir, global_step, split='train', rank=0, world_size=1):
    """
    Save images and metadata with structure:
      - metadata: {save_dir}/{split}_{global_step}/{data_source}_{sample_idx}/{rollout_id}.json
      - images: {save_dir}/{split}_{global_step}/{data_source}_{sample_idx}/{rollout_id}_images/
    Multiple images can belong to the same rollout_id, and multiple rollout_ids can belong to the same {data_source}_{sample_idx}.
    """
    step_dir = os.path.join(save_dir, f"{split}_{global_step}")
    os.makedirs(step_dir, exist_ok=True)

    # Normalize tensors/arrays
    if isinstance(images, torch.Tensor):
        images_np = images.detach().cpu().float().numpy()
    else:
        images_np = np.asarray(images)

    num_local = len(images_np)
    rewards_np = {k: np.asarray(v) for k, v in rewards.items()}

    # Group images by (data_source, sample_idx, rollout_id)
    from collections import defaultdict
    groups = defaultdict(list)
    
    for i in range(num_local):
        meta = metadatas[i] if i < len(metadatas) else {}
        
        # Extract metadata fields
        if isinstance(meta, dict):
            # Try to get from top-level or nested metadata
            rollout_id = meta.get("rollout_id")
            data_source = meta.get("data_source")
            sample_idx = meta.get("sample_idx")
        else:
            rollout_id = None
            data_source = "unknown"
            sample_idx = i
        
        # Generate rollout_id if not present
        if not rollout_id:
            import uuid
            rollout_id = str(uuid.uuid4())[:8]
        
        # Group key: (data_source, sample_idx, rollout_id)
        group_key = (str(data_source), str(sample_idx), str(rollout_id))
        groups[group_key].append(i)
    
    # Process each group
    for (data_source, sample_idx, rollout_id), indices in groups.items():
        # Create folder name: {data_source}_{sample_idx}
        safe_data_source = data_source.replace("/", "_").replace(" ", "_")
        folder_name = f"{safe_data_source}_{sample_idx}"
        sample_dir = os.path.join(step_dir, folder_name)
        os.makedirs(sample_dir, exist_ok=True)
        
        # Create images directory: {rollout_id}_images
        images_dir = os.path.join(sample_dir, f"{rollout_id}_images")
        os.makedirs(images_dir, exist_ok=True)
        
        # Collect all images and rewards for this rollout_id
        image_paths = []
        all_rewards_for_group = defaultdict(list)
        rep_idx = indices[0]
        rep_meta = metadatas[rep_idx] if rep_idx < len(metadatas) else {}
        rep_prompt = prompts[rep_idx] if rep_idx < len(prompts) else ""
        
        for seq, i in enumerate(indices):
            # Save image
            img = images_np[i]
            if img.ndim == 3 and img.shape[0] in (1, 3, 4):
                img_arr = (np.clip(img, 0.0, 1.0).transpose(1, 2, 0) * 255).astype(np.uint8)
            elif img.ndim == 2:
                img_arr = (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)
            else:
                # Fallback: try last-dim channels
                if img.ndim == 3:
                    img_arr = (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)
                else:
                    continue
            
            image_filename = f"image_{seq}_rank{rank}.jpg"
            Image.fromarray(img_arr).save(os.path.join(images_dir, image_filename))
            image_paths.append(f"{rollout_id}_images/{image_filename}")
            
            # Collect rewards
            for k in rewards_np.keys():
                all_rewards_for_group[k].append(float(rewards_np[k][i]))
        
        # Create metadata JSON with all images for this rollout_id
        metadata_obj = {
            "prompt": rep_prompt,
            "rollout_id": rollout_id,
            "data_source": data_source,
            "sample_idx": sample_idx,
            "rank": int(rank),
            "global_step": int(global_step),
            "split": split,
            "num_images": len(image_paths),
            "image_paths": image_paths,
            "rewards": dict(all_rewards_for_group),
            "original_metadata": rep_meta,
        }
        
        # Save metadata JSON: {rollout_id}.json
        json_path = os.path.join(sample_dir, f"{rollout_id}.json")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(metadata_obj, f, indent=2, ensure_ascii=False)

def compute_log_prob(transformer, pipeline, sample, j, config, rank):
    img_shapes = [[(1, config.resolution // pipeline.vae_scale_factor // 2, config.resolution // pipeline.vae_scale_factor // 2)]] * len(sample["latents"][:, j])
    txt_seq_lens = sample["prompt_embeds_mask"].sum(dim=1).tolist()
    negative_txt_seq_lens = sample["negative_prompt_embeds_mask"].sum(dim=1).tolist()

    # Predict the noise residual
    # txt_seq_lens是最长的,sample["prompt_embeds_mask"]和sample["prompt_embeds"]可能有没必要的padding
    sample["prompt_embeds_mask"] = sample["prompt_embeds_mask"][:, :max(txt_seq_lens+negative_txt_seq_lens)]
    sample["negative_prompt_embeds_mask"] = sample["negative_prompt_embeds_mask"][:, :max(txt_seq_lens+negative_txt_seq_lens)]
    sample["prompt_embeds"] = sample["prompt_embeds"][:, :max(txt_seq_lens+negative_txt_seq_lens)]
    sample["negative_prompt_embeds"] = sample["negative_prompt_embeds"][:, :max(txt_seq_lens+negative_txt_seq_lens)]


    noise_pred = transformer(
        hidden_states=torch.cat([sample["latents"][:, j], sample["latents"][:, j]], dim=0),
        timestep=torch.cat([sample["timesteps"][:, j], sample["timesteps"][:, j]], dim=0) / 1000,
        guidance=None,
        encoder_hidden_states_mask=torch.cat([sample["prompt_embeds_mask"], sample["negative_prompt_embeds_mask"]], dim=0),
        encoder_hidden_states=torch.cat([sample["prompt_embeds"], sample["negative_prompt_embeds"]], dim=0),
        img_shapes=img_shapes*2,
        txt_seq_lens=txt_seq_lens+negative_txt_seq_lens,
    )[0]
    noise_pred, neg_noise_pred = noise_pred.chunk(2, dim=0)
    comb_pred = neg_noise_pred + config.sample.guidance_scale * (noise_pred - neg_noise_pred)

    cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
    noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
    noise_pred = comb_pred * (cond_norm / noise_norm)
    # compute the log prob of next_latents given latents under the current model
    prev_sample, log_prob, prev_sample_mean, std_dev_t = sde_step_with_logprob(
        pipeline.scheduler,
        noise_pred.float(),
        sample["timesteps"][:, j],
        sample["latents"][:, j].float(),
        prev_sample=sample["next_latents"][:, j].float(),
        noise_level=config.sample.noise_level,
    )

    return prev_sample, log_prob, prev_sample_mean, std_dev_t

def eval(pipeline, test_dataloader, config, rank, local_rank, world_size, device, global_step, reward_fn, executor, autocast, ema, transformer_trainable_parameters, save_dir):
    if config.train.ema:
        ema.copy_ema_to(transformer_trainable_parameters, store_temp=True)

    # Separate containers for local saving and global logging
    all_rewards_gather = defaultdict(list)  # for wandb/logging
    all_rewards_local = defaultdict(list)   # for saving (no collectives)
    all_images = []
    all_prompts = []
    all_metadatas = []
    
    # 新增：评估阶段每个数据生成的图片数量（缺省为1）
    try:
        eval_num_images = getattr(config.sample, "eval_num_image_per_prompt", 1)
    except Exception:
        eval_num_images = 1
    eval_num_images = int(max(1, eval_num_images))

    last_used_prompts = None  # 用于日志中 caption 对齐

    for test_batch in tqdm(
            test_dataloader,
            desc="Eval: ",
            disable=local_rank != 0,
            position=0,
        ):
        prompts, prompt_metadata = test_batch

        # 新增：按 num_image_per_prompt 展开 batch
        if eval_num_images > 1:
            batch_prompts = [p for p in prompts for _ in range(eval_num_images)]
            batch_metadatas = [m for m in prompt_metadata for _ in range(eval_num_images)]
        else:
            batch_prompts = prompts
            batch_metadatas = prompt_metadata

        with autocast():
            with torch.no_grad():
                collected_data = pipeline_with_logprob(
                        pipeline,
                        batch_prompts,
                        negative_prompt=[" "]*len(batch_prompts),
                        num_inference_steps=config.sample.eval_num_steps,
                        true_cfg_scale=config.sample.guidance_scale,
                        output_type="pt",
                        height=config.resolution,
                        width=config.resolution, 
                        noise_level=0,
                        sde_window_size=0,
                )
        images = collected_data["images"]
        # 使用展开后的 prompts 和 metadatas 计算 reward
        rewards = executor.submit(reward_fn, images, batch_prompts, batch_metadatas, only_strict=False)
        time.sleep(0)
        rewards, reward_metadata = rewards.result()

        # Accumulate local data for saving
        all_images.append(images)
        all_prompts.extend(batch_prompts)        # 使用展开后的 prompts
        all_metadatas.extend(batch_metadatas)    # 使用展开后的 metadatas
        for key, value in rewards.items():
            all_rewards_local[key].append(np.asarray(value))

        # Accumulate gathered rewards for logging
        for key, value in rewards.items():
            rewards_gather = gather_tensor(torch.as_tensor(value, device=device), world_size).cpu().float().numpy()
            all_rewards_gather[key].append(rewards_gather)

        # 记录最后一次 batch 的 prompts（展开后）用于日志 caption
        last_used_prompts = batch_prompts
    
    # Concatenate all images and local rewards, then save locally (group by prompt, single json per prompt)
    all_images_local = torch.cat(all_images, dim=0) if len(all_images) > 0 else torch.empty(0)
    all_rewards_local = {key: np.concatenate(value) for key, value in all_rewards_local.items()}
    save_images_and_metadata(
        all_images_local, 
        all_prompts, 
        all_metadatas, 
        all_rewards_local,
        save_dir,
        global_step,
        split='eval',
        rank=rank,
        world_size=world_size
    )
    
    # Existing logging (uses gathered rewards)
    last_batch_images_gather = gather_tensor(torch.as_tensor(images, device=device), world_size).cpu().float().numpy()
    # 使用展开后的 last_used_prompts 构造 caption，避免数量不一致
    last_batch_prompt_ids = pipeline.tokenizer(
        last_used_prompts if last_used_prompts is not None else batch_prompts,
        padding="max_length",
        max_length=256,
        truncation=True,
        return_tensors="pt",
    ).input_ids.to(device)
    last_batch_prompt_ids_gather = gather_tensor(last_batch_prompt_ids, world_size).cpu().float().numpy()
    last_batch_prompts_gather = pipeline.tokenizer.batch_decode(
        last_batch_prompt_ids_gather, skip_special_tokens=True
    )
    last_batch_rewards_gather = {}
    for key, value in rewards.items():
        last_batch_rewards_gather[key] = gather_tensor(torch.as_tensor(value, device=device), world_size).cpu().float().numpy()

    all_rewards = {key: np.concatenate(value) for key, value in all_rewards_gather.items()}
    if rank == 0:
        with tempfile.TemporaryDirectory() as tmpdir:
            num_samples = min(15, len(last_batch_images_gather))
            sample_indices = range(num_samples)
            for idx, index in enumerate(sample_indices):
                image = last_batch_images_gather[index]
                pil = Image.fromarray(
                    (image.transpose(1, 2, 0) * 255).astype(np.uint8)
                )
                pil = pil.resize((config.resolution, config.resolution))
                pil.save(os.path.join(tmpdir, f"{idx}.jpg"))
            sampled_prompts = [last_batch_prompts_gather[index] for index in sample_indices]
            sampled_rewards = [{k: last_batch_rewards_gather[k][index] for k in last_batch_rewards_gather} for index in sample_indices]
            for key, value in all_rewards.items():
                print(key, value.shape)
            wandb.log(
                {
                    "eval_images": [
                        wandb.Image(
                            os.path.join(tmpdir, f"{idx}.jpg"),
                            caption=f"{prompt:.1000} | " + " | ".join(f"{k}: {v:.2f}" for k, v in reward.items() if v != -10),
                        )
                        for idx, (prompt, reward) in enumerate(zip(sampled_prompts, sampled_rewards))
                    ],
                    **{f"eval_reward_{key}": np.mean(value[value != -10]) for key, value in all_rewards.items()},
                },
                step=global_step,
            )
    eval_reward_entries: List[Dict[str, Any]] = []
    # 选择用于提交的指标（默认优先 config.train.submit_metric，其次 avg）
    submit_metric = getattr(config.train, "submit_metric", "avg")
    avg_scores = None
    if all_rewards_local:
        avg_scores = all_rewards_local.get(submit_metric)
        if avg_scores is None:
            first_key = next(iter(all_rewards_local.keys()))
            avg_scores = all_rewards_local[first_key]

    # 修改：按 rollout_id 聚合（每个数据多图求平均）后再提交
    if avg_scores is not None:
        avg_array = np.asarray(avg_scores)
        rollout_to_scores: Dict[str, List[float]] = defaultdict(list)
        rollout_to_meta: Dict[str, Dict[str, Any]] = {}

        for meta, score in zip(all_metadatas, avg_array.tolist()):
            if isinstance(meta, dict):
                rollout_id = meta.get("rollout_id") or meta.get("metadata", {}).get("rollout_id")
                payload_meta = copy.deepcopy(meta)
            else:
                rollout_id = None
                payload_meta = {"metadata": meta}
            if rollout_id is None:
                continue
            payload_meta.pop("images", None)
            rollout_to_scores[rollout_id].append(float(score))
            if rollout_id not in rollout_to_meta:
                rollout_to_meta[rollout_id] = payload_meta

        for rollout_id, values in rollout_to_scores.items():
            eval_reward_entries.append(
                {
                    "rollout_id": rollout_id,
                    "score": float(np.mean(values)),  # 每个数据的多图均值
                    "metadata": rollout_to_meta[rollout_id],
                }
            )

    reward_summary = {
        key: float(np.mean(np.asarray(values))) for key, values in all_rewards_local.items()
    } if all_rewards_local else {}

    if config.train.ema:
        ema.copy_temp_to(transformer_trainable_parameters)

    return eval_reward_entries, reward_summary


def validate(
    global_step: int,
    signal_base_url: str,
    signal_timeout: float,
    signal_config: SignalConfig,  # 改为传递config对象
    rank: int,
    world_size: int,
    local_rank: int,
    pipeline,
    config,
    device,
    eval_reward_fn,
    executor,
    autocast,
    ema,
    transformer_trainable_parameters,
    images_save_dir: str,
    current_text_encoder_path: str,
    is_distributed: bool,
):
    """执行验证流程，包括获取验证数据、运行验证和提交结果"""
    eval_key = signal_config.get_key("eval", "data", global_step)
    eval_payload = _fetch_signal_payload(signal_base_url, eval_key, signal_timeout, rank)
    payload_eval_step = eval_payload.get("global_step")
    if payload_eval_step is not None and payload_eval_step != global_step and rank == 0:
        logger.warning(
            "Eval payload step mismatch: expected %s, received %s", global_step, payload_eval_step
        )

    eval_samples = eval_payload.get("eval_data") or eval_payload.get("data") or []
    if not eval_samples:
        if rank == 0:
            logger.info("Step %s received empty eval payload; skipping validation.", global_step)
        return

    # 获取数据后将模型移到GPU
    inference_dtype = torch.float32
    if config.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif config.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16
    if rank == 0:
        logger.info(f"Step {global_step}: Moving pipeline to {device} for validation")
    _move_pipeline_to(device, pipeline, inference_dtype, None)
    if is_distributed:
        dist.barrier()


    test_dataset = InMemoryPromptDataset(eval_samples)
    if is_distributed:
        eval_sampler = DistributedSampler(
            test_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=False,
            drop_last=False,
        )
    else:
        eval_sampler = None
    
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=config.sample.test_batch_size,
        sampler=eval_sampler,
        shuffle=False,
        collate_fn=InMemoryPromptDataset.collate_fn,
        num_workers=1,
        drop_last=False,
    )

    if rank == 0:
        logger.info(f"Step {global_step}: Running evaluation on {len(test_dataloader)} batches")
    eval_entries_local, eval_reward_summary_local = eval(
        pipeline, test_dataloader, config, rank, local_rank, world_size, device, 
        global_step, eval_reward_fn, executor, autocast, ema, 
        transformer_trainable_parameters, images_save_dir
    )

    # 提交reward前将模型移到CPU
    cpu_device = torch.device("cpu")
    if rank == 0:
        logger.info(f"Step {global_step}: Moving pipeline to CPU after evaluation")
    _move_pipeline_to(cpu_device, pipeline, inference_dtype, None)

    if is_distributed:
        dist.barrier()

    eval_entries_global = _gather_all_objects(eval_entries_local, rank, world_size)
    summary_list = _gather_all_objects(
        [eval_reward_summary_local] if eval_reward_summary_local else [],
        rank,
        world_size,
    )

    eval_reward_mean: Dict[str, float] = {}
    if summary_list:
        merged_summary: Dict[str, List[float]] = defaultdict(list)
        for summary in summary_list:
            for key, value in summary.items():
                merged_summary[key].append(value)
        eval_reward_mean = {key: float(np.mean(values)) for key, values in merged_summary.items()}
    if not eval_reward_mean and eval_entries_global:
        eval_reward_mean = {
            "score": float(np.mean([entry.get("score", 0.0) for entry in eval_entries_global]))
        }

    # eval_entries_global.sort(
    #     key=lambda item: (
    #         item.get("metadata", {}).get("data_source", ""),
    #         item.get("metadata", {}).get(
    #             "sample_idx", item.get("metadata", {}).get("idx", 0)
    #         ),
    #     )
    # )

    if rank == 0 and eval_entries_global:
        reward_key = signal_config.get_key("eval", "reward", global_step)
        logger.info(f"Step {global_step}: Submitting {len(eval_entries_global)} eval rewards with key: {reward_key}")
        logger.info(f"Step {global_step}: Eval reward mean: {eval_reward_mean}")
        eval_submission_payload = {
            "global_step": int(global_step),
            "mode": "eval",
            "rewards": eval_entries_global,
            "text_encoder_path": current_text_encoder_path,
        }
        _submit_signal_payload(signal_base_url, reward_key, eval_submission_payload, rank, signal_timeout)


def get_transformer_layer_cls():
    from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformerBlock
    from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionBlock, Qwen2_5_VLDecoderLayer
    return {
        QwenImageTransformerBlock,
        # QwenImageResidualBlock,
        # QwenImageResample,
        # QwenImageResidualBlock,
        # QwenImageMidBlock,
        # QwenImageAttentionBlock,
        Qwen2_5_VLVisionBlock,
        Qwen2_5_VLDecoderLayer
        }

def main(_):
    # basic Accelerate and logging setup
    config = FLAGS.config

    # Initialize distributed training
    is_distributed, rank, world_size, local_rank = init_distributed()
    device = torch.device(f'cuda:{local_rank}') if torch.cuda.is_available() else torch.device('cpu')
    
    # unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
    # if not config.run_name:
    #     config.run_name = unique_id
    # else:
    #     config.run_name += "_" + unique_id

    # number of timesteps within each trajectory to train on
    if config.sample.sde_window_size > 0:
        num_train_timesteps = config.sample.sde_window_size
    else:
        num_train_timesteps = config.sample.num_steps - 1

    # Create project directory
    project_dir = os.path.join(config.logdir, config.run_name)
    os.makedirs(project_dir, exist_ok=True)
    
    # Create save directory for images
    images_save_dir = os.path.join(project_dir, "saved_images")
    os.makedirs(images_save_dir, exist_ok=True)
    
    if rank == 0:
        wandb.init(
            project="flow_grpo",
            # mode="disabled"
        )
    logger.info(f"\n{config}")

    # set seed (device_specific is very important to get different prompts on different devices)
    set_seed(config.seed, device_specific=True)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    inference_dtype = torch.float32
    if config.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif config.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16

    # load scheduler, tokenizer and models.
    pipeline = DiffusionPipeline.from_pretrained(
        config.pretrained.model,
        torch_dtype=inference_dtype,
        low_cpu_mem_usage=True,  # 减少加载峰值 CPU 占用
    )
    
    # Switch Text Encoder
    text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        config.pretrained.te_model,
        torch_dtype=inference_dtype,
        low_cpu_mem_usage=True
    )
    old_te = pipeline.text_encoder
    pipeline.text_encoder = text_encoder
    del old_te
    
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.transformer.requires_grad_(not config.use_lora)
    # disable safety checker
    pipeline.safety_checker = None
    # make the progress bar nicer
    pipeline.set_progress_bar_config(
        position=1,
        disable=local_rank != 0,
        leave=False,
        desc="Timestep",
        dynamic_ncols=True,
    )

    if config.use_lora:
        # Set correct lora layers
        target_modules = [
            "attn.to_k",
            "attn.to_q",
            "attn.to_v",
            "attn.to_out.0",
            "attn.add_k_proj",
            "attn.add_q_proj",
            "attn.add_v_proj",
            "attn.to_add_out",
            "img_mlp.net.0.proj",
            "img_mlp.net.2",
            "txt_mlp.net.0.proj",
            "txt_mlp.net.2",
        ]
        transformer_lora_config = LoraConfig(
            r=64,
            lora_alpha=128,
            init_lora_weights="gaussian",
            target_modules=target_modules,
        )
        if config.train.lora_path:
            pipeline.transformer = PeftModel.from_pretrained(pipeline.transformer, config.train.lora_path)
            # After loading with PeftModel.from_pretrained, all parameters have requires_grad set to False. You need to call set_adapter to enable gradients for the adapter parameters.
            pipeline.transformer.set_adapter("default")
        else:
            pipeline.transformer = get_peft_model(pipeline.transformer, transformer_lora_config)
    
    transformer = pipeline.transformer

    # Setup FSDP configuration
    fsdp_config = FSDPConfig(
        sharding_strategy="FULL_SHARD",
        backward_prefetch="BACKWARD_PRE",
        cpu_offload=False,  # Set to True if memory is limited
        num_replicate=1,
        num_shard=world_size,
        mixed_precision_dtype=inference_dtype,
        use_activation_checkpointing=config.activation_checkpointing,
        use_device_mesh=False, 
    )
    # Wrap language model with FSDP
    transformer.cpu().to(dtype=torch.float32)
    transformer = fsdp_wrapper(transformer, fsdp_config, get_transformer_layer_cls)
    pipeline.transformer = transformer

    if config.train.beta > 0:
        transformer_ref = QwenImageTransformer2DModel.from_pretrained(
            config.pretrained.model,
            subfolder="transformer",
            torch_dtype=inference_dtype,
        )
        transformer_ref.eval()
        transformer_ref.requires_grad_(False)
        transformer_ref.cpu().to(dtype=torch.float32)
        transformer_ref = fsdp_wrapper(transformer_ref, fsdp_config, get_transformer_layer_cls)
    
    # pipeline.text_encoder.cpu().to(dtype=torch.float32)
    # pipeline.text_encoder = fsdp_wrapper(pipeline.text_encoder, fsdp_config, get_transformer_layer_cls)

    transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
    # This ema setting affects the previous 20 × 8 = 160 steps on average.
    # ema = EMAModuleWrapper(transformer_trainable_parameters, decay=0.9, update_step_interval=8, device=device)
    ema = None
    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    # Initialize the optimizer
    if config.train.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        transformer_trainable_parameters,
        lr=config.train.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )

    if config.fsdp_optimizer_offload:
        register_optimizer_offload_hooks(optimizer)
    
    train_dataset = None
    test_dataset = None
    train_dataloader = None
    test_dataloader = None

    if config.sample.num_image_per_prompt == 1:
        config.per_prompt_stat_tracking = False
    # initialize stat tracker
    if config.per_prompt_stat_tracking:
        stat_tracker = PerPromptStatTracker(config.sample.global_std)

    # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
    # more memory
    if config.mixed_precision == "fp16":
        autocast = lambda: torch.amp.autocast('cuda', dtype=torch.float16)
    elif config.mixed_precision == "bf16":
        autocast = lambda: torch.amp.autocast('cuda', dtype=torch.bfloat16)
    else:
        autocast = contextlib.nullcontext

    # FSDP doesn't need deepspeed configuration
    # prepare prompt and reward fn
    reward_fn = getattr(flow_grpo.rewards, 'multi_score')(device, config.reward_fn)
    eval_reward_fn = getattr(flow_grpo.rewards, 'multi_score')(device, {"wise": 0.0, "unifiedreward": 1.0})
    selection_reward_config = getattr(config, "selection_reward_fn", None)
    if selection_reward_config is None or len(selection_reward_config.items()) == 0:
        selection_reward_config = {"wise": 1.0}
    selection_reward_fn = getattr(flow_grpo.rewards, 'multi_score')(device, selection_reward_config)
    
    # FSDP setup completed above
    # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
    # remote server running llava inference.
    executor = futures.ThreadPoolExecutor(max_workers=4)

    try:
        signal_config = SignalConfig()
    except ValueError as exc:
        raise ValueError("Signal base URL is required. Set TRAIN_SIGNAL_URL or config.signal.url.") from exc

    signal_base_url = signal_config.url
    signal_timeout = signal_config.timeout

    cpu_device = torch.device("cpu")
    current_text_encoder_path = config.pretrained.te_model
    
    global_step = 0
    
    vlm_train_step = os.getenv("VLM_TRAIN_STEP", None)
    dit_train_step = os.getenv("DIT_TRAIN_STEP", None)
    
    frozen = False
    
    if not config.train.get("enable", True):
        frozen = True

    # 在训练循环开始前进行一次验证
    if config.eval_freq > 0:
        validate(
            global_step=global_step,
            signal_base_url=signal_base_url,
            signal_timeout=signal_timeout,
            signal_config=signal_config,
            rank=rank,
            world_size=world_size,
            local_rank=local_rank,
            pipeline=pipeline,
            config=config,
            device=device,
            eval_reward_fn=eval_reward_fn,
            executor=executor,
            autocast=autocast,
            ema=ema,
            transformer_trainable_parameters=transformer_trainable_parameters,
            images_save_dir=images_save_dir,
            current_text_encoder_path=current_text_encoder_path,
            is_distributed=is_distributed,
        )

    
    while global_step < 1000:
        
        global_step += 1
        
        pipeline.transformer.eval()

        train_key = signal_config.get_key("train", "data", global_step)
        if rank == 0:
            logger.info(f"Step {global_step}: Fetching training data with key: {train_key}")
        train_payload = _fetch_signal_payload(signal_base_url, train_key, signal_timeout, rank)
        raw_train_samples = train_payload.get("train_data") or train_payload.get("data") or []
        if rank == 0:
            logger.info(f"Step {global_step}: Received {len(raw_train_samples)} train samples")
        train_samples = _normalize_signal_samples(raw_train_samples)
        if rank == 0:
            logger.info(f"Step {global_step}: Normalized {len(train_samples)} prompts before selection")
        
        if not train_samples:
            if rank == 0:
                logger.warning("Step %s received empty train payload; skipping.", global_step)
            if is_distributed:
                dist.barrier()
            continue

        payload_train_step = train_payload.get("global_step")
        if payload_train_step is not None and payload_train_step != global_step and rank == 0:
            logger.warning(
                "Train payload step mismatch: expected %s, received %s", global_step, payload_train_step
            )

        # 获取训练数据后将模型移到GPU
        if rank == 0:
            logger.info(f"Step {global_step}: Moving pipeline to {device}")
        _move_pipeline_to(device, pipeline, inference_dtype, optimizer)
        if is_distributed:
            dist.barrier()

        selected_samples, selection_reward_entries = perform_prompt_selection(
            train_samples,
            pipeline,
            config,
            device,
            rank,
            local_rank,
            world_size,
            selection_reward_fn,
            autocast,
            global_step,
            is_distributed,
            executor,
            images_save_dir,
        )
        reward_entries_local_for_signal = selection_reward_entries
        
        if not frozen:
            if not selected_samples:
                selected_samples = [copy.deepcopy(sample) for sample in train_samples]
            if rank == 0:
                logger.info(
                    "Step %s: Proceeding with %s prompts for unified reward training",
                    global_step,
                    len(selected_samples),
                )
            train_samples = [
                copy.deepcopy(sample)
                for sample in selected_samples
                for _ in range(config.sample.num_image_per_prompt)
            ]
            if rank == 0:
                logger.info(
                    "Step %s: Expanded selected prompts to %s samples (num_image_per_prompt=%s)",
                    global_step,
                    len(train_samples),
                    config.sample.num_image_per_prompt,
                )

            train_dataset = InMemoryPromptDataset(train_samples)
            if is_distributed:
                train_sampler = DistributedSampler(
                    train_dataset,
                    num_replicas=world_size,
                    rank=rank,
                    shuffle=True,
                    drop_last=False,
                )
            else:
                train_sampler = None

            train_dataloader = DataLoader(
                train_dataset,
                batch_size=config.sample.train_batch_size,
                sampler=train_sampler,
                shuffle=train_sampler is None,
                collate_fn=InMemoryPromptDataset.collate_fn,
                num_workers=1,
                drop_last=False,
            )

            if train_sampler is not None:
                train_sampler.set_epoch(global_step)

            num_batches_per_epoch = len(train_dataloader)
            samples_per_epoch = len(train_dataset)
            total_train_batch_size = (
                config.train.batch_size * world_size * config.train.gradient_accumulation_steps
            )

            if rank == 0:
                logger.info("***** Running global step %s *****", global_step)
                logger.info("  Train samples per step: %s", samples_per_epoch)
                logger.info("  Train batches per step: %s", num_batches_per_epoch)
                logger.info("  Train batch size per device = %s", config.train.batch_size)
                logger.info("  Sample batch size per device = %s", config.sample.train_batch_size)
                logger.info("  Gradient Accumulation steps = %s", config.train.gradient_accumulation_steps)

            
            #################### SAMPLING ####################
            if rank == 0:
                logger.info(f"Step {global_step}: Starting sampling phase")
            pipeline.transformer.eval()
            
            # 根据 frozen 决定是否需要梯度计算
            gradient_context = torch.no_grad() if frozen else contextlib.nullcontext()
            
            samples = []
            all_train_images = []
            all_train_prompts = []
            all_train_metadatas = []
            all_train_rewards = defaultdict(list)

            with gradient_context:
                for i, (prompts, prompt_metadata) in enumerate(tqdm(
                    train_dataloader,
                    desc=f"Step {global_step}: sampling",
                    disable=local_rank != 0,
                    position=0,
                )):
                    # sample
                    if config.sample.same_latent:
                        generator = create_generator(prompts, base_seed=global_step*10000+i)
                    else:
                        generator = None
                    with autocast():
                        with torch.no_grad():
                            collected_data = pipeline_with_logprob(
                                pipeline,
                                prompts,
                                negative_prompt=[" "]*len(prompts),
                                num_inference_steps=config.sample.num_steps,
                                true_cfg_scale=config.sample.guidance_scale,
                                output_type="pt",
                                height=config.resolution,
                                width=config.resolution, 
                                noise_level=config.sample.noise_level,
                                generator=generator,
                                sde_window_size=config.sample.sde_window_size,
                                sde_window_range=config.sample.sde_window_range,
                        )

                    images = collected_data["images"]
                    total_images = images.shape[0]

                    prompt_ids = pipeline.tokenizer(
                        prompts,
                        padding="max_length",
                        max_length=256,
                        truncation=True,
                        return_tensors="pt",
                    ).input_ids.to(device)

                    latents = torch.stack(collected_data["all_latents"], dim=1)
                    log_probs = torch.stack(collected_data["all_log_probs"], dim=1)
                    timesteps = torch.stack(collected_data["all_timesteps"]).unsqueeze(0).repeat(total_images, 1)
                    # compute rewards asynchronously
                    rewards = executor.submit(
                        reward_fn,
                        images,
                        prompts,
                        prompt_metadata,
                        only_strict=True,
                    )
                    # yield to to make sure reward computation starts
                    time.sleep(0)

                    # Collect training data for saving (local only)
                    all_train_images.append(images)
                    all_train_prompts.extend(prompts)
                    all_train_metadatas.extend(prompt_metadata)
                    
                    if rank == 0 and i == 0:
                        logger.info(f"Step {global_step}: First batch - generated {images.shape[0]} images")
                    
                    samples.append(
                        {
                            "prompt_ids": prompt_ids,
                            "prompt_embeds": collected_data["prompt_embeds"],
                            "prompt_embeds_mask": collected_data["prompt_embeds_mask"],
                            "negative_prompt_embeds": collected_data["negative_prompt_embeds"],
                            "negative_prompt_embeds_mask": collected_data["negative_prompt_embeds_mask"],
                            "timesteps": timesteps,
                            "latents": latents[
                                :, :-1
                            ],  # each entry is the latent before timestep t
                            "next_latents": latents[
                                :, 1:
                            ],  # each entry is the latent after timestep t
                            "log_probs": log_probs,
                            "rewards": rewards,
                        }
                    )
            
            max_prompt_embeds_len = max([sample["prompt_embeds_mask"].shape[1] for sample in samples])
            
            # wait for all rewards to be computed
            if rank == 0:
                logger.info(f"Step {global_step}: Waiting for {len(samples)} batch rewards to be computed")
            for idx, sample in enumerate(tqdm(
                samples,
                desc="Waiting for rewards",
                disable=local_rank!=0,
                position=0,
            )):
                # pad prompt embeds and mask
                seq_pad_len = max_prompt_embeds_len - sample["prompt_embeds"].shape[1]
                sample["prompt_embeds"] = torch.nn.functional.pad(
                    sample["prompt_embeds"],  # [B, L, D]
                    (0, 0, 0, seq_pad_len),   # pad dim=1 (L)
                    value=0,
                )
                sample["prompt_embeds_mask"] = torch.nn.functional.pad(
                    sample["prompt_embeds_mask"],  # [B, L]
                    (0, seq_pad_len),              # pad dim=1 (L)
                    value=0,
                )
                sample["negative_prompt_embeds"] = torch.nn.functional.pad(
                    sample["negative_prompt_embeds"],  # [B, L, D]
                    (0, 0, 0, seq_pad_len),            # pad dim=1 (L)
                    value=0,
                )
                sample["negative_prompt_embeds_mask"] = torch.nn.functional.pad(
                    sample["negative_prompt_embeds_mask"],  # [B, L]
                    (0, seq_pad_len),                       # pad dim=1 (L)
                    value=0,
                )

                rewards, reward_metadata = sample["rewards"].result()
                # print(reward_metadata)
                sample["rewards"] = {
                    key: torch.as_tensor(value, device=device).float()
                    for key, value in rewards.items()
                }
                # Accumulate local rewards for saving
                for key, value in rewards.items():
                    all_train_rewards[key].extend(value)

            # Save all training images and metadata (group by prompt, single json per prompt)
            if rank == 0:
                logger.info(f"Step {global_step}: Saving {len(all_train_images)} image batches")
            all_train_images = torch.cat(all_train_images, dim=0) if len(all_train_images) > 0 else torch.empty(0)
            all_train_rewards = {key: np.asarray(all_train_rewards[key]) for key in all_train_rewards.keys()}
            if rank == 0:
                logger.info(f"Step {global_step}: Total images to save: {all_train_images.shape[0]}")

            global_reward_entries = _gather_all_objects(reward_entries_local_for_signal, rank, world_size)
            reward_submission_payload = None

            save_images_and_metadata(
                all_train_images,
                all_train_prompts,
                all_train_metadatas,
                all_train_rewards,
                images_save_dir,
                global_step,
                split='train',
                rank=rank,
                world_size=world_size
            )

            # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
            samples = {
                k: torch.cat([s[k] for s in samples], dim=0)
                if not isinstance(samples[0][k], dict)
                else {
                    sub_key: torch.cat([s[k][sub_key] for s in samples], dim=0)
                    for sub_key in samples[0][k]
                }
                for k in samples[0].keys()
            }

            # if global_step % 10 == 0 and rank == 0:
            #     # this is a hack to force wandb to log the images as JPEGs instead of PNGs
            #     with tempfile.TemporaryDirectory() as tmpdir:
            #         num_samples = min(15, len(images))
            #         sample_indices = random.sample(range(len(images)), num_samples)

            #         for idx, i in enumerate(sample_indices):
            #             image = images[i]
            #             pil = Image.fromarray(
            #                 (image.cpu().float().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
            #             )
            #             pil = pil.resize((config.resolution, config.resolution))
            #             pil.save(os.path.join(tmpdir, f"{idx}.jpg"))

            #         sampled_prompts = [prompts[i] for i in sample_indices]
            #         sampled_rewards = [rewards['avg'][i] for i in sample_indices]

            #         wandb.log(
            #             {
            #                 "images": [
            #                     wandb.Image(
            #                         os.path.join(tmpdir, f"{idx}.jpg"),
            #                         caption=f"{prompt:.100} | avg: {avg_reward:.2f}",
            #                     )
            #                     for idx, (prompt, avg_reward) in enumerate(zip(sampled_prompts, sampled_rewards))
            #                 ],
            #             },
            #             step=global_step,
            #         )
            samples["rewards"]["ori_avg"] = samples["rewards"]["avg"]
            # The purpose of repeating `adv` along the timestep dimension here is to make it easier to introduce timestep-dependent advantages later, such as adding a KL reward.
            samples["rewards"]["avg"] = samples["rewards"]["avg"].unsqueeze(1).repeat(1, num_train_timesteps)
            # gather rewards across processes
            if rank == 0:
                logger.info(f"Step {global_step}: Gathering rewards across {world_size} processes")
            gathered_rewards = {key: gather_tensor(value, world_size) for key, value in samples["rewards"].items()}
            gathered_rewards = {key: value.cpu().float().numpy() for key, value in gathered_rewards.items()}
            # log rewards and images
            if rank == 0:
                logger.info(f"Step {global_step}: Reward stats - " + ", ".join([f"{k}: {v.mean():.4f}" for k, v in gathered_rewards.items() if '_strict_accuracy' not in k and '_accuracy' not in k]))
                wandb.log(
                    {
                        "global_step": global_step,
                        **{f"reward_{key}": value.mean() for key, value in gathered_rewards.items() if '_strict_accuracy' not in key and '_accuracy' not in key},
                    },
                    step=global_step,
                )

            # per-prompt mean/std tracking
            if config.per_prompt_stat_tracking:
                # gather the prompts across processes
                prompt_ids = gather_tensor(samples["prompt_ids"], world_size).cpu().float().numpy()
                prompts = pipeline.tokenizer.batch_decode(
                    prompt_ids, skip_special_tokens=True
                )
                advantages = stat_tracker.update(prompts, gathered_rewards['avg'])
                if local_rank == 0:
                    print("len(prompts)", len(prompts))
                    print("len unique prompts", len(set(prompts)))

                group_size, trained_prompt_num = stat_tracker.get_stats()

                zero_std_ratio, reward_std_mean = calculate_zero_std_ratio(prompts, gathered_rewards)

                if rank == 0:
                    wandb.log(
                        {
                            "group_size": group_size,
                            "trained_prompt_num": trained_prompt_num,
                            "zero_std_ratio": zero_std_ratio,
                            "reward_std_mean": reward_std_mean,
                        },
                        step=global_step,
                    )
                stat_tracker.clear()
            else:
                advantages = (gathered_rewards['avg'] - gathered_rewards['avg'].mean()) / (gathered_rewards['avg'].std() + 1e-4)

            # ungather advantages; we only need to keep the entries corresponding to the samples on this process
            advantages = torch.as_tensor(advantages)
            samples["advantages"] = (
                advantages.reshape(world_size, -1, advantages.shape[-1])[rank]
                .to(device)
            )
            if local_rank == 0:
                print("advantages: ", samples["advantages"].abs().mean())

            del samples["rewards"]
            del samples["prompt_ids"]

            total_batch_size, num_timesteps = samples["timesteps"].shape
            gradient_accumulation_steps = config.train.gradient_accumulation_steps * num_train_timesteps

            #################### TRAINING ####################
        
            if rank == 0:
                logger.info(f"Step {global_step}: Starting training phase with {config.train.num_inner_epochs} inner epochs")
            for inner_epoch in range(config.train.num_inner_epochs):
                if rank == 0:
                    logger.info(f"Step {global_step}.{inner_epoch}: Starting inner epoch")
                # rebatch for training
                samples_batched = {
                    k: v.reshape(-1, config.sample.train_batch_size, *v.shape[1:])
                    for k, v in samples.items()
                }

                # dict of lists -> list of dicts for easier iteration
                samples_batched = [
                    dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
                ]

                # train
                pipeline.transformer.train()
                info = defaultdict(list)
                for i, sample in tqdm(
                    list(enumerate(samples_batched)),
                    desc=f"Step {global_step}.{inner_epoch}: training",
                    position=0,
                    disable=local_rank != 0,
                ):
                    train_timesteps = [step_index  for step_index in range(num_train_timesteps)]
                    for j in tqdm(
                        train_timesteps,
                        desc="Timestep",
                        position=1,
                        leave=False,
                        disable=local_rank != 0,
                    ):
                        # Manual gradient accumulation for FSDP
                        if (i * num_train_timesteps + j + 1) % gradient_accumulation_steps == 0:
                            should_sync = True
                        else:
                            should_sync = False
                        
                        with autocast():
                            prev_sample, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(transformer, pipeline, sample, j, config, rank)
                            if config.train.beta > 0:
                                with torch.no_grad():
                                    _, _, prev_sample_mean_ref, _ = compute_log_prob(transformer_ref, pipeline, sample, j, config, rank)
                        # grpo logic
                        advantages = torch.clamp(
                            sample["advantages"][:, j],
                            -config.train.adv_clip_max,
                            config.train.adv_clip_max,
                        )
                        ratio = torch.exp(log_prob - sample["log_probs"][:, j])
                        # print("ratio", ratio)
                        unclipped_loss = -advantages * ratio
                        clipped_loss = -advantages * torch.clamp(
                            ratio,
                            1.0 - config.train.clip_range,
                            1.0 + config.train.clip_range,
                        )
                        policy_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
                        policy_loss = policy_loss / gradient_accumulation_steps
                        if config.train.beta > 0:
                            kl_loss = ((prev_sample_mean - prev_sample_mean_ref) ** 2).mean(dim=(1,2), keepdim=True) / (2 * std_dev_t ** 2)
                            kl_loss = torch.mean(kl_loss)
                            kl_loss = kl_loss / gradient_accumulation_steps
                            loss = policy_loss + config.train.beta * kl_loss
                        else:
                            loss = policy_loss

                        info["approx_kl"].append(
                            0.5
                            * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
                        )
                        info["clipfrac"].append(
                            torch.mean(
                                (
                                    torch.abs(ratio - 1.0) > config.train.clip_range
                                ).float()
                            )
                        )
                        info["clipfrac_gt_one"].append(
                            torch.mean(
                                (
                                    ratio - 1.0 > config.train.clip_range
                                ).float()
                            )
                        )
                        info["clipfrac_lt_one"].append(
                            torch.mean(
                                (
                                    1.0 - ratio > config.train.clip_range
                                ).float()
                            )
                        )
                        info["policy_loss"].append(policy_loss)
                        if config.train.beta > 0:
                            info["kl_loss"].append(kl_loss)

                        info["loss"].append(loss)

                        # backward pass
                        loss.backward()
                        if should_sync:
                            # Clip gradients
                            torch.nn.utils.clip_grad_norm_(transformer.parameters(), config.train.max_grad_norm)
                            optimizer.step()
                            optimizer.zero_grad()

                        # if should_sync:
                        #     # log training-related stuff
                        #     info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                        #     # Reduce info across processes
                        #     if is_distributed:
                        #         for k, v in info.items():
                        #             dist.all_reduce(v, op=dist.ReduceOp.SUM)
                        #             info[k] = v / world_size
                        #     info.update({"global_step": global_step, "inner_epoch": inner_epoch})
                        #     if rank == 0:
                        #         wandb.log(info, step=global_step)
                        #     global_step += 1
                        #     info = defaultdict(list)
                    
                    if config.train.ema:
                        ema.step(transformer_trainable_parameters, global_step)
                        
                # log training-related stuff
                info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                # Reduce info across processes
                if is_distributed:
                    for k, v in info.items():
                        dist.all_reduce(v, op=dist.ReduceOp.SUM)
                        info[k] = v / world_size
                info.update({"global_step": global_step, "inner_epoch": inner_epoch})
                if rank == 0:
                    logger.info(f"Step {global_step}.{inner_epoch}: Training metrics - loss: {info.get('loss', 0):.4f}, policy_loss: {info.get('policy_loss', 0):.4f}, approx_kl: {info.get('approx_kl', 0):.4f}")
                    wandb.log(info, step=global_step)
                
        # 保存模型
        if global_step % config.save_freq == 0:
            if rank == 0:
                logger.info(f"Step {global_step}: Saving FSDP checkpoint to {config.save_dir}")
            save_fsdp_checkpoint(config.save_dir, transformer, global_step, rank)
                
        reward_key = signal_config.get_key("train", "reward", global_step)

        if rank == 0:
            aggregated_rewards: Dict[str, Dict[str, Any]] = {}
            for entry in global_reward_entries:
                aggregated_rewards[entry["rollout_id"]] = entry

            reward_results: List[Dict[str, Any]] = []
            for rollout_id, info in aggregated_rewards.items():
                mean_score = info.get("score", 0.0)
                payload_entry: Dict[str, Any] = {"rollout_id": rollout_id, "score": mean_score}
                if info.get("metadata"):
                    payload_entry["metadata"] = info["metadata"]
                reward_results.append(payload_entry)

            # reward_results.sort(
            #     key=lambda item: (
            #         item.get("metadata", {}).get("data_source", ""),
            #         item.get("metadata", {}).get("sample_idx", item.get("metadata", {}).get("idx", 0)),
            #     )
            # )

            reward_submission_payload = {
                "global_step": int(global_step),
                "mode": "train",
                "rewards": reward_results,
                "text_encoder_path": current_text_encoder_path,
            }

        # 提交reward前将模型移到CPU
        if rank == 0:
            logger.info(f"Step {global_step}: Moving pipeline to CPU for checkpoint/eval")
        _move_pipeline_to(cpu_device, pipeline, inference_dtype, optimizer)
        if is_distributed:
            dist.barrier()

        # 提交奖励
        if rank == 0:
            logger.info(f"Step {global_step}: Submitting rewards to signal service with key: {reward_key}")
        if reward_submission_payload is not None:
            _submit_signal_payload(signal_base_url, reward_key, reward_submission_payload, rank, signal_timeout)

        # 更新text_encoder
        text_encoder_key = signal_config.text_encoder_key_template.format(epoch=global_step)
        if rank == 0:
            logger.info(f"Step {global_step}: Fetching text encoder update with key: {text_encoder_key}")
        te_payload = _fetch_signal_payload(signal_base_url, text_encoder_key, signal_timeout, rank)
        payload_step = te_payload.get("global_step")
        if payload_step is not None and payload_step != global_step and rank == 0:
            logger.warning(
                "Text encoder payload step mismatch: expected %s, received %s", global_step, payload_step
            )
        new_te_path = te_payload.get("text_encoder_path") or te_payload.get("path")
        if new_te_path:
            if rank == 0:
                logger.info("Loading new text encoder weights from %s", new_te_path)
            _load_text_encoder_weights(pipeline, new_te_path, inference_dtype)
            current_text_encoder_path = new_te_path
        elif rank == 0 and not new_te_path:
            logger.info("Step %s received no new text encoder path; reusing %s", global_step, current_text_encoder_path)

        # 训练完成后进行验证
        if config.eval_freq > 0 and global_step % config.eval_freq == 0:
            if rank == 0:
                logger.info(f"Step {global_step}: Starting validation")
            validate(
                global_step=global_step,
                signal_base_url=signal_base_url,
                signal_timeout=signal_timeout,
                signal_config=signal_config,
                rank=rank,
                world_size=world_size,
                local_rank=local_rank,
                pipeline=pipeline,
                config=config,
                device=device,
                eval_reward_fn=eval_reward_fn,
                executor=executor,
                autocast=autocast,
                ema=ema,
                transformer_trainable_parameters=transformer_trainable_parameters,
                images_save_dir=images_save_dir,
                current_text_encoder_path=current_text_encoder_path,
                is_distributed=is_distributed,
            )
            
        # For scheduler training
        if dit_train_step and global_step % (vlm_train_step + dit_train_step) == vlm_train_step:
            frozen = False
        
        if dit_train_step and global_step % (vlm_train_step + dit_train_step) == 0:
            frozen = True

        
if __name__ == "__main__":
    app.run(main)

