# -*- coding: utf-8 -*-
"""
GRPO training for OpenCUA / Qwen3-VL VLM policy model (LAZY dataset).

✅ Only changes vs your original RL script:
1) Dataset is LAZY: no pre-build (N, L) input_ids/p_mask/labels huge tensors.
   __getitem__ builds one sample on-the-fly and pads to fixed max_seq_len.
   (To avoid重复 CPU 预处理，这里还加了“按 idx 缓存一次”的 memoization，和你原版“预处理全量”效果一致，
    但不是 upfront build。你不想缓存的话把 cache_enable=False 即可。)

2) Old-logp is still computed BEFORE training using the SAME model (frozen snapshot at init):
   - run one precompute pass
   - store dataset.logp_old_tok[idx, :] on CPU
   - then start training with PPO-style ratio/clipping/KL (same as old script)

3) For ZeRO-3 safety: we DO NOT use accelerator.accumulate(model).
   We keep the classic manual grad accumulation:
       loss /= grad_acc
       backward every step
       optimizer.step() only when step % grad_acc == 0
   (This avoids DeepSpeed ZeRO-3 no_sync assertion.)
"""

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ["TOKENIZERS_PARALLELISM"] = "true"

import io
import json
import math
import time
import shutil
import base64
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from omegaconf import OmegaConf

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from transformers import (
    AutoTokenizer,
    AutoImageProcessor,
    AutoModel,
    AutoModelForCausalLM,
    AutoProcessor,
)

from train.utils import get_config, flatten_omega_conf
from train.lr_schedulers import get_scheduler
from train.log_utils import set_verbosity_info, set_verbosity_error

try:
    import wandb
except ImportError:
    wandb = None

logger = get_logger(__name__, log_level="INFO")


# ==================== Utils for images / prompts ====================

def _decode_image(src: str) -> Image.Image:
    s = (src or "").strip()
    if not s:
        raise ValueError("empty image src")
    if s.startswith("http://") or s.startswith("https://"):
        import requests
        resp = requests.get(s, timeout=10)
        resp.raise_for_status()
        return Image.open(io.BytesIO(resp.content)).convert("RGB")
    if not os.path.exists(s):
        # base64 / data url
        b64 = s.split(",", 1)[1] if "," in s else s
        raw = base64.b64decode(b64, validate=False)
        return Image.open(io.BytesIO(raw)).convert("RGB")
    return Image.open(s).convert("RGB")


def collect_images_from_messages(messages: List[Dict[str, Any]]) -> List[Image.Image]:
    pil_images: List[Image.Image] = []
    for m in messages:
        content = m.get("content")
        if isinstance(content, list):
            for part in content:
                if isinstance(part, dict):
                    t = part.get("type")
                    if t == "image":
                        src = part.get("image")
                        if src:
                            try:
                                pil_images.append(_decode_image(src))
                            except Exception as e:
                                print(f"[WARN] decode image failed: {src} | {e}", flush=True)
                    elif t == "image_url":
                        url = (part.get("image_url") or {}).get("url") or part.get("url")
                        if url:
                            try:
                                pil_images.append(_decode_image(url))
                            except Exception as e:
                                print(f"[WARN] decode image_url failed: {url} | {e}", flush=True)
    return pil_images


def _normalize_qwen_messages(prompt_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Convert stored messages into Qwen3-VL / AutoProcessor expected format:
      {"role": "...", "content": [{"type":"text","text":...} / {"type":"image","image":...}, ...]}
    """
    qwen_messages = []
    for m in prompt_messages:
        role = m.get("role", "user")
        content = m.get("content", "")

        new_m: Dict[str, Any] = {"role": role}
        new_content: List[Dict[str, Any]] = []

        if isinstance(content, str):
            new_content.append({"type": "text", "text": content})
        elif isinstance(content, list):
            for part in content:
                if isinstance(part, str):
                    new_content.append({"type": "text", "text": part})
                elif isinstance(part, dict):
                    p_type = part.get("type")
                    if p_type in ("text", "paragraph"):
                        new_content.append({"type": "text", "text": part.get("text", "")})
                    elif p_type in ("image", "video"):
                        new_content.append(part)
                    elif p_type == "image_url":
                        url = (part.get("image_url") or {}).get("url") or part.get("url")
                        if url:
                            new_content.append({"type": "image", "image": url})
                    else:
                        txt = part.get("text", None)
                        new_content.append({"type": "text", "text": txt if txt is not None else str(part)})
                else:
                    new_content.append({"type": "text", "text": str(part)})
        else:
            new_content.append({"type": "text", "text": str(content)})

        new_m["content"] = new_content
        qwen_messages.append(new_m)
    return qwen_messages


# ==================== Lazy Dataset ====================

class VLMGRPOLazyDataset(Dataset):
    """
    Lazy RL dataset with optional per-index cache.

    Returns (per sample):
      - idx: int
      - input_ids: (max_seq_len,)
      - p_mask:    (max_seq_len,) bool, True on response tokens (token positions)
      - labels:    (max_seq_len,) (same as input_ids, for convenience)
      - advantage: float
      - pixel_values: Tensor or None
      - grid_thws:    Tensor or None

    Old-logp cache:
      - logp_old_tok: (N, max_seq_len) float32, only positions [1:] meaningful after precompute
    """

    def __init__(
        self,
        raw_data: List[Dict[str, Any]],
        tokenizer,
        image_processor,
        max_prompt_len: int,
        max_gen_length: int,
        is_qwen3_vl: bool = False,
        processor=None,
        fallback_max_seq_len: int = 4096,
        cache_enable: bool = True,
    ):
        self.raw_data = raw_data
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.is_qwen3_vl = is_qwen3_vl
        self.processor = processor
        self.cache_enable = bool(cache_enable)

        if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token = tokenizer.eos_token
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        self.pad_id = tokenizer.pad_token_id

        self.max_prompt_len = int(max_prompt_len)
        self.max_gen_length = int(max_gen_length)

        if self.max_prompt_len > 0 and self.max_gen_length > 0:
            self.max_seq_len = self.max_prompt_len + self.max_gen_length
        else:
            self.max_seq_len = int(fallback_max_seq_len)
            logger.warning(
                f"[VLMGRPOLazyDataset] max_prompt_len/max_gen_length <= 0, "
                f"fallback max_seq_len={self.max_seq_len}. Recommend setting both explicitly."
            )

        n = len(self.raw_data)
        self.logp_old_tok = torch.full((n, self.max_seq_len), float("-inf"), dtype=torch.float32)

        # cache: store fully padded tensors + vision tensors per idx
        self._cache = [None] * n if self.cache_enable else None

    def __len__(self) -> int:
        return len(self.raw_data)

    def _build_one(self, item: Dict[str, Any]):
        tokenizer = self.tokenizer
        image_processor = self.image_processor
        is_qwen3_vl = self.is_qwen3_vl
        processor = self.processor
        pad_id = self.pad_id

        prompt_messages = item["prompt_messages"]
        response_text = item["response"]
        advantage = float(item.get("reward", 0.0))

        # ---------- prompt ids + vision ----------
        if is_qwen3_vl:
            if processor is None:
                raise ValueError("Qwen3-VL requires `processor`, but got None.")
            qwen_messages = _normalize_qwen_messages(prompt_messages)
            proc_inputs = processor.apply_chat_template(
                qwen_messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_prompt_len if self.max_prompt_len > 0 else 100000,
            )
            prompt_ids = proc_inputs["input_ids"][0].tolist()

            pixel_values = proc_inputs.get("pixel_values", None)
            if pixel_values is not None:
                pixel_values = pixel_values.clone()

            grid = proc_inputs.get("image_grid_thw", None)
            grid_thws = grid.clone().long() if grid is not None else None

        else:
            prompt_ids = tokenizer.apply_chat_template(
                prompt_messages,
                tokenize=True,
                add_generation_prompt=True,
            )
            if self.max_prompt_len > 0 and len(prompt_ids) > self.max_prompt_len:
                prompt_ids = prompt_ids[-self.max_prompt_len:]

            pil_images = collect_images_from_messages(prompt_messages)
            pixel_values = None
            grid_thws = None
            if pil_images:
                info = image_processor.preprocess(images=pil_images)

                pixel = info.get("pixel_values", None)
                if pixel is None:
                    pixel = info.get("pixel_values_videos", None)

                grid = info.get("image_grid_thw", None)
                if grid is None:
                    grid = info.get("images_grid_thw", None)
                if grid is None:
                    grid = info.get("grid_thws", None)

                if pixel is not None:
                    pixel_values = torch.as_tensor(pixel)
                if grid is not None:
                    grid_thws = torch.as_tensor(grid, dtype=torch.long)

        # ---------- response ids ----------
        resp_ids = tokenizer(response_text, add_special_tokens=False)["input_ids"]
        if self.max_gen_length > 0 and len(resp_ids) > self.max_gen_length:
            resp_ids = resp_ids[: self.max_gen_length]

        input_ids = prompt_ids + resp_ids
        start_pos = len(prompt_ids)
        L = len(input_ids)

        max_len = self.max_seq_len
        if L > max_len:
            # safety truncate from left; keep tail
            offset = L - max_len
            input_ids = input_ids[offset:]
            start_pos = max(0, start_pos - offset)
            L = max_len

        input_ids_tensor = torch.full((max_len,), pad_id, dtype=torch.long)
        labels_tensor = torch.full((max_len,), pad_id, dtype=torch.long)
        p_mask_tensor = torch.zeros((max_len,), dtype=torch.bool)

        if L > 0:
            ids_tensor = torch.tensor(input_ids, dtype=torch.long)
            input_ids_tensor[:L] = ids_tensor
            labels_tensor[:L] = ids_tensor
            if L > start_pos:
                p_mask_tensor[start_pos:L] = True

        return input_ids_tensor, p_mask_tensor, labels_tensor, advantage, pixel_values, grid_thws

    def __getitem__(self, idx: int):
        if self.cache_enable and self._cache[idx] is not None:
            input_ids_tensor, p_mask_tensor, labels_tensor, advantage, pv, g = self._cache[idx]
        else:
            input_ids_tensor, p_mask_tensor, labels_tensor, advantage, pv, g = self._build_one(self.raw_data[idx])
            if self.cache_enable:
                self._cache[idx] = (input_ids_tensor, p_mask_tensor, labels_tensor, advantage, pv, g)

        return (
            int(idx),
            input_ids_tensor,
            p_mask_tensor,
            labels_tensor,
            float(advantage),
            pv,
            g,
        )


def grpo_collate(batch):
    ids, input_ids, p_mask, labels, adv, pixel_values, grid_thws = zip(*batch)
    return {
        "ids": torch.tensor(ids, dtype=torch.long),
        "input_ids": torch.stack(input_ids),                 # (B, L)
        "p_mask": torch.stack(p_mask),                       # (B, L)
        "labels": torch.stack(labels),                       # (B, L)
        "advantage": torch.tensor(adv, dtype=torch.float32), # (B,)
        "pixel_values": list(pixel_values),                  # list len=B
        "grid_thws": list(grid_thws),                        # list len=B
    }


# ==================== Core utils ====================

def make_attn_and_pos(input_ids: torch.Tensor, pad_id: int):
    attention_mask = (input_ids != pad_id).to(torch.long)
    position_ids = attention_mask.cumsum(dim=1) - 1
    position_ids.masked_fill_(attention_mask == 0, 0)
    return attention_mask, position_ids


@torch.no_grad()
def compute_logp_old_tok_parallel(
    accelerator: Accelerator,
    model,
    dataset: VLMGRPOLazyDataset,
    train_dataloader: DataLoader,
    pad_id: int,
    is_qwen3_vl: bool,
):
    """
    Precompute old token log-probs under the initial (frozen) policy model.
    Stores into dataset.logp_old_tok (CPU).
    """
    model.eval()
    from tqdm.auto import tqdm

    iterator = tqdm(
        train_dataloader,
        desc="Precomputing old token log-probs",
        dynamic_ncols=True,
        disable=not accelerator.is_local_main_process,
        leave=True,
    )

    for batch in iterator:
        ids = batch["ids"].cpu()  # (B,)
        input_ids_batch = batch["input_ids"].to(accelerator.device)  # (B,L)
        labels_batch = batch["labels"].to(accelerator.device)        # (B,L)
        pixel_values_list = batch["pixel_values"]
        grid_thws_list = batch["grid_thws"]

        B, L = input_ids_batch.shape
        full_batch = torch.full((B, L), float("-inf"), device=accelerator.device, dtype=torch.float32)

        for i in range(B):
            seq = input_ids_batch[i:i+1]    # (1,L)
            attention_mask, position_ids = make_attn_and_pos(seq, pad_id)

            kwargs: Dict[str, Any] = dict(
                input_ids=seq,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )

            pv = pixel_values_list[i]
            if pv is not None:
                kwargs["pixel_values"] = pv.to(accelerator.device)

            g = grid_thws_list[i]
            if g is not None:
                if is_qwen3_vl:
                    kwargs["image_grid_thw"] = g.to(accelerator.device)
                else:
                    kwargs["grid_thws"] = g.to(accelerator.device)

            out = model(**kwargs)
            logits = out.logits  # (1, L, V)

            logits_shifted = logits[:, :-1, :]
            labels_shifted = labels_batch[i:i+1, 1:]  # (1, L-1)

            log_probs = F.log_softmax(logits_shifted.float(), dim=-1)
            logp_tok = log_probs.gather(dim=-1, index=labels_shifted.unsqueeze(-1)).squeeze(-1)  # (1, L-1)

            full = torch.full((1, L), float("-inf"), device=accelerator.device, dtype=torch.float32)
            full[:, 1:] = logp_tok
            full_batch[i] = full[0]

        # write back to dataset cache (CPU)
        dataset.logp_old_tok[ids] = full_batch.detach().cpu()

    accelerator.wait_for_everyone()
    model.train()


def forward_grpo_vlm(
    accelerator: Accelerator,
    model,
    batch: Dict[str, Any],
    dataset: VLMGRPOLazyDataset,   # ✅ keep same style as your original RL script
    tokenizer,
    config,
    is_qwen3_vl: bool,
):
    """
    PPO/GRPO loss on response tokens only, using precomputed dataset.logp_old_tok.
    """
    device = accelerator.device
    pad_id = tokenizer.pad_token_id

    ids = batch["ids"].cpu()                        # (B,)
    input_ids_batch = batch["input_ids"].to(device) # (B, L)
    p_mask_batch = batch["p_mask"].to(device)       # (B, L)
    labels_batch = batch["labels"].to(device)       # (B, L)
    adv_batch = batch["advantage"].to(device)       # (B,)
    pixel_values_list = batch["pixel_values"]
    grid_thws_list = batch["grid_thws"]

    old_lp_batch = dataset.logp_old_tok[ids].to(device)  # (B, L)

    B, L = input_ids_batch.shape
    total_loss = torch.tensor(0.0, device=device)

    for i in range(B):
        input_ids = input_ids_batch[i:i+1]      # (1,L)
        labels = labels_batch[i:i+1]            # (1,L)
        p_mask = p_mask_batch[i:i+1]            # (1,L)
        adv = adv_batch[i:i+1]                  # (1,)
        old_lp = old_lp_batch[i:i+1]            # (1,L)

        attention_mask, position_ids = make_attn_and_pos(input_ids, pad_id)

        kwargs: Dict[str, Any] = dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        pv = pixel_values_list[i]
        if pv is not None:
            kwargs["pixel_values"] = pv.to(device)

        g = grid_thws_list[i]
        if g is not None:
            if is_qwen3_vl:
                kwargs["image_grid_thw"] = g.to(device)
            else:
                kwargs["grid_thws"] = g.to(device)

        out = model(**kwargs)
        logits = out.logits  # (1, L, V)

        logits_shifted = logits[:, :-1, :]
        labels_shifted = labels[:, 1:]  # (1, L-1)

        log_probs = F.log_softmax(logits_shifted.float(), dim=-1)
        logp_new_tok = log_probs.gather(dim=-1, index=labels_shifted.unsqueeze(-1)).squeeze(-1)  # (1, L-1)

        p_mask_shift = p_mask[:, 1:]  # (1, L-1)
        old_lp_shift = old_lp[:, 1:]  # (1, L-1)

        # ratio only on response tokens
        diff = (logp_new_tok - old_lp_shift)
        diff = torch.where(p_mask_shift, diff, torch.zeros_like(diff))
        diff = diff.clamp(-10.0, 10.0)
        ratio = torch.exp(diff)

        clipped = torch.clamp(
            ratio,
            1.0 - float(config.training.eps),
            1.0 + float(config.training.eps),
        )

        adv_tok = adv.unsqueeze(1)  # (1,1)
        surrogate = torch.min(ratio * adv_tok, clipped * adv_tok)
        surrogate = surrogate * p_mask_shift

        denom = torch.clamp(p_mask_shift.sum(dim=1), min=1.0)  # (1,)
        policy_loss = - (surrogate.sum(dim=1) / denom).mean()

        # optional KL vs old (only meaningful because old is real precomputed snapshot)
        kl_loss = torch.tensor(0.0, device=device)
        beta = float(getattr(config.training, "beta", 0.0))
        if beta > 0:
            kl_seq = (logp_new_tok - old_lp_shift)
            kl_seq = torch.where(p_mask_shift, kl_seq, torch.zeros_like(kl_seq))
            if getattr(config.training, "use_kl_estimator_k3", False):
                t = (-kl_seq).clamp(-10.0, 10.0)
                kl_seq = t.exp() - 1.0 + kl_seq
            kl_seq = (kl_seq * p_mask_shift).sum(dim=1) / torch.clamp(denom, min=1.0)
            kl_loss = beta * kl_seq.mean()

        total_loss = total_loss + (policy_loss + kl_loss)

    return total_loss / max(B, 1)


# ==================== Checkpoint (same as your original) ====================

def save_checkpoint(model, tokenizer, config, accelerator, name: str):
    import json as _json

    project_name = config.experiment.project
    rl_base_dir = Path(config.system.rl_base_dir)
    save_base = rl_base_dir / project_name / "ckpt"
    save_dir = save_base / name

    model_to_save = accelerator.unwrap_model(model)
    state_dict = accelerator.get_state_dict(model)

    if accelerator.is_main_process:
        save_dir.mkdir(parents=True, exist_ok=True)

        model_to_save.save_pretrained(
            save_dir,
            save_function=accelerator.save,
            state_dict=state_dict,
            safe_serialization=True,
        )
        tokenizer.save_pretrained(str(save_dir))

        # infer base dir for extra files
        init_kwargs = getattr(tokenizer, "init_kwargs", {})
        tok_name = getattr(tokenizer, "name_or_path", "") or init_kwargs.get("_name_or_path", "")
        target = getattr(config.training, "target", None)

        candidate_base = []
        if tok_name:
            candidate_base.append(Path(tok_name))
        if target == "policy":
            candidate_base.append(Path(config.model.policy_model))
        elif target == "reward":
            candidate_base.append(Path(config.model.reward_model))

        base_dir = None
        for p in candidate_base:
            if p and p.exists():
                base_dir = p
                break

        tok_class_name = tokenizer.__class__.__name__
        tok_cfg_class = init_kwargs.get("tokenizer_class", "")

        is_opencua = (
            "TikTokenV3" in tok_class_name
            or "TikTokenV3" in str(tok_cfg_class)
            or "OpenCUA" in str(tok_name)
        )

        if is_opencua and base_dir is not None:
            extra_files = [
                "tokenization_opencua.py",
                "configuration_opencua.py",
                "modeling_opencua.py",
                "tiktoken.model",
                "preprocessor_config.json",
                "image_processor_config.json",
                "processor_config.json",
                "generation_config.json",
            ]
            for fn in extra_files:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

            tok_cfg_path = save_dir / "tokenizer_config.json"
            if tok_cfg_path.exists():
                with tok_cfg_path.open("r", encoding="utf-8") as f:
                    cfg = _json.load(f)
                cfg["tokenizer_class"] = "tokenization_opencua.TikTokenV3"
                auto_map = cfg.get("auto_map", {})
                auto_map["AutoTokenizer"] = ["tokenization_opencua.TikTokenV3", None]
                cfg["auto_map"] = auto_map
                with tok_cfg_path.open("w", encoding="utf-8") as f:
                    _json.dump(cfg, f, indent=2, ensure_ascii=False)

        # generic multimodal fallback
        if base_dir is not None and not is_opencua:
            for fn in ["preprocessor_config.json", "processor_config.json", "image_processor_config.json"]:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

        metadata = {"save_time": time.strftime("%Y-%m-%d %H:%M:%S")}
        with (save_base / "metadata.json").open("w", encoding="utf-8") as f:
            _json.dump(metadata, f, indent=2, ensure_ascii=False)

        logger.info(f"Saved model + tokenizer to {save_dir}")


# ==================== Model helpers ====================

def enable_gc(model):
    if hasattr(model, "config") and hasattr(model.config, "use_cache"):
        model.config.use_cache = False
    if hasattr(model, "language_model") and hasattr(model.language_model, "config"):
        if hasattr(model.language_model.config, "use_cache"):
            model.language_model.config.use_cache = False

    if hasattr(model, "language_model") and hasattr(model.language_model, "gradient_checkpointing_enable"):
        model.language_model.gradient_checkpointing_enable()
        for m in model.language_model.modules():
            if hasattr(m, "gradient_checkpointing"):
                m.gradient_checkpointing = True
        logger.info("[GC] Enabled on model.language_model")
        return

    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
        for m in model.modules():
            if hasattr(m, "gradient_checkpointing"):
                m.gradient_checkpointing = True
        logger.info("[GC] Enabled on model directly")
        return


def load_vlm_model(pretrained_model: str, is_qwen3_vl: bool):
    if is_qwen3_vl:
        from transformers import Qwen3VLForConditionalGeneration
        return Qwen3VLForConditionalGeneration.from_pretrained(
            pretrained_model, trust_remote_code=True, torch_dtype="auto"
        )
    try:
        return AutoModelForCausalLM.from_pretrained(
            pretrained_model, trust_remote_code=True, torch_dtype="auto"
        )
    except Exception:
        return AutoModel.from_pretrained(
            pretrained_model, trust_remote_code=True, torch_dtype="auto"
        )


# ==================== Main ====================

def main():
    config = get_config()
    project_name = config.experiment.project

    # -------- target selection (same as your RL script) --------
    if config.training.target == "policy":
        if config.experiment.current_epoch == 1:
            pretrained_model = config.model.policy_model
        else:
            pretrained_model = config.system.rl_base_dir + "/" + project_name + "/ckpt/" + config.model.optimized_name
        optimized_name = config.model.optimized_name
        max_prompt_len = config.training.policy.max_prompt_len
        max_gen_length = config.training.policy.max_gen_length
        optimization_data = "policy_optimization_data"
        update_per_step = config.training.policy.update_per_step
        batch_size_lm = config.training.policy.batch_size_lm
        gradient_checkpointing_enable = config.training.policy.gradient_checkpointing_enable

    elif config.training.target == "reward":
        if config.experiment.current_epoch == 1:
            pretrained_model = config.model.reward_model
        else:
            pretrained_model = config.system.rl_base_dir + "/" + project_name + "/ckpt/" + config.model.optimized_reward_name
        optimized_name = config.model.optimized_reward_name
        max_prompt_len = config.training.reward.max_prompt_len
        max_gen_length = config.training.reward.max_gen_length
        optimization_data = "reward_optimization_data"
        update_per_step = config.training.reward.update_per_step
        batch_size_lm = config.training.reward.batch_size_lm
        gradient_checkpointing_enable = config.training.reward.gradient_checkpointing_enable
    else:
        raise ValueError(f"Unknown training.target = {config.training.target}")

    # TF32
    if config.training.enable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    # -------- load data --------
    opt_path = Path(project_name) / "temp_data" / f"{optimization_data}.json"
    with opt_path.open("r", encoding="utf-8") as f:
        dataset_load = json.load(f)

    ws = int(os.environ.get("WORLD_SIZE", "1"))
    total_n = len(dataset_load)

    gradient_accumulation_steps = max(
        1,
        math.ceil(total_n / (update_per_step * batch_size_lm * ws)),
    )

    # detect model type (keep your old logic)
    is_qwen3_vl = (str(getattr(config, "model_type", "")).lower() == "qwen3vl")

    # -------- tokenizer / processor --------
    if is_qwen3_vl:
        processor = AutoProcessor.from_pretrained(pretrained_model, trust_remote_code=True)
        tokenizer = processor.tokenizer
        image_processor = processor.image_processor
    else:
        processor = None
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)
        image_processor = AutoImageProcessor.from_pretrained(pretrained_model, trust_remote_code=True)

    # -------- lazy dataset --------
    dataset_lm = VLMGRPOLazyDataset(
        raw_data=dataset_load,
        tokenizer=tokenizer,
        image_processor=image_processor,
        max_prompt_len=max_prompt_len,
        max_gen_length=max_gen_length,
        is_qwen3_vl=is_qwen3_vl,
        processor=processor,
        cache_enable=True,   # 跟你原版“预处理一次后复用”一致；你想彻底不缓存就改 False
    )

    # -------- accelerator --------
    config.experiment.logging_dir = str(Path(project_name) / "logs")
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=config.training.mixed_precision,
        log_with=None,
        project_dir=config.experiment.logging_dir,
        split_batches=True,   # ✅ 保持你原 RL 脚本的分片方式，确保 precompute 和 train 看到同一批 ids
    )
    assert ws == accelerator.num_processes

    # logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        set_verbosity_info()
    else:
        set_verbosity_error()

    # wandb
    if accelerator.is_main_process and wandb is not None:
        resume_wandb_run = config.wandb.resume
        run_id = config.wandb.get("run_id", None)
        if run_id is None:
            resume_wandb_run = False
            run_id = wandb.util.generate_id()
            config.wandb.run_id = run_id

        wandb_init_kwargs = dict(
            name=project_name,
            id=run_id,
            resume=resume_wandb_run,
            entity=config.wandb.get("entity", None),
            config_exclude_keys=[],
        )
        wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
        wandb_config.pop("experiment.resume_from_checkpoint", None)

        accelerator.init_trackers(
            project_name,
            config=wandb_config,
            init_kwargs={"wandb": wandb_init_kwargs},
        )

    if accelerator.is_main_process:
        os.makedirs(project_name, exist_ok=True)
        config_path = Path(project_name) / "config.yaml"
        logging.info(f"Saving config to {config_path}")
        OmegaConf.save(config, config_path)

    if config.training.seed is not None:
        set_seed(config.training.seed)

    # -------- model --------
    logger.info("Loading VLM model")
    model = load_vlm_model(pretrained_model, is_qwen3_vl=is_qwen3_vl)

    if gradient_checkpointing_enable:
        enable_gc(model)

    pad_id = tokenizer.pad_token_id

    # -------- optimizer --------
    optimizer_config = config.optimizer.params
    no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and not any(nd in n for nd in no_decay)
            ],
            "weight_decay": optimizer_config.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if config.optimizer.name == "adamw":
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=optimizer_config.learning_rate,
            betas=(optimizer_config.beta1, optimizer_config.beta2),
            weight_decay=optimizer_config.weight_decay,
            eps=optimizer_config.epsilon,
        )
    else:
        raise ValueError(f"Optimizer {config.optimizer.name} not supported")

    total_batch_size_lm = batch_size_lm * ws * gradient_accumulation_steps
    num_update_steps_per_epoch = max(1, math.ceil(total_n / total_batch_size_lm))
    num_train_epochs = config.training.num_train_epochs
    max_train_steps = num_update_steps_per_epoch * num_train_epochs

    lr_scheduler = get_scheduler(
        config.lr_scheduler.scheduler,
        optimizer=optimizer,
        num_training_steps=max_train_steps,
        num_warmup_steps=config.lr_scheduler.params.warmup_steps,
        min_lr_scale=config.lr_scheduler.params.min_lr_scale,
    )

    train_dataloader_lm = DataLoader(
        dataset_lm,
        batch_size=batch_size_lm,
        sampler=None,
        collate_fn=grpo_collate,
        num_workers=0,
        pin_memory=True,
    )

    # prepare
    logger.info("Preparing model/optimizer/scheduler/dataloader")
    model, optimizer, lr_scheduler, train_dataloader_lm = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader_lm
    )

    # -------- precompute old logp (same logic as your original) --------
    logger.info("***** Precomputing old logp (frozen snapshot) *****")
    compute_logp_old_tok_parallel(
        accelerator=accelerator,
        model=model,
        dataset=dataset_lm,
        train_dataloader=train_dataloader_lm,
        pad_id=pad_id,
        is_qwen3_vl=is_qwen3_vl,
    )

    # -------- training --------
    logger.info("***** Running GRPO training (lazy) *****")
    logger.info(f"  Num training data = {len(dataset_lm)}")
    logger.info(f"  Num training steps = {max_train_steps}")
    logger.info(f"  Instantaneous batch size per device = {batch_size_lm}")
    logger.info(f"  Total train batch size (parallel * accum) = {total_batch_size_lm}")
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")

    from tqdm.auto import tqdm

    global_step = 0
    for epoch in range(num_train_epochs):
        model.train()
        progress = tqdm(
            train_dataloader_lm,
            disable=not accelerator.is_local_main_process,
            dynamic_ncols=True,
        )

        for step, batch in enumerate(progress, start=1):
            loss = forward_grpo_vlm(
                accelerator=accelerator,
                model=model,
                batch=batch,
                dataset=dataset_lm,     # ✅ signature matches
                tokenizer=tokenizer,
                config=config,
                is_qwen3_vl=is_qwen3_vl,
            )

            loss = loss / gradient_accumulation_steps
            accelerator.backward(loss)

            if (step % gradient_accumulation_steps) == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

                global_step += 1
                loss_val = float(accelerator.gather(loss.detach()).mean().item())
                progress.set_postfix(loss=loss_val)

                if accelerator.is_main_process and wandb is not None:
                    accelerator.log({"train/loss": loss_val, "train/step": global_step})

    torch.cuda.empty_cache()
    accelerator.wait_for_everyone()

    # save
    save_checkpoint(model, tokenizer, config, accelerator, optimized_name)
    if getattr(config.experiment, "save_every", 0) and (config.experiment.current_epoch % config.experiment.save_every == 0):
        save_checkpoint(model, tokenizer, config, accelerator, f"epoch-{config.experiment.current_epoch}")

    accelerator.end_training()


if __name__ == "__main__":
    main()
