# -*- coding: utf-8 -*-
"""
osworld_rl_lazy.py  (PREPROC-LAZY VERSION)

✅ 算法/目标函数/训练逻辑不改（GRPO/PPO-style 完全一致）
✅ 仅工程改动：
   1) 不再 torch.load 一个巨型 *_preproc_merged.pt 到内存
   2) 改为 lazy 从“预处理 shard pt”按需读取（LRU 缓存 1~K 个 shard）
   3) 不再分配 N×L 的 logp_old_tok 常驻内存；改为按样本保存 response tokens 的 old logp（数学等价）

预期你已经跑过：
  python -m train.osworld_vlm_preprocess_shards ...
并且在 <project>/temp_data/ 下能看到若干个预处理 shard pt 文件（不是 merged）。

如果你当前 pipeline 会先 merge，再 train：
  - merge 可以保留（方便别的脚本）
  - 训练脚本会优先使用 shard pt，不会去读 merged pt
"""

import os
import sys
from pathlib import Path
import re
import json
import math
import time
import logging
from bisect import bisect_right
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from omegaconf import OmegaConf
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoProcessor,
)

from train.utils import get_config, flatten_omega_conf
from train.lr_schedulers import get_scheduler
from train.log_utils import set_verbosity_info, set_verbosity_error

try:
    import wandb
except ImportError:
    wandb = None

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")

logger = get_logger(__name__, log_level="INFO")


# ============================================================
# Utilities: discover preprocessed shard files
# ============================================================

def _unique_paths(paths: List[Path]) -> List[Path]:
    seen = set()
    out = []
    for p in paths:
        rp = str(p.resolve())
        if rp in seen:
            continue
        seen.add(rp)
        out.append(p)
    return out


def discover_preproc_shards(
    project_name: str,
    optimization_data: str,
    override_glob: Optional[str] = None,
) -> List[Path]:
    """
    Try to find preprocessed shard pt files under:
      <project>/temp_data/

    We intentionally exclude:
      - *_preproc_merged.pt (too big)
      - old_logp cache files
    """
    temp_dir = Path(project_name) / "temp_data"
    if not temp_dir.exists():
        raise FileNotFoundError(f"Missing temp_data dir: {temp_dir}")

    candidates: List[Path] = []
    if override_glob:
        # allow user to specify exact pattern
        candidates.extend(sorted(temp_dir.glob(override_glob)))
    else:
        # common naming guesses
        pats = [
            f"{optimization_data}_preproc_shard*.pt",
            f"{optimization_data}_preproc_node*.pt",
            f"{optimization_data}_preproc_rank*.pt",
            f"{optimization_data}_preproc_part*.pt",
            f"*{optimization_data}*preproc*shard*.pt",
            f"*{optimization_data}*preproc*node*.pt",
            f"*{optimization_data}*preproc*rank*.pt",
        ]
        for pat in pats:
            candidates.extend(sorted(temp_dir.glob(pat)))

        # as last resort: any pt containing both optimization_data + "preproc" but not "merged"
        if not candidates:
            for p in sorted(temp_dir.glob("*.pt")):
                s = p.name
                if optimization_data in s and "preproc" in s and "merged" not in s and "old_logp" not in s:
                    candidates.append(p)

    candidates = [p for p in candidates if p.exists()]
    candidates = [p for p in candidates if ("merged" not in p.name and "old_logp" not in p.name)]
    candidates = _unique_paths(candidates)

    if not candidates:
        # help debugging
        all_pt = sorted(temp_dir.glob("*.pt"))
        hint = "\n".join([f"  - {x.name}" for x in all_pt[:50]])
        raise FileNotFoundError(
            f"Cannot find preproc shard pt under {temp_dir} for optimization_data={optimization_data}.\n"
            f"Existing *.pt (first 50):\n{hint}"
        )

    candidates.sort(key=lambda p: p.name)
    return candidates


# ============================================================
# Lazy dataset reading preprocessed shard pt
# ============================================================

def _get_shard_length(obj: Dict[str, Any]) -> int:
    # padded format
    if "input_ids" in obj and torch.is_tensor(obj["input_ids"]):
        return int(obj["input_ids"].size(0))
    if "labels" in obj and torch.is_tensor(obj["labels"]):
        return int(obj["labels"].size(0))

    # packed format (your preprocess_shards output)
    if "offsets" in obj and torch.is_tensor(obj["offsets"]):
        return int(obj["offsets"].numel() - 1)

    # fallback list format
    if "items" in obj and isinstance(obj["items"], list):
        return len(obj["items"])

    raise KeyError(f"Unknown shard format keys={list(obj.keys())[:30]}")



def _get_pixel_at(pv: Any, i: int):
    if pv is None:
        return None
    if torch.is_tensor(pv):
        return pv[i]
    if isinstance(pv, list):
        return pv[i]
    return pv  # fallback


def _get_grid_at(g: Any, i: int):
    if g is None:
        return None
    if torch.is_tensor(g):
        return g[i]
    if isinstance(g, list):
        return g[i]
    return g


class PreprocShardLazyDataset(Dataset):
    """
    Lazy dataset backed by multiple preprocessed shard pt files.

    LRU cache keeps only K shard objects in RAM per process.
    This is critical because DDP spawns many processes, each with its own Dataset object.
    """

    def __init__(
        self,
        shard_paths: List[Path],
        lru_shards: int = 1,
    ):
        self.shard_paths = [Path(p) for p in shard_paths]
        self.lru_shards = int(max(1, lru_shards))

        # compute shard lengths and prefix sums
        self.shard_lengths: List[int] = []
        self.prefix: List[int] = []

        total = 0
        for p in self.shard_paths:
            obj = torch.load(p, map_location="cpu")
            n = _get_shard_length(obj)
            del obj
            self.shard_lengths.append(n)
            total += n
            self.prefix.append(total)

        self.total_n = total

        # LRU cache: shard_idx -> loaded_obj
        self._cache: "OrderedDict[int, Dict[str, Any]]" = OrderedDict()

        # old logp store: only response tokens
        # sample_id -> 1D float32 tensor (len = #response_tokens_in_shift_space)
        self.old_logp_resp: Dict[int, torch.Tensor] = {}

    def __len__(self):
        return self.total_n

    def _locate(self, idx: int) -> Tuple[int, int]:
        # find first prefix > idx
        sidx = bisect_right(self.prefix, idx)
        start = 0 if sidx == 0 else self.prefix[sidx - 1]
        local = idx - start
        return sidx, local

    def _load_shard(self, shard_idx: int) -> Dict[str, Any]:
        if shard_idx in self._cache:
            self._cache.move_to_end(shard_idx)
            return self._cache[shard_idx]

        obj = torch.load(self.shard_paths[shard_idx], map_location="cpu")

        self._cache[shard_idx] = obj
        self._cache.move_to_end(shard_idx)

        while len(self._cache) > self.lru_shards:
            _, old = self._cache.popitem(last=False)
            # help GC
            del old

        return obj

    def __getitem__(self, idx: int):
        shard_idx, local = self._locate(int(idx))
        shard = self._load_shard(shard_idx)

        # -------- padded/merged format --------
        if "input_ids" in shard:
            input_ids = shard["input_ids"][local]
            p_mask = shard["p_mask"][local]
            labels = shard["labels"][local]
            advantage = shard["advantage"][local]
            if torch.is_tensor(advantage):
                advantage = float(advantage.item())
            else:
                advantage = float(advantage)

            pv = shard.get("pixel_values_list", shard.get("pixel_values", None))
            gt = shard.get("grid_thws_list", shard.get("grid_thws", None))
            pixel_values = _get_pixel_at(pv, local)
            grid_thws = _get_grid_at(gt, local)

            # 7-tuple: (id, input_ids, p_mask, labels, adv, pv, gt)
            return (
                int(idx),
                input_ids.long(),
                p_mask.bool(),
                labels.long(),
                float(advantage),
                pixel_values,
                grid_thws,
            )

        # -------- packed format (your preprocess_shards) --------
        if "flat_input_ids" in shard and "offsets" in shard:
            flat = shard["flat_input_ids"]
            offsets = shard["offsets"]
            a = int(offsets[local].item())
            b = int(offsets[local + 1].item())
            seq = flat[a:b].clone()  # 1D Long (Li,)

            start_pos = int(shard["start_pos"][local].item())
            advantage = float(shard["advantage"][local].item())

            pv_list = shard.get("pixel_values_list", None)
            gt_list = shard.get("grid_thws_list", None)
            pixel_values = pv_list[local] if pv_list is not None else None
            grid_thws = gt_list[local] if gt_list is not None else None

            # 6-tuple: (id, seq_1d, start_pos, adv, pv, gt)
            return (
                int(idx),
                seq.long(),
                start_pos,
                advantage,
                pixel_values,
                grid_thws,
            )

        # fallback list
        if "items" in shard:
            it = shard["items"][local]
            return (
                int(idx),
                it["input_ids"].long(),
                it["p_mask"].bool(),
                it["labels"].long(),
                float(it["advantage"]),
                it.get("pixel_values", None),
                it.get("grid_thws", None),
            )

        raise KeyError(f"Unsupported shard keys: {list(shard.keys())[:30]}")



def collate_preproc_batch(batch, pad_id: int):
    # padded/merged format: 7-tuple
    if len(batch[0]) == 7:
        ids, input_ids, p_mask, labels, adv, pixel_values, grid_thws = zip(*batch)
        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "input_ids": torch.stack(input_ids),
            "p_mask": torch.stack(p_mask),
            "labels": torch.stack(labels),
            "advantage": torch.tensor(adv, dtype=torch.float32),
            "pixel_values": list(pixel_values),
            "grid_thws": list(grid_thws),
        }

    # packed format: 6-tuple
    if len(batch[0]) == 6:
        ids, seqs, start_pos, adv, pixel_values, grid_thws = zip(*batch)
        B = len(seqs)
        lens = [int(s.numel()) for s in seqs]
        Lmax = max(lens) if lens else 1

        input_ids = torch.full((B, Lmax), pad_id, dtype=torch.long)
        labels    = torch.full((B, Lmax), pad_id, dtype=torch.long)
        p_mask    = torch.zeros((B, Lmax), dtype=torch.bool)

        for i in range(B):
            li = lens[i]
            sp = int(start_pos[i])
            if li > 0:
                input_ids[i, :li] = seqs[i]
                labels[i, :li] = seqs[i]
            if li > sp:
                p_mask[i, sp:li] = True

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "input_ids": input_ids,
            "p_mask": p_mask,
            "labels": labels,
            "advantage": torch.tensor(adv, dtype=torch.float32),
            "pixel_values": list(pixel_values),
            "grid_thws": list(grid_thws),
            "start_pos": torch.tensor(start_pos, dtype=torch.long),
            "lengths": torch.tensor(lens, dtype=torch.long),
        }

    raise RuntimeError(f"Unknown batch tuple len={len(batch[0])}")



# ============================================================
# Mask/pos helpers (unchanged)
# ============================================================

def make_attention_mask(input_ids: torch.Tensor, pad_id: int):
    return (input_ids != pad_id).to(torch.long)

def make_position_ids(attention_mask: torch.Tensor):
    position_ids = attention_mask.cumsum(dim=1) - 1
    position_ids.masked_fill_(attention_mask == 0, 0)
    return position_ids


# ============================================================
# Precompute old logp (no N×L tensor) - store response tokens only
# ============================================================

@torch.no_grad()
def compute_logp_old_tok_parallel_lazy(
    accelerator: Accelerator,
    model,
    dataset: PreprocShardLazyDataset,
    train_dataloader: DataLoader,
    pad_id: int,
    is_qwen3_vl: bool,
):
    """
    Same math as your original compute_logp_old_tok_parallel,
    but we DO NOT allocate dataset.logp_old_tok (N×L).
    Instead store:
      dataset.old_logp_resp[sample_id] = old logp on response tokens (shifted space)
    """
    model.eval()
    from tqdm.auto import tqdm

    iterator = tqdm(
        train_dataloader,
        desc="Precomputing old token log-probs (lazy)",
        dynamic_ncols=True,
        disable=not accelerator.is_local_main_process,
        leave=True,
    )

    for batch in iterator:
        ids = batch["ids"].tolist()
        input_ids_batch = batch["input_ids"].to(accelerator.device)  # (B,L)
        p_mask_batch = batch["p_mask"].to(accelerator.device)        # (B,L)
        pixel_values_list = batch["pixel_values"]
        grid_thws_list = batch["grid_thws"]

        B, L = input_ids_batch.shape

        for i in range(B):
            sample_id = int(ids[i])

            seq = input_ids_batch[i : i + 1]          # (1, L)
            p_mask = p_mask_batch[i : i + 1]          # (1, L)
            attention_mask = make_attention_mask(seq, pad_id)
            kwargs = dict(input_ids=seq, attention_mask=attention_mask)

            if not is_qwen3_vl:
                kwargs["position_ids"] = make_position_ids(attention_mask)

            pv = pixel_values_list[i]
            if pv is not None:
                kwargs["pixel_values"] = pv.to(accelerator.device)

            g = grid_thws_list[i]
            if g is not None:
                g = g.to(accelerator.device)
                if is_qwen3_vl:
                    kwargs["image_grid_thw"] = g
                else:
                    kwargs["grid_thws"] = g

            out = model(**kwargs)
            logits = out.logits  # (1, L, V)

            logits_shifted = logits[:, :-1, :]
            labels_shifted = seq[:, 1:]
            log_probs = F.log_softmax(logits_shifted, dim=-1)
            logp_tok = log_probs.gather(dim=-1, index=labels_shifted.unsqueeze(-1)).squeeze(-1)  # (1, L-1)

            p_mask_shift = p_mask[:, 1:]  # (1, L-1)
            resp_mask = p_mask_shift[0].bool()
            old_logp_resp = logp_tok[0, resp_mask].float().detach().cpu()

            dataset.old_logp_resp[sample_id] = old_logp_resp

    accelerator.wait_for_everyone()
    model.train()


# ============================================================
# forward_process (loss) - same formula, old logp comes from dataset.old_logp_resp
# ============================================================

def forward_process_vlm_lazy(
    accelerator: Accelerator,
    model,
    batch: Dict[str, Any],
    dataset: PreprocShardLazyDataset,
    tokenizer,
    config,
    is_qwen3_vl: bool,
):
    device = accelerator.device

    ids = batch["ids"].tolist()
    input_ids_batch = batch["input_ids"].to(device)      # (B, L)
    p_mask_batch = batch["p_mask"].to(device)            # (B, L)
    adv_batch = batch["advantage"].to(device)            # (B,)
    pixel_values_list = batch["pixel_values"]
    grid_thws_list = batch["grid_thws"]

    pad_id = tokenizer.pad_token_id
    B, L = input_ids_batch.shape

    total_loss = torch.tensor(0.0, device=device)

    for i in range(B):
        sample_id = int(ids[i])

        input_ids = input_ids_batch[i : i + 1]      # (1,L)
        p_mask = p_mask_batch[i : i + 1]            # (1,L)
        adv = adv_batch[i : i + 1]                  # (1,)

        attention_mask = make_attention_mask(input_ids, pad_id)
        kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)

        if not is_qwen3_vl:
            kwargs["position_ids"] = make_position_ids(attention_mask)

        pv = pixel_values_list[i]
        if pv is not None:
            kwargs["pixel_values"] = pv.to(device)

        g = grid_thws_list[i]
        if g is not None:
            if is_qwen3_vl:
                kwargs["image_grid_thw"] = g.to(device)
            else:
                kwargs["grid_thws"] = g.to(device)

        out = model(**kwargs)
        logits = out.logits  # (1, L, V)

        logits_shifted = logits[:, :-1, :]
        labels_shifted = input_ids[:, 1:]

        log_probs = F.log_softmax(logits_shifted, dim=-1)
        logp_new_tok = log_probs.gather(dim=-1, index=labels_shifted.unsqueeze(-1)).squeeze(-1)  # (1, L-1)

        p_mask_shift = p_mask[:, 1:]  # (1, L-1)
        resp_mask = p_mask_shift[0].bool()
        n_resp = int(resp_mask.sum().item())

        old_resp = dataset.old_logp_resp.get(sample_id, None)
        if old_resp is None:
            raise RuntimeError(
                f"[old_logp missing] sample_id={sample_id} not found in dataset.old_logp_resp. "
                f"Did precompute step finish on this rank?"
            )
        if int(old_resp.numel()) != n_resp:
            raise RuntimeError(
                f"[old_logp mismatch] sample_id={sample_id} need n_resp={n_resp}, cached={old_resp.numel()}."
            )

        # reconstruct old_lp_shift (only response positions matter; others are masked out anyway)
        old_lp_shift = logp_new_tok.detach().clone()
        if n_resp > 0:
            old_lp_shift[0, resp_mask] = old_resp.to(device)

        diff = (logp_new_tok - old_lp_shift)
        diff = torch.where(p_mask_shift, diff, torch.zeros_like(diff))
        diff = diff.clamp(-10.0, 10.0)
        ratio = torch.exp(diff)

        clipped = torch.clamp(
            ratio,
            1.0 - config.training.eps,
            1.0 + config.training.eps,
        )

        adv_tok = adv.unsqueeze(1)
        surrogate = torch.min(ratio * adv_tok, clipped * adv_tok)
        surrogate = surrogate * p_mask_shift

        denom = torch.clamp(p_mask_shift.sum(dim=1), min=1.0)
        policy_loss = - (surrogate.sum(dim=1) / denom).mean()

        kl_loss = torch.tensor(0.0, device=device)
        if config.training.beta > 0:
            kl_seq = (logp_new_tok - old_lp_shift)
            kl_seq = torch.where(p_mask_shift, kl_seq, torch.zeros_like(kl_seq))
            if getattr(config.training, "use_kl_estimator_k3", False):
                t = (-kl_seq).clamp(-10.0, 10.0)
                kl_seq = t.exp() - 1.0 + kl_seq
            kl_seq = (kl_seq * p_mask_shift).sum(dim=1) / torch.clamp(denom, min=1.0)
            kl_loss = config.training.beta * kl_seq.mean()

        total_loss = total_loss + (policy_loss + kl_loss)

    return total_loss / max(B, 1)


# ============================================================
# Checkpoint (keep your original save_checkpoint behavior)
# ============================================================

def save_checkpoint(model, tokenizer, config, accelerator, name: str):
    import json as _json, shutil as _shutil, time as _time
    from pathlib import Path as _Path

    project_name = config.experiment.project
    rl_base_dir = _Path(config.system.rl_base_dir)
    save_base = rl_base_dir / project_name / "ckpt"
    save_dir = save_base / name

    model_to_save = accelerator.unwrap_model(model)
    state_dict = accelerator.get_state_dict(model)

    if accelerator.is_main_process:
        save_dir.mkdir(parents=True, exist_ok=True)

        model_to_save.save_pretrained(
            save_dir,
            save_function=accelerator.save,
            state_dict=state_dict,
            safe_serialization=True,
        )

        tokenizer.save_pretrained(str(save_dir))

        init_kwargs = getattr(tokenizer, "init_kwargs", {})
        tok_name = getattr(tokenizer, "name_or_path", "") or init_kwargs.get("_name_or_path", "")

        target = getattr(config.training, "target", None)

        candidate_base = []
        if tok_name:
            candidate_base.append(_Path(tok_name))
        if target == "policy":
            candidate_base.append(_Path(config.model.policy_model))
        elif target == "reward":
            candidate_base.append(_Path(config.model.reward_model))

        base_dir = None
        for p in candidate_base:
            if p and p.exists():
                base_dir = p
                break

        tok_class_name = tokenizer.__class__.__name__
        tok_cfg_class = init_kwargs.get("tokenizer_class", "")

        is_opencua = (
            "TikTokenV3" in tok_class_name
            or "TikTokenV3" in str(tok_cfg_class)
            or "OpenCUA" in str(tok_name)
        )

        if is_opencua and base_dir is not None:
            extra_files = [
                "tokenization_opencua.py",
                "configuration_opencua.py",
                "modeling_opencua.py",
                "tiktoken.model",
                "preprocessor_config.json",
                "image_processor_config.json",
                "processor_config.json",
                "generation_config.json",
            ]
            for fn in extra_files:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    _shutil.copy(src, dst)

            tok_cfg_path = save_dir / "tokenizer_config.json"
            if tok_cfg_path.exists():
                with tok_cfg_path.open("r", encoding="utf-8") as f:
                    cfg = _json.load(f)
                cfg["tokenizer_class"] = "tokenization_opencua.TikTokenV3"
                auto_map = cfg.get("auto_map", {})
                auto_map["AutoTokenizer"] = ["tokenization_opencua.TikTokenV3", None]
                cfg["auto_map"] = auto_map
                with tok_cfg_path.open("w", encoding="utf-8") as f:
                    _json.dump(cfg, f, indent=2, ensure_ascii=False)

        if base_dir is not None and not is_opencua:
            for fn in ["preprocessor_config.json", "processor_config.json", "image_processor_config.json"]:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    _shutil.copy(src, dst)

        metadata = {"save_time": _time.strftime("%Y-%m-%d %H:%M:%S")}
        with (save_base / "metadata.json").open("w") as f:
            _json.dump(metadata, f, indent=2)

        logger.info(f"Saved model + tokenizer to {save_dir}")


def enable_gc(model):
    if hasattr(model, "config") and hasattr(model.config, "use_cache"):
        model.config.use_cache = False
    if hasattr(model, "language_model") and hasattr(model.language_model, "config"):
        if hasattr(model.language_model.config, "use_cache"):
            model.language_model.config.use_cache = False

    if hasattr(model, "language_model") and hasattr(model.language_model, "gradient_checkpointing_enable"):
        model.language_model.gradient_checkpointing_enable()
        for m in model.language_model.modules():
            if hasattr(m, "gradient_checkpointing"):
                m.gradient_checkpointing = True
        logger.info("[GC] Enabled on model.language_model (OpenCUA/Qwen-style)")
        return

    if hasattr(model, "gradient_checkpointing_enable"):
        try:
            model.gradient_checkpointing_enable()
            for m in model.modules():
                if hasattr(m, "gradient_checkpointing"):
                    m.gradient_checkpointing = True
            logger.info("[GC] Enabled on model directly")
            return
        except ValueError as e:
            logger.warning(f"[GC] Direct enable failed: {e}")

    any_flag = False
    for m in model.modules():
        if hasattr(m, "gradient_checkpointing"):
            m.gradient_checkpointing = True
            any_flag = True
    if any_flag:
        logger.info("[GC] Enabled by setting gradient_checkpointing=True on submodules")
    else:
        logger.warning("[GC] No supported gradient checkpointing hooks found; running without GC.")


# ============================================================
# Main
# ============================================================

def main():
    config = get_config()
    project_name = config.experiment.project

    if config.training.target == "policy":
        if config.experiment.current_epoch == 1:
            pretrained_model = config.model.policy_model
        else:
            pretrained_model = config.system.rl_base_dir + "/" + project_name + "/ckpt/" + config.model.optimized_name
        optimized_name = config.model.optimized_name
        max_prompt_len = config.training.policy.max_prompt_len
        max_gen_length = config.training.policy.max_gen_length
        optimization_data = "policy_optimization_data"
        update_per_step = config.training.policy.update_per_step
        batch_size_lm = config.training.policy.batch_size_lm
        gradient_checkpointing_enable = config.training.policy.gradient_checkpointing_enable
    elif config.training.target == "reward":
        if config.experiment.current_epoch == 1:
            pretrained_model = config.model.reward_model
        else:
            pretrained_model = config.system.rl_base_dir + "/" + project_name + "/ckpt/" + config.model.optimized_reward_name
        optimized_name = config.model.optimized_reward_name
        max_prompt_len = config.training.reward.max_prompt_len
        max_gen_length = config.training.reward.max_gen_length
        optimization_data = "reward_optimization_data"
        update_per_step = config.training.reward.update_per_step
        batch_size_lm = config.training.reward.batch_size_lm
        gradient_checkpointing_enable = config.training.reward.gradient_checkpointing_enable
    else:
        raise ValueError(f"Unknown training.target = {config.training.target}")

    if config.training.enable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    is_qwen3_vl = (config.model_type == "qwen3vl")

    # ---- Discover preprocessed shards (NO json, NO tokenization) ----
    override_glob = OmegaConf.select(config, "dataset.preproc_shard_glob", default=None)
    lru_shards = int(OmegaConf.select(config, "dataset.preproc_lru_shards", default=1))

    shard_paths = discover_preproc_shards(
        project_name=project_name,
        optimization_data=optimization_data,
        override_glob=override_glob,
    )

    dataset_lm = PreprocShardLazyDataset(shard_paths=shard_paths, lru_shards=lru_shards)
    total_n = len(dataset_lm)

    # ---- tokenizer only (for pad_id + saving ckpt); data is already tokenized ----
    if is_qwen3_vl:
        processor = AutoProcessor.from_pretrained(pretrained_model, trust_remote_code=True)
        tokenizer = processor.tokenizer
    else:
        processor = None
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)

    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    pad_id = tokenizer.pad_token_id

    ws = int(os.environ.get("WORLD_SIZE", "1"))
    gradient_accumulation_steps = max(
        1,
        math.ceil(total_n / (update_per_step * batch_size_lm * ws)),
    )

    config.experiment.logging_dir = str(Path(project_name) / "logs")
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=config.training.mixed_precision,
        log_with=None,
        project_dir=config.experiment.logging_dir,
        split_batches=True,
    )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        set_verbosity_info()
    else:
        set_verbosity_error()

    if accelerator.is_main_process and wandb is not None:
        resume_wandb_run = config.wandb.resume
        run_id = config.wandb.get("run_id", None)
        if run_id is None:
            resume_wandb_run = False
            run_id = wandb.util.generate_id()
            config.wandb.run_id = run_id

        wandb_init_kwargs = dict(
            name=project_name,
            id=run_id,
            resume=resume_wandb_run,
            entity=config.wandb.get("entity", None),
            config_exclude_keys=[],
        )
        wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
        wandb_config.pop("experiment.resume_from_checkpoint", None)

        accelerator.init_trackers(
            project_name,
            config=wandb_config,
            init_kwargs={"wandb": wandb_init_kwargs},
        )

    if accelerator.is_main_process:
        os.makedirs(project_name, exist_ok=True)
        config_path = Path(project_name) / "config.yaml"
        logging.info(f"Saving config to {config_path}")
        OmegaConf.save(config, config_path)

    if config.training.seed is not None:
        set_seed(config.training.seed)

    # ---- Dataloader ----
    train_dataloader_lm = DataLoader(
        dataset_lm,
        batch_size=batch_size_lm,
        sampler=None,
        collate_fn=lambda b: collate_preproc_batch(b, pad_id=pad_id),
        num_workers=0,
    )

    # ---- Model ----
    logger.info("Loading VLM model and optimizer")
    model = AutoModel.from_pretrained(pretrained_model, trust_remote_code=True, torch_dtype="auto")

    if gradient_checkpointing_enable:
        enable_gc(model)

    optimizer_config = config.optimizer.params
    no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and not any(nd in n for nd in no_decay)
            ],
            "weight_decay": optimizer_config.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if config.optimizer.name == "adamw":
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=optimizer_config.learning_rate,
            betas=(optimizer_config.beta1, optimizer_config.beta2),
            weight_decay=optimizer_config.weight_decay,
            eps=optimizer_config.epsilon,
        )
    else:
        raise ValueError(f"Optimizer {config.optimizer.name} not supported")

    total_batch_size_lm = batch_size_lm * ws * gradient_accumulation_steps
    num_update_steps_per_epoch = max(1, math.ceil(total_n / total_batch_size_lm))
    num_train_epochs = config.training.num_train_epochs
    max_train_steps = num_update_steps_per_epoch * num_train_epochs

    lr_scheduler = get_scheduler(
        config.lr_scheduler.scheduler,
        optimizer=optimizer,
        num_training_steps=max_train_steps,
        num_warmup_steps=config.lr_scheduler.params.warmup_steps,
        min_lr_scale=config.lr_scheduler.params.min_lr_scale,
    )

    logger.info("Preparing model, optimizer, scheduler and dataloader")
    model, optimizer, lr_scheduler, train_dataloader_lm = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader_lm
    )

    # ---- Precompute old logp (lazy) ----
    logger.info("***** Running inference (precompute old logp) *****")
    compute_logp_old_tok_parallel_lazy(
        accelerator=accelerator,
        model=model,
        dataset=dataset_lm,
        train_dataloader=train_dataloader_lm,
        pad_id=pad_id,
        is_qwen3_vl=is_qwen3_vl,
    )

    # ---- Training ----
    logger.info("***** Running training *****")
    logger.info(f"  Num training data = {len(dataset_lm)}")
    logger.info(f"  Num training steps = {max_train_steps}")
    logger.info(f"  Instantaneous batch size per device = {batch_size_lm}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size_lm}")
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")

    from tqdm.auto import tqdm
    for epoch in range(num_train_epochs):
        model.train()
        progress = tqdm(
            train_dataloader_lm,
            disable=not accelerator.is_local_main_process,
            dynamic_ncols=True,
        )

        for step, batch in enumerate(progress, start=1):
            loss = forward_process_vlm_lazy(
                accelerator=accelerator,
                model=model,
                batch=batch,
                dataset=dataset_lm,
                tokenizer=tokenizer,
                config=config,
                is_qwen3_vl=is_qwen3_vl,
            )

            loss = loss / gradient_accumulation_steps
            accelerator.backward(loss)

            if (step % gradient_accumulation_steps) == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

                loss_val = float(accelerator.gather(loss.detach()).mean().item())
                progress.set_postfix(loss=loss_val)

    torch.cuda.empty_cache()
    accelerator.wait_for_everyone()

    save_checkpoint(model, tokenizer, config, accelerator, optimized_name)
    if config.experiment.current_epoch % config.experiment.save_every == 0:
        save_checkpoint(model, tokenizer, config, accelerator, f"epoch-{config.experiment.current_epoch}")

    accelerator.end_training()


if __name__ == "__main__":
    main()
