# -*- coding: utf-8 -*-
"""
GRPO training for OpenCUA VLM policy model.

- Input: policy_optimization_data.json
  [
    {
      "domain": "...",
      "example": "...",
      "run_id": 0,
      "step": 0,
      "prompt_messages": [...],   # system/user 多模态，格式同 serve_opencua.py
      "response": "...",          # assistant 文本
      "reward": 0.0               # 标量（本脚本直接当 advantage 用）
    },
    ...
  ]

- Multi-modal handling strictly follows serve_opencua.py style:
  - prompt_messages -> tokenizer.apply_chat_template(..., add_generation_prompt=True)
  - images in prompt_messages -> AutoImageProcessor.preprocess -> pixel_values + grid_thws
  - response -> 纯文本 token, 拼在 prompt 后面

- GRPO style:
  - Precompute old token log-probs under frozen policy (logp_old_tok)
  - Train with PPO-like objective on response tokens only:
      ratio = exp(logp_new - logp_old)
      clipped surrogate with per-sample scalar advantages (= reward here)
      optional KL penalty vs old policy (config.training.beta)
"""

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ["TOKENIZERS_PARALLELISM"] = "true"

import io
import re
import json
import math
import time
import shutil
import base64
import logging
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from omegaconf import OmegaConf

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from transformers import (
    AutoTokenizer,
    AutoImageProcessor,
    AutoModelForCausalLM,
    AutoModel,
    AutoProcessor,
)

from train.utils import get_config, flatten_omega_conf, AverageMeter
from train.lr_schedulers import get_scheduler
from train.log_utils import set_verbosity_info, set_verbosity_error

try:
    import wandb
except ImportError:
    wandb = None

logger = get_logger(__name__, log_level="INFO")


# ==================== Dataset ====================

class VLMTrainDataset(Dataset):
    """
    存储：
      - input_ids: (N, L)
      - p_mask:    (N, L) bool, True 在 response tokens
      - labels:    (N, L)
      - advantage: (N,) float
      - pixel_values_list: 每样本一个 Tensor 或 None
      - grid_thws_list:    每样本一个 Tensor 或 None

    以及：
      - logp_old_tok: (N, L) 预计算的旧策略 log prob（仅 [1:] 有意义）
    """
    def __init__(
        self,
        input_ids: torch.Tensor,
        p_mask: torch.Tensor,
        labels: torch.Tensor,
        advantage: torch.Tensor,
        pixel_values_list: List[Any],
        grid_thws_list: List[Any],
    ):
        assert input_ids.shape == p_mask.shape == labels.shape
        assert input_ids.size(0) == advantage.size(0) == len(pixel_values_list) == len(grid_thws_list)

        self.input_ids = input_ids          # (N, L)
        self.p_mask = p_mask                # (N, L)
        self.labels = labels                # (N, L)
        self.advantage = advantage          # (N,)

        self.pixel_values_list = pixel_values_list
        self.grid_thws_list = grid_thws_list

        N, L = input_ids.shape
        self.logp_old_tok = torch.full((N, L), float("-inf"), dtype=torch.float32)

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return (
            int(idx),
            self.input_ids[idx],
            self.p_mask[idx],
            self.labels[idx],
            float(self.advantage[idx].item()),
            self.pixel_values_list[idx],
            self.grid_thws_list[idx],
        )


def simple_collate(batch):
    """
    与 LLM 版风格一致，但多返回 pixel_values/grid_thws 的 list（逐样本处理）。
    """
    (
        ids,
        input_ids,
        p_mask,
        labels,
        adv,
        pixel_values,
        grid_thws,
    ) = zip(*batch)

    return {
        "ids": torch.tensor(ids, dtype=torch.long),
        "input_ids": torch.stack(input_ids),      # (B, L)
        "p_mask": torch.stack(p_mask),            # (B, L)
        "labels": torch.stack(labels),            # (B, L)
        "advantage": torch.tensor(adv, dtype=torch.float32),   # (B,)
        "pixel_values": list(pixel_values),       # len=B, 每个是 Tensor 或 None
        "grid_thws": list(grid_thws),             # len=B, 每个是 Tensor 或 None
    }


# ==================== Utils for images / prompts ====================

def _decode_image(src: str) -> Image.Image:
    s = src.strip()
    if s.startswith("http://") or s.startswith("https://"):
        # 训练阶段一般是本地路径，这里简单支持一下 http(s) / data URL
        import requests
        resp = requests.get(s, timeout=10)
        resp.raise_for_status()
        return Image.open(io.BytesIO(resp.content)).convert("RGB")
    # data/base64
    if not os.path.exists(s):
        # 可能是纯 base64 / data URI
        b64 = s.split(",", 1)[1] if "," in s else s
        raw = base64.b64decode(b64, validate=False)
        return Image.open(io.BytesIO(raw)).convert("RGB")
    # 本地路径
    return Image.open(s).convert("RGB")


def collect_images_from_messages(messages: List[Dict[str, Any]]) -> List[Image.Image]:
    pil_images: List[Image.Image] = []
    for m in messages:
        content = m.get("content")
        if isinstance(content, list):
            for part in content:
                t = part.get("type")
                if t == "image":
                    src = part.get("image")
                    if src:
                        try:
                            pil_images.append(_decode_image(src))
                        except Exception as e:
                            print(f"[WARN] decode image failed: {src} | {e}", flush=True)
                elif t == "image_url":
                    url = (part.get("image_url") or {}).get("url")
                    if url:
                        try:
                            pil_images.append(_decode_image(url))
                        except Exception as e:
                            print(f"[WARN] decode image_url failed: {url} | {e}", flush=True)
    return pil_images


def build_vlm_grpo_dataset(
    data: List[Dict[str, Any]],
    tokenizer,
    image_processor,
    max_prompt_len: int,
    max_gen_length: int,
    is_qwen3_vl: bool = False,
    processor=None,
) -> VLMTrainDataset:
    """
    按你 LLM GRPO 的思路构造：
      input_ids = chat_template(prompt_messages, add_generation_prompt=True) + tokenize(response)
      p_mask    = [False...False, True...True] on response tokens
      advantage = reward（直接使用）
      pixel_values/grid_thws 来自 prompt_messages 里的图片（与 serve_opencua 对齐）
    """
    # pad_token
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            # fallback
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    pad_id = tokenizer.pad_token_id

    samples: List[Dict[str, Any]] = []
    max_len = 0

    for item in data:
        prompt_messages = item["prompt_messages"]
        response_text = item["response"]
        reward = float(item.get("reward", 0.0))

        # ---------- Qwen3-VL 分支：用 AutoProcessor 保证占位符和特征对齐 ----------
        if is_qwen3_vl:
            if processor is None:
                raise ValueError("Qwen3-VL requires `processor`, but got None.")

            # 把我们自己存的 messages 规范成 Qwen3-VL / AutoProcessor 期望的格式：
            # 每条 message:
            #   {"role": "...", "content": [ {"type": "text", "text": ...} or {"type": "image", "image": ...}, ... ]}
            qwen_messages = []
            for m in prompt_messages:
                role = m.get("role", "user")
                content = m.get("content", "")

                new_m = {"role": role}
                new_content = []

                # 情况 1：原本就是字符串
                if isinstance(content, str):
                    new_content.append({"type": "text", "text": content})

                # 情况 2：是列表（可能混有 str / dict）
                elif isinstance(content, list):
                    for part in content:
                        # 字符串 -> text
                        if isinstance(part, str):
                            new_content.append({"type": "text", "text": part})
                        # dict，根据 type 归一
                        elif isinstance(part, dict):
                            p_type = part.get("type")
                            if p_type in ("text", "paragraph"):
                                # 兼容可能的 text 字段
                                txt = part.get("text", "")
                                new_content.append({"type": "text", "text": txt})
                            elif p_type in ("image", "video"):
                                # 已经是规范格式
                                new_content.append(part)
                            elif p_type == "image_url":
                                # 老格式：{"type": "image_url", "image_url": {"url": ...}}
                                url = (part.get("image_url") or {}).get("url") or part.get("url")
                                if url:
                                    new_content.append({"type": "image", "image": url})
                            else:
                                # 没有 type / 不认识，当文本兜底
                                txt = part.get("text", None)
                                if txt is not None:
                                    new_content.append({"type": "text", "text": txt})
                                else:
                                    new_content.append({"type": "text", "text": str(part)})
                        else:
                            # 其他类型直接转成文本
                            new_content.append({"type": "text", "text": str(part)})
                else:
                    # 其他奇怪类型，兜底成文本
                    new_content.append({"type": "text", "text": str(content)})

                new_m["content"] = new_content
                qwen_messages.append(new_m)

            proc_inputs = processor.apply_chat_template(
                qwen_messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
                truncation=False,
                max_length=100000,
            )

            prompt_ids = proc_inputs["input_ids"][0].tolist()

            pixel_values = proc_inputs.get("pixel_values", None)
            if pixel_values is not None:
                pixel_values = pixel_values.clone()

            grid = proc_inputs.get("image_grid_thw", None)
            if grid is not None:
                grid_thws = grid.clone().long()
            else:
                grid_thws = None

        # ---------- 非 Qwen3-VL 分支：保持你原来的 OpenCUA / policy 行为 ----------
        else:
            # 1) prompt token ids
            prompt_ids = tokenizer.apply_chat_template(
                prompt_messages,
                tokenize=True,
                add_generation_prompt=True,
            )

            # 4) 图像（只看 prompt_messages）
            pil_images = collect_images_from_messages(prompt_messages)
            pixel_values = None
            grid_thws = None
            if pil_images:
                info = image_processor.preprocess(images=pil_images)

                pixel = info.get("pixel_values", None)
                if pixel is None:
                    pixel = info.get("pixel_values_videos", None)

                grid = info.get("image_grid_thw", None)
                if grid is None:
                    grid = info.get("images_grid_thw", None)
                if grid is None:
                    grid = info.get("grid_thws", None)

                if pixel is not None:
                    pixel_values = torch.as_tensor(pixel)
                if grid is not None:
                    grid_thws = torch.as_tensor(grid, dtype=torch.long)

        # 2) response token ids（只文本）
        resp_ids = tokenizer(
            response_text,
            add_special_tokens=False,
        )["input_ids"]

        # 3) 截断
        has_image = (pixel_values is not None) or (grid_thws is not None)

        if max_prompt_len > 0 and len(prompt_ids) > max_prompt_len:
            if has_image:
                # ✅ 不要截 prompt（会坏对齐）；超长只能：调大 max_prompt_len 或跳过样本
                print(
                    f"[WARN] skip sample: image prompt too long "
                    f"({len(prompt_ids)} > {max_prompt_len}). Increase max_prompt_len.",
                    flush=True,
                )
                continue
            else:
                # ✅ 纯文本才允许 token 级左截断
                prompt_ids = prompt_ids[-max_prompt_len:]
        
        if max_gen_length > 0 and len(resp_ids) > max_gen_length:
            resp_ids = resp_ids[:max_gen_length]

        input_ids = prompt_ids + resp_ids
        if len(input_ids) == 0:
            continue

        start_pos = len(prompt_ids)
        L = len(input_ids)
        max_len = max(max_len, L)

        samples.append(
            {
                "input_ids": input_ids,
                "start_pos": start_pos,
                "advantage": reward,
                "pixel_values": pixel_values,
                "grid_thws": grid_thws,
            }
        )

    if not samples:
        raise ValueError("No valid samples found in optimization data.")

    N = len(samples)
    input_ids_tensor = torch.full((N, max_len), pad_id, dtype=torch.long)
    p_mask_tensor = torch.zeros((N, max_len), dtype=torch.bool)
    labels_tensor = torch.full((N, max_len), pad_id, dtype=torch.long)
    adv_tensor = torch.zeros(N, dtype=torch.float32)
    pixel_values_list: List[Any] = []
    grid_thws_list: List[Any] = []

    for i, s in enumerate(samples):
        ids = s["input_ids"]
        start_pos = s["start_pos"]
        adv = s["advantage"]
        L = len(ids)

        input_ids_tensor[i, :L] = torch.tensor(ids, dtype=torch.long)
        labels_tensor[i, :L] = torch.tensor(ids, dtype=torch.long)
        if L > start_pos:
            p_mask_tensor[i, start_pos:L] = True

        adv_tensor[i] = float(adv)
        pixel_values_list.append(s["pixel_values"])
        grid_thws_list.append(s["grid_thws"])

    return VLMTrainDataset(
        input_ids=input_ids_tensor,
        p_mask=p_mask_tensor,
        labels=labels_tensor,
        advantage=adv_tensor,
        pixel_values_list=pixel_values_list,
        grid_thws_list=grid_thws_list,
    )


# ==================== Core train ====================

def make_attn_and_pos(input_ids: torch.Tensor, pad_id: int):
    """
    attention_mask: 1 for non-pad
    position_ids:   cumsum over non-pad, pad 位置为 0（不会用到）
    """
    attention_mask = (input_ids != pad_id).to(torch.long)
    position_ids = attention_mask.cumsum(dim=1) - 1
    position_ids.masked_fill_(attention_mask == 0, 0)
    return attention_mask, position_ids

def make_attention_mask(input_ids: torch.Tensor, pad_id: int):
    return (input_ids != pad_id).to(torch.long)

def make_position_ids(attention_mask: torch.Tensor):
    position_ids = attention_mask.cumsum(dim=1) - 1
    position_ids.masked_fill_(attention_mask == 0, 0)
    return position_ids


@torch.no_grad()
def compute_logp_old_tok_parallel(
    accelerator: Accelerator,
    model,
    dataset: VLMTrainDataset,
    train_dataloader: DataLoader,
    pad_id: int,
    is_qwen3_vl: bool,
):
    """
    与 LLM 版类似，但对每个 batch 内逐样本 forward，
    以支持不同 pixel_values/grid_thws（保持 serve_opencua 风格，避免 shape 搞炸）。
    """
    model.eval()
    from tqdm.auto import tqdm

    dl = train_dataloader
    iterator = tqdm(
        dl,
        desc="Precomputing old token log-probs",
        dynamic_ncols=True,
        disable=not accelerator.is_local_main_process,
        leave=True,
    )

    for batch in iterator:
        ids = batch["ids"]                              # (B,)
        input_ids_batch = batch["input_ids"].to(accelerator.device)  # (B,L)
        p_mask_batch = batch["p_mask"].to(accelerator.device)        # (B,L)
        pixel_values_list = batch["pixel_values"]
        grid_thws_list = batch["grid_thws"]

        B, L = input_ids_batch.shape
        full_batch = torch.full(
            (B, L),
            float("-inf"),
            device=accelerator.device,
            dtype=torch.float32,
        )

        for i in range(B):
            seq = input_ids_batch[i : i + 1]          # (1, L)
            p_mask = p_mask_batch[i : i + 1]          # (1, L)
            attention_mask = make_attention_mask(seq, pad_id)
            kwargs = dict(input_ids=seq, attention_mask=attention_mask)

            if not is_qwen3_vl:
                kwargs["position_ids"] = make_position_ids(attention_mask)

            pv = pixel_values_list[i]
            if pv is not None:
                pv = pv.to(accelerator.device)
                kwargs["pixel_values"] = pv

            g = grid_thws_list[i]
            if g is not None:
                g = g.to(accelerator.device)
                if is_qwen3_vl:
                    # ✅ Qwen3-VL 期望的名字
                    kwargs["image_grid_thw"] = g
                else:
                    # ✅ 保持原来对 OpenCUA / 你自家 VLM 的接口
                    kwargs["grid_thws"] = g

            out = model(**kwargs)
            logits = out.logits  # (1, L, V)

            # next-token logp
            logits_shifted = logits[:, :-1, :]
            labels_shifted = seq[:, 1:]
            log_probs = F.log_softmax(logits_shifted, dim=-1)
            logp_tok = log_probs.gather(
                dim=-1,
                index=labels_shifted.unsqueeze(-1),
            ).squeeze(-1)  # (1, L-1)

            # 对齐到 (1,L)：第 0 token 保持 -inf，1..L-1 写入
            full = torch.full((1, L), float("-inf"), device=logp_tok.device)
            full[:, 1:] = logp_tok
            # 不强制按 p_mask 截，因为训练时还会与 p_mask 交集
            full_batch[i] = full[0].float()

        # 写回 dataset 全局缓存
        dataset.logp_old_tok[ids] = full_batch.detach().cpu()

    accelerator.wait_for_everyone()
    model.train()


def forward_process_vlm(
    accelerator: Accelerator,
    model,
    batch: Dict[str, Any],
    dataset: VLMTrainDataset,
    tokenizer,
    config,
    is_qwen3_vl: bool,
):
    """
    单 step 的 GRPO/PPO-style loss 计算（多模态版）。
    与你原来的 forward_process 逻辑保持一致，只是多了 pixel_values/grid_thws。
    """
    device = accelerator.device

    ids = batch["ids"].cpu()
    input_ids_batch = batch["input_ids"].to(device)      # (B, L)
    p_mask_batch = batch["p_mask"].to(device)            # (B, L)
    adv_batch = batch["advantage"].to(device)            # (B,)
    pixel_values_list = batch["pixel_values"]
    grid_thws_list = batch["grid_thws"]

    old_lp_batch = dataset.logp_old_tok[ids].to(device)  # (B, L)

    pad_id = tokenizer.pad_token_id
    B, L = input_ids_batch.shape

    total_loss = torch.tensor(0.0, device=device)

    for i in range(B):
        input_ids = input_ids_batch[i : i + 1]      # (1,L)
        p_mask = p_mask_batch[i : i + 1]            # (1,L)
        adv = adv_batch[i : i + 1]                  # (1,)
        old_lp = old_lp_batch[i : i + 1]            # (1,L)

        attention_mask = make_attention_mask(input_ids, pad_id)
        kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)

        if not is_qwen3_vl:
            kwargs["position_ids"] = make_position_ids(attention_mask)

        pv = pixel_values_list[i]
        if pv is not None:
            kwargs["pixel_values"] = pv.to(device)

        g = grid_thws_list[i]
        if g is not None:
            if is_qwen3_vl:
                kwargs["image_grid_thw"] = g.to(device)
            else:
                kwargs["grid_thws"] = g.to(device)

        out = model(**kwargs)
        logits = out.logits  # (1, L, V)

        # shift
        logits_shifted = logits[:, :-1, :]
        labels_shifted = input_ids[:, 1:]

        log_probs = F.log_softmax(logits_shifted, dim=-1)
        logp_new_tok = log_probs.gather(
            dim=-1,
            index=labels_shifted.unsqueeze(-1),
        ).squeeze(-1)  # (1, L-1)

        p_mask_shift = p_mask[:, 1:]          # (1, L-1)
        old_lp_shift = old_lp[:, 1:]          # (1, L-1)

        # ratio only on response tokens
        diff = (logp_new_tok - old_lp_shift)
        diff = torch.where(p_mask_shift, diff, torch.zeros_like(diff))
        diff = diff.clamp(-10.0, 10.0)
        ratio = torch.exp(diff)

        clipped = torch.clamp(
            ratio,
            1.0 - config.training.eps,
            1.0 + config.training.eps,
        )

        adv_tok = adv.unsqueeze(1)  # (1,1)
        surrogate = torch.min(ratio * adv_tok, clipped * adv_tok)
        surrogate = surrogate * p_mask_shift

        denom = torch.clamp(p_mask_shift.sum(dim=1), min=1.0)  # (1,)
        policy_loss = - (surrogate.sum(dim=1) / denom).mean()

        # optional KL
        kl_loss = torch.tensor(0.0, device=device)
        if config.training.beta > 0:
            kl_seq = (logp_new_tok - old_lp_shift)
            kl_seq = torch.where(p_mask_shift, kl_seq, torch.zeros_like(kl_seq))
            if getattr(config.training, "use_kl_estimator_k3", False):
                t = (-kl_seq).clamp(-10.0, 10.0)
                kl_seq = t.exp() - 1.0 + kl_seq
            kl_seq = (kl_seq * p_mask_shift).sum(dim=1) / torch.clamp(denom, min=1.0)
            kl_loss = config.training.beta * kl_seq.mean()

        total_loss = total_loss + (policy_loss + kl_loss)

    return total_loss / max(B, 1)


# ==================== Checkpoint ====================

def save_checkpoint(model, tokenizer, config, accelerator, name: str):
    import json, shutil, time
    from pathlib import Path

    project_name = config.experiment.project
    rl_base_dir = Path(config.system.rl_base_dir)
    save_base = rl_base_dir / project_name / "ckpt"
    save_dir = save_base / name

    # unwrap & gather
    model_to_save = accelerator.unwrap_model(model)
    state_dict = accelerator.get_state_dict(model)

    if accelerator.is_main_process:
        save_dir.mkdir(parents=True, exist_ok=True)

        # 1) 保存模型权重
        model_to_save.save_pretrained(
            save_dir,
            save_function=accelerator.save,
            state_dict=state_dict,
            safe_serialization=True,
        )

        # 2) 保存 tokenizer（所有模型通用）
        tokenizer.save_pretrained(str(save_dir))

        # ====== 推断 base_dir，用来拷各种配置 ======
        init_kwargs = getattr(tokenizer, "init_kwargs", {})
        tok_name = (
            getattr(tokenizer, "name_or_path", "")
            or init_kwargs.get("_name_or_path", "")
        )

        target = getattr(config.training, "target", None)

        candidate_base = []
        if tok_name:
            candidate_base.append(Path(tok_name))
        if target == "policy":
            candidate_base.append(Path(config.model.policy_model))
        elif target == "reward":
            candidate_base.append(Path(config.model.reward_model))

        base_dir = None
        for p in candidate_base:
            if p and p.exists():
                base_dir = p
                break

        # ====== 识别是不是 OpenCUA（TikTokenV3） ======
        tok_class_name = tokenizer.__class__.__name__
        tok_cfg_class = init_kwargs.get("tokenizer_class", "")

        is_opencua = (
            "TikTokenV3" in tok_class_name
            or "TikTokenV3" in str(tok_cfg_class)
            or "OpenCUA" in str(tok_name)
        )

        if is_opencua and base_dir is not None:
            # OpenCUA 需要的自定义代码和多模态配置
            extra_files = [
                "tokenization_opencua.py",
                "configuration_opencua.py",
                "modeling_opencua.py",
                "tiktoken.model",
                "preprocessor_config.json",
                "image_processor_config.json",
                "processor_config.json",
                "generation_config.json",
            ]
            for fn in extra_files:
                src = base_dir / fn
                dst = save_dir   / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

            # patch tokenizer_config.json -> 指向 tokenization_opencua.TikTokenV3
            tok_cfg_path = save_dir / "tokenizer_config.json"
            if tok_cfg_path.exists():
                with tok_cfg_path.open("r", encoding="utf-8") as f:
                    cfg = json.load(f)
                cfg["tokenizer_class"] = "tokenization_opencua.TikTokenV3"
                auto_map = cfg.get("auto_map", {})
                auto_map["AutoTokenizer"] = ["tokenization_opencua.TikTokenV3", None]
                cfg["auto_map"] = auto_map
                with tok_cfg_path.open("w", encoding="utf-8") as f:
                    json.dump(cfg, f, indent=2, ensure_ascii=False)

        # ====== 通用多模态兜底（Qwen3-VL 等） ======
        # 如果不是 OpenCUA，但 base_dir 里有 preprocessor_config / processor_config，
        # 就同步一份到 ckpt，保证 AutoProcessor.from_pretrained(ckpt) 能工作。
        if base_dir is not None and not is_opencua:
            for fn in [
                "preprocessor_config.json",
                "processor_config.json",
                "image_processor_config.json",
            ]:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

        # 3) metadata（可选）
        metadata = {
            "save_time": time.strftime("%Y-%m-%d %H:%M:%S"),
        }
        with (save_base / "metadata.json").open("w") as f:
            json.dump(metadata, f, indent=2)

        logger.info(f"Saved model + tokenizer to {save_dir}")



import torch
from pathlib import Path
from typing import Any, Dict, List, Tuple

def load_merged_preproc_dataset(project_name: str, optimization_data: str):
    path = Path(project_name) / "temp_data" / f"{optimization_data}_preproc_merged.pt"
    if not path.exists():
        raise FileNotFoundError(f"Missing merged dataset: {path}. Run merge script first.")
    pk = torch.load(path, map_location="cpu")

    input_ids = pk["input_ids"]
    p_mask = pk["p_mask"]
    labels = pk["labels"]
    advantage = pk["advantage"]
    pixel_values_list = pk["pixel_values_list"]
    grid_thws_list = pk["grid_thws_list"]
    meta = pk.get("meta", {})

    assert input_ids.shape == p_mask.shape == labels.shape
    assert input_ids.size(0) == advantage.size(0) == len(pixel_values_list) == len(grid_thws_list)
    return input_ids, p_mask, labels, advantage, pixel_values_list, grid_thws_list, meta


# ==================== Main ====================

def main():
    config = get_config()

    project_name = config.experiment.project

    # 只考虑 policy target（和原脚本一致保留分支）
    if config.training.target == "policy":
        if config.experiment.current_epoch == 1:
            pretrained_model = config.model.policy_model
        else:
            pretrained_model = config.system.rl_base_dir + "/" + project_name + "/ckpt/" + config.model.optimized_name
        optimized_name = config.model.optimized_name
        max_prompt_len = config.training.policy.max_prompt_len
        max_gen_length = config.training.policy.max_gen_length
        optimization_data = "policy_optimization_data"
        update_per_step = config.training.policy.update_per_step
        batch_size_lm = config.training.policy.batch_size_lm
        gradient_checkpointing_enable = config.training.policy.gradient_checkpointing_enable
    elif config.training.target == "reward":
        # 若之后要做 reward model，同风格拓展即可
        if config.experiment.current_epoch == 1:
            pretrained_model = config.model.reward_model
        else:
            pretrained_model = config.system.rl_base_dir + "/" + project_name + "/ckpt/" + config.model.optimized_reward_name
        optimized_name = config.model.optimized_reward_name
        max_prompt_len = config.training.reward.max_prompt_len
        max_gen_length = config.training.reward.max_gen_length
        optimization_data = "reward_optimization_data"
        update_per_step = config.training.reward.update_per_step
        batch_size_lm = config.training.reward.batch_size_lm
        gradient_checkpointing_enable = config.training.reward.gradient_checkpointing_enable
    else:
        raise ValueError(f"Unknown training.target = {config.training.target}")

    # TF32
    if config.training.enable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    # ========== Load optimization data ==========
    input_ids_tensor, p_mask_tensor, labels_tensor, adv_tensor, pixel_values_list, grid_thws_list, meta = \
    load_merged_preproc_dataset(project_name, optimization_data)

    dataset_lm = VLMTrainDataset(
        input_ids=input_ids_tensor,
        p_mask=p_mask_tensor,
        labels=labels_tensor,
        advantage=adv_tensor,
        pixel_values_list=pixel_values_list,
        grid_thws_list=grid_thws_list,
    )

    total_n = int(meta.get("merged_kept_N", len(adv_tensor)))

    ws = int(os.environ.get("WORLD_SIZE", "1"))
    rank = int(os.environ.get("RANK", "0"))

    gradient_accumulation_steps = max(
        1,
        math.ceil(total_n / (update_per_step * batch_size_lm * ws)),
    )

    if config.model_type != "qwen3vl":
        is_qwen3_vl = False
    else:
        is_qwen3_vl = True

    # ========== Tokenizer & Image Processor ==========
    if is_qwen3_vl:
        # 对 Qwen3-VL：官方推荐用 AutoProcessor，一步生成 text+image 对齐输入
        processor = AutoProcessor.from_pretrained(pretrained_model, trust_remote_code=True)
        tokenizer = processor.tokenizer
        image_processor = processor.image_processor
    else:
        # 原来的逻辑，给 OpenCUA policy 等用
        processor = None
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)
        image_processor = AutoImageProcessor.from_pretrained(pretrained_model, trust_remote_code=True)

    # ========== Accelerator ==========
    

    config.experiment.logging_dir = str(Path(project_name) / "logs")
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=config.training.mixed_precision,
        log_with=None,
        project_dir=config.experiment.logging_dir,
        split_batches=True,
    )
    assert ws == accelerator.num_processes

    # Logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        set_verbosity_info()
    else:
        set_verbosity_error()

    # Wandb
    if accelerator.is_main_process and wandb is not None:
        resume_wandb_run = config.wandb.resume
        run_id = config.wandb.get("run_id", None)
        if run_id is None:
            resume_wandb_run = False
            run_id = wandb.util.generate_id()
            config.wandb.run_id = run_id

        wandb_init_kwargs = dict(
            name=project_name,
            id=run_id,
            resume=resume_wandb_run,
            entity=config.wandb.get("entity", None),
            config_exclude_keys=[],
        )
        wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
        wandb_config.pop("experiment.resume_from_checkpoint", None)

        accelerator.init_trackers(
            project_name,
            config=wandb_config,
            init_kwargs={"wandb": wandb_init_kwargs},
        )

    if accelerator.is_main_process:
        os.makedirs(project_name, exist_ok=True)
        config_path = Path(project_name) / "config.yaml"
        logging.info(f"Saving config to {config_path}")
        OmegaConf.save(config, config_path)

    # Seed
    if config.training.seed is not None:
        set_seed(config.training.seed)

    # ========== Model & Optimizer ==========
    logger.info("Loading VLM model and optimizer")

    if is_qwen3_vl:
        from transformers import Qwen3VLForConditionalGeneration
        model = Qwen3VLForConditionalGeneration.from_pretrained(pretrained_model, trust_remote_code=True,torch_dtype="auto")
    else:
        model = AutoModel.from_pretrained(pretrained_model, trust_remote_code=True,torch_dtype="auto")
    

    def enable_gc(model):
        """
        针对：
        - OpenCUAForConditionalGeneration（包了一层 language_model）
        - Qwen3-VL / Qwen2.x-VL 等原生支持 gc 的模型
        做一个统一的 gradient checkpointing 开启逻辑。
        """
        # 先关所有能看到的 use_cache，防止和 gc 冲突
        if hasattr(model, "config") and hasattr(model.config, "use_cache"):
            model.config.use_cache = False
        if hasattr(model, "language_model") and hasattr(model.language_model, "config"):
            if hasattr(model.language_model.config, "use_cache"):
                model.language_model.config.use_cache = False

        # 情况 1：OpenCUA 这类 wrapper，真正的 LM 在 language_model 里
        if hasattr(model, "language_model") and hasattr(model.language_model, "gradient_checkpointing_enable"):
            model.language_model.gradient_checkpointing_enable()
            # 部分实现还用 gradient_checkpointing flag
            for m in model.language_model.modules():
                if hasattr(m, "gradient_checkpointing"):
                    m.gradient_checkpointing = True
            logger.info("[GC] Enabled on model.language_model (OpenCUA/Qwen-style)")
            return

        # 情况 2：普通支持 gc 的模型（Qwen3-VL ForCausalLM 等）
        if hasattr(model, "gradient_checkpointing_enable"):
            try:
                model.gradient_checkpointing_enable()
                for m in model.modules():
                    if hasattr(m, "gradient_checkpointing"):
                        m.gradient_checkpointing = True
                logger.info("[GC] Enabled on model directly")
                return
            except ValueError as e:
                logger.warning(f"[GC] Direct enable failed: {e}")

        # 情况 3：兜底：对有 gradient_checkpointing 属性的层开开关（有就赚到）
        any_flag = False
        for m in model.modules():
            if hasattr(m, "gradient_checkpointing"):
                m.gradient_checkpointing = True
                any_flag = True
        if any_flag:
            logger.info("[GC] Enabled by setting gradient_checkpointing=True on submodules")
        else:
            logger.warning("[GC] No supported gradient checkpointing hooks found; running without GC.")
    

    if gradient_checkpointing_enable:
        #model.gradient_checkpointing_enable()
        #if hasattr(model, "config"):
        #    model.config.use_cache = False
        enable_gc(model)
    else:
        model = model.to(accelerator.device)

    pad_id = tokenizer.pad_token_id

    optimizer_config = config.optimizer.params

    no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and not any(nd in n for nd in no_decay)
            ],
            "weight_decay": optimizer_config.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if config.optimizer.name == "adamw":
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=optimizer_config.learning_rate,
            betas=(optimizer_config.beta1, optimizer_config.beta2),
            weight_decay=optimizer_config.weight_decay,
            eps=optimizer_config.epsilon,
        )
    else:
        raise ValueError(f"Optimizer {config.optimizer.name} not supported")

    total_batch_size_lm = batch_size_lm * ws * gradient_accumulation_steps
    num_update_steps_per_epoch = max(
        1,
        math.ceil(total_n / total_batch_size_lm),
    )
    num_train_epochs = config.training.num_train_epochs
    max_train_steps = num_update_steps_per_epoch * num_train_epochs

    lr_scheduler = get_scheduler(
        config.lr_scheduler.scheduler,
        optimizer=optimizer,
        num_training_steps=max_train_steps,
        num_warmup_steps=config.lr_scheduler.params.warmup_steps,
        min_lr_scale=config.lr_scheduler.params.min_lr_scale,
    )

    train_dataloader_lm = DataLoader(
        dataset_lm,
        batch_size=batch_size_lm,
        sampler=None,
        collate_fn=simple_collate,
        num_workers=0,
    )

    # Prepare with accelerator
    logger.info("Preparing model, optimizer, scheduler and dataloader")
    model, optimizer, lr_scheduler, train_dataloader_lm = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader_lm
    )

    # ========== Precompute old logp ==========
    logger.info("***** Running inference (precompute old logp) *****")
    compute_logp_old_tok_parallel(
        accelerator,
        model,
        dataset_lm,
        train_dataloader_lm,
        pad_id=pad_id,
        is_qwen3_vl=is_qwen3_vl,
    )

    # ========== Training ==========
    logger.info("***** Running training *****")
    logger.info(f"  Num training data = {len(dataset_lm)}")
    logger.info(f"  Num training steps = {max_train_steps}")
    logger.info(f"  Instantaneous batch size per device = {batch_size_lm}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size_lm}")
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")

    from tqdm.auto import tqdm

    for epoch in range(num_train_epochs):
        model.train()
        progress = tqdm(
            train_dataloader_lm,
            disable=not accelerator.is_local_main_process,
            dynamic_ncols=True,
        )

        for step, batch in enumerate(progress, start=1):
            loss = forward_process_vlm(
                accelerator=accelerator,
                model=model,
                batch=batch,
                dataset=dataset_lm,
                tokenizer=tokenizer,
                config=config,
                is_qwen3_vl=is_qwen3_vl,
            )

            loss = loss / gradient_accumulation_steps
            accelerator.backward(loss)

            if (step % gradient_accumulation_steps) == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

                loss_val = float(accelerator.gather(loss.detach()).mean().item())
                progress.set_postfix(loss=loss_val)

    torch.cuda.empty_cache()
    accelerator.wait_for_everyone()

    # save final
    save_checkpoint(model, tokenizer, config, accelerator, optimized_name)
    if config.experiment.current_epoch % config.experiment.save_every == 0:
        save_checkpoint(model, tokenizer, config, accelerator, f"epoch-{config.experiment.current_epoch}-{config.training.target}")

    accelerator.end_training()


if __name__ == "__main__":
    main()
