# -*- coding: utf-8 -*-
"""
SFT training for OpenCUA / Qwen3-VL VLM policy model.

不同于 GRPO 版：
- 输入仍然是 policy_optimization_data.json / reward_optimization_data.json
  [
    {
      "domain": "...",
      "example": "...",
      "run_id": 0,
      "step": 0,
      "prompt_messages": [...],   # system/user 多模态，格式同 serve_opencua.py
      "response": "...",          # assistant 文本
      "reward": 0.0               # 会被忽略，在 SFT 中不用
    },
    ...
  ]

- 多模态处理与 serve_opencua / GRPO 版完全对齐：
  - prompt_messages -> tokenizer/processor.apply_chat_template(..., add_generation_prompt=True)
  - images in prompt_messages -> AutoImageProcessor.preprocess -> pixel_values + grid_thws

- SFT 风格：
  - 对 response tokens 做标准自回归交叉熵：
      - labels 在 prompt 部分设为 -100（不参与 loss）
      - labels 在 response 部分 = 对应 token id
      - 使用 ignore_index=-100 的 cross_entropy
"""

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ["TOKENIZERS_PARALLELISM"] = "true"

import io
import json
import math
import time
import shutil
import base64
import logging
from pathlib import Path
from typing import Any, Dict, List

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from omegaconf import OmegaConf

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from transformers import (
    AutoTokenizer,
    AutoImageProcessor,
    AutoModel,
    AutoProcessor,
)

from train.utils import get_config, flatten_omega_conf, AverageMeter
from train.lr_schedulers import get_scheduler
from train.log_utils import set_verbosity_info, set_verbosity_error

try:
    import wandb
except ImportError:
    wandb = None

logger = get_logger(__name__, log_level="INFO")


# ==================== Dataset ====================

class VLMSFTDataset(Dataset):
    """
    存储：
      - input_ids: (N, L)
      - labels:    (N, L)，其中：
          * prompt / pad 位置为 -100（不参与 loss）
          * response tokens 位置为真实 token id
      - pixel_values_list: 每样本一个 Tensor 或 None
      - grid_thws_list:    每样本一个 Tensor 或 None
    """
    def __init__(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        pixel_values_list: List[Any],
        grid_thws_list: List[Any],
    ):
        assert input_ids.shape == labels.shape
        assert input_ids.size(0) == len(pixel_values_list) == len(grid_thws_list)

        self.input_ids = input_ids
        self.labels = labels
        self.pixel_values_list = pixel_values_list
        self.grid_thws_list = grid_thws_list

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return (
            self.input_ids[idx],
            self.labels[idx],
            self.pixel_values_list[idx],
            self.grid_thws_list[idx],
        )


def sft_collate(batch):
    """
    collate：stack input_ids / labels，其余保持 list（逐样本处理图像）。
    """
    input_ids, labels, pixel_values, grid_thws = zip(*batch)

    return {
        "input_ids": torch.stack(input_ids),   # (B, L)
        "labels": torch.stack(labels),         # (B, L)
        "pixel_values": list(pixel_values),    # len=B, 每个是 Tensor 或 None
        "grid_thws": list(grid_thws),          # len=B, 每个是 Tensor 或 None
    }


# ==================== Utils for images / prompts ====================

def _decode_image(src: str) -> Image.Image:
    s = src.strip()
    if s.startswith("http://") or s.startswith("https://"):
        # 训练阶段一般是本地路径，这里简单支持一下 http(s) / data URL
        import requests
        resp = requests.get(s, timeout=10)
        resp.raise_for_status()
        return Image.open(io.BytesIO(resp.content)).convert("RGB")
    # data/base64
    if not os.path.exists(s):
        # 可能是纯 base64 / data URI
        b64 = s.split(",", 1)[1] if "," in s else s
        raw = base64.b64decode(b64, validate=False)
        return Image.open(io.BytesIO(raw)).convert("RGB")
    # 本地路径
    return Image.open(s).convert("RGB")


def collect_images_from_messages(messages: List[Dict[str, Any]]) -> List[Image.Image]:
    pil_images: List[Image.Image] = []
    for m in messages:
        content = m.get("content")
        if isinstance(content, list):
            for part in content:
                t = part.get("type")
                if t == "image":
                    src = part.get("image")
                    if src:
                        try:
                            pil_images.append(_decode_image(src))
                        except Exception as e:
                            print(f"[WARN] decode image failed: {src} | {e}", flush=True)
                elif t == "image_url":
                    url = (part.get("image_url") or {}).get("url")
                    if url:
                        try:
                            pil_images.append(_decode_image(url))
                        except Exception as e:
                            print(f"[WARN] decode image_url failed: {url} | {e}", flush=True)
    return pil_images


def build_vlm_sft_dataset(
    data: List[Dict[str, Any]],
    tokenizer,
    image_processor,
    max_prompt_len: int,
    max_gen_length: int,
    is_qwen3_vl: bool = False,
    processor=None,
) -> VLMSFTDataset:
    """
    按 GRPO 版的思路构造，但用于 SFT：

      input_ids = chat_template(prompt_messages, add_generation_prompt=True) + tokenize(response)
      labels    = 与 input_ids shape 相同，但：
                    - prompt / pad 位置 = -100
                    - response tokens 位置 = 对应 token id

      pixel_values / grid_thws 来自 prompt_messages 里的图片（与 serve_opencua 对齐）
    """
    # pad_token
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            # fallback
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    pad_id = tokenizer.pad_token_id

    samples: List[Dict[str, Any]] = []
    max_len = 0

    for item in data:
        prompt_messages = item["prompt_messages"]
        response_text = item["response"]

        # ---------- Qwen3-VL 分支：用 AutoProcessor 保证占位符和特征对齐 ----------
        if is_qwen3_vl:
            if processor is None:
                raise ValueError("Qwen3-VL requires `processor`, but got None.")

            # 规范 messages -> Qwen3-VL 格式
            qwen_messages = []
            for m in prompt_messages:
                role = m.get("role", "user")
                content = m.get("content", "")

                new_m = {"role": role}
                new_content = []

                # 情况 1：原本就是字符串
                if isinstance(content, str):
                    new_content.append({"type": "text", "text": content})

                # 情况 2：是列表（可能混有 str / dict）
                elif isinstance(content, list):
                    for part in content:
                        # 字符串 -> text
                        if isinstance(part, str):
                            new_content.append({"type": "text", "text": part})
                        # dict，根据 type 归一
                        elif isinstance(part, dict):
                            p_type = part.get("type")
                            if p_type in ("text", "paragraph"):
                                txt = part.get("text", "")
                                new_content.append({"type": "text", "text": txt})
                            elif p_type in ("image", "video"):
                                new_content.append(part)
                            elif p_type == "image_url":
                                url = (part.get("image_url") or {}).get("url") or part.get("url")
                                if url:
                                    new_content.append({"type": "image", "image": url})
                            else:
                                txt = part.get("text", None)
                                if txt is not None:
                                    new_content.append({"type": "text", "text": txt})
                                else:
                                    new_content.append({"type": "text", "text": str(part)})
                        else:
                            new_content.append({"type": "text", "text": str(part)})
                else:
                    # 其他奇怪类型，兜底成文本
                    new_content.append({"type": "text", "text": str(content)})

                new_m["content"] = new_content
                qwen_messages.append(new_m)

            proc_inputs = processor.apply_chat_template(
                qwen_messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
                truncation=False,
                max_length=100000,
            )

            prompt_ids = proc_inputs["input_ids"][0].tolist()

            pixel_values = proc_inputs.get("pixel_values", None)
            if pixel_values is not None:
                pixel_values = pixel_values.clone()

            grid = proc_inputs.get("image_grid_thw", None)
            if grid is not None:
                grid_thws = grid.clone().long()
            else:
                grid_thws = None

        # ---------- 非 Qwen3-VL 分支：保持原来的 OpenCUA / policy 行为 ----------
        else:
            # 1) prompt token ids
            prompt_ids = tokenizer.apply_chat_template(
                prompt_messages,
                tokenize=True,
                add_generation_prompt=True,
            )

            # 2) 图像（只看 prompt_messages）
            pil_images = collect_images_from_messages(prompt_messages)
            pixel_values = None
            grid_thws = None
            if pil_images:
                info = image_processor.preprocess(images=pil_images)

                pixel = info.get("pixel_values", None)
                if pixel is None:
                    pixel = info.get("pixel_values_videos", None)

                grid = info.get("image_grid_thw", None)
                if grid is None:
                    grid = info.get("images_grid_thw", None)
                if grid is None:
                    grid = info.get("grid_thws", None)

                if pixel is not None:
                    pixel_values = torch.as_tensor(pixel)
                if grid is not None:
                    grid_thws = torch.as_tensor(grid, dtype=torch.long)

        # 3) response token ids（只文本）
        resp_ids = tokenizer(
            response_text,
            add_special_tokens=False,
        )["input_ids"]

        # 4) 截断
        if max_prompt_len > 0 and len(prompt_ids) > max_prompt_len:
            prompt_ids = prompt_ids[-max_prompt_len:]
        if max_gen_length > 0 and len(resp_ids) > max_gen_length:
            resp_ids = resp_ids[:max_gen_length]

        input_ids = prompt_ids + resp_ids
        if len(input_ids) == 0:
            continue

        start_pos = len(prompt_ids)
        L = len(input_ids)
        max_len = max(max_len, L)

        samples.append(
            {
                "input_ids": input_ids,
                "start_pos": start_pos,
                "pixel_values": pixel_values,
                "grid_thws": grid_thws,
            }
        )

    if not samples:
        raise ValueError("No valid samples found in optimization data.")

    N = len(samples)
    input_ids_tensor = torch.full((N, max_len), pad_id, dtype=torch.long)
    labels_tensor = torch.full((N, max_len), -100, dtype=torch.long)  # 全部 -100，后面只填 response 部分
    pixel_values_list: List[Any] = []
    grid_thws_list: List[Any] = []

    for i, s in enumerate(samples):
        ids = s["input_ids"]
        start_pos = s["start_pos"]
        L = len(ids)

        ids_tensor = torch.tensor(ids, dtype=torch.long)
        input_ids_tensor[i, :L] = ids_tensor

        # labels 只在 response tokens 上填真实 token，其余保持 -100
        if L > start_pos:
            labels_tensor[i, start_pos:L] = ids_tensor[start_pos:L]

        pixel_values_list.append(s["pixel_values"])
        grid_thws_list.append(s["grid_thws"])

    return VLMSFTDataset(
        input_ids=input_ids_tensor,
        labels=labels_tensor,
        pixel_values_list=pixel_values_list,
        grid_thws_list=grid_thws_list,
    )


# ==================== Core train ====================

def make_attn_and_pos(input_ids: torch.Tensor, pad_id: int):
    """
    attention_mask: 1 for non-pad
    position_ids:   cumsum over non-pad, pad 位置为 0（不会用到）
    """
    attention_mask = (input_ids != pad_id).to(torch.long)
    position_ids = attention_mask.cumsum(dim=1) - 1
    position_ids.masked_fill_(attention_mask == 0, 0)
    return attention_mask, position_ids


def forward_sft_vlm(
    accelerator: Accelerator,
    model,
    batch: Dict[str, Any],
    tokenizer,
    is_qwen3_vl: bool,
):
    """
    单 step 的 SFT loss 计算（多模态版）：
      - 只在 labels != -100 的 token 上做交叉熵（也就是 response tokens）。
    同样逐样本 forward，以兼容不同 pixel_values/grid_thws shape。
    """
    device = accelerator.device

    input_ids_batch = batch["input_ids"].to(device)  # (B, L)
    labels_batch = batch["labels"].to(device)        # (B, L)
    pixel_values_list = batch["pixel_values"]
    grid_thws_list = batch["grid_thws"]

    pad_id = tokenizer.pad_token_id
    B, L = input_ids_batch.shape

    total_loss = torch.tensor(0.0, device=device)

    for i in range(B):
        input_ids = input_ids_batch[i : i + 1]  # (1, L)
        labels = labels_batch[i : i + 1]        # (1, L)

        attention_mask, position_ids = make_attn_and_pos(input_ids, pad_id)

        kwargs = dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        pv = pixel_values_list[i]
        if pv is not None:
            kwargs["pixel_values"] = pv.to(device)

        g = grid_thws_list[i]
        if g is not None:
            if is_qwen3_vl:
                kwargs["image_grid_thw"] = g.to(device)
            else:
                kwargs["grid_thws"] = g.to(device)

        out = model(**kwargs)
        logits = out.logits  # (1, L, V)

        # shift，和 GRPO 版保持一致
        logits_shifted = logits[:, :-1, :]   # (1, L-1, V)
        labels_shifted = labels[:, 1:]       # (1, L-1)

        loss = F.cross_entropy(
            logits_shifted.reshape(-1, logits_shifted.size(-1)),
            labels_shifted.reshape(-1),
            ignore_index=-100,
        )

        total_loss = total_loss + loss

    return total_loss / max(B, 1)


# ==================== Checkpoint ====================

def save_checkpoint(model, tokenizer, config, accelerator, name: str):
    import json, shutil, time
    from pathlib import Path

    project_name = config.experiment.project
    rl_base_dir = Path(config.system.rl_base_dir)
    save_base = rl_base_dir / project_name / "ckpt"
    save_dir = save_base / name

    # unwrap & gather
    model_to_save = accelerator.unwrap_model(model)
    state_dict = accelerator.get_state_dict(model)

    if accelerator.is_main_process:
        save_dir.mkdir(parents=True, exist_ok=True)

        # 1) 保存模型权重
        model_to_save.save_pretrained(
            save_dir,
            save_function=accelerator.save,
            state_dict=state_dict,
            safe_serialization=True,
        )

        # 2) 保存 tokenizer（所有模型通用）
        tokenizer.save_pretrained(str(save_dir))

        # ====== 推断 base_dir，用来拷各种配置 ======
        init_kwargs = getattr(tokenizer, "init_kwargs", {})
        tok_name = (
            getattr(tokenizer, "name_or_path", "")
            or init_kwargs.get("_name_or_path", "")
        )

        target = getattr(config.training, "target", None)

        candidate_base = []
        if tok_name:
            candidate_base.append(Path(tok_name))
        if target == "policy":
            candidate_base.append(Path(config.model.policy_model))
        elif target == "reward":
            candidate_base.append(Path(config.model.reward_model))

        base_dir = None
        for p in candidate_base:
            if p and p.exists():
                base_dir = p
                break

        # ====== 识别是不是 OpenCUA（TikTokenV3） ======
        tok_class_name = tokenizer.__class__.__name__
        tok_cfg_class = init_kwargs.get("tokenizer_class", "")

        is_opencua = (
            "TikTokenV3" in tok_class_name
            or "TikTokenV3" in str(tok_cfg_class)
            or "OpenCUA" in str(tok_name)
        )

        if is_opencua and base_dir is not None:
            # OpenCUA 需要的自定义代码和多模态配置
            extra_files = [
                "tokenization_opencua.py",
                "configuration_opencua.py",
                "modeling_opencua.py",
                "tiktoken.model",
                "preprocessor_config.json",
                "image_processor_config.json",
                "processor_config.json",
                "generation_config.json",
            ]
            for fn in extra_files:
                src = base_dir / fn
                dst = save_dir   / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

            # patch tokenizer_config.json -> 指向 tokenization_opencua.TikTokenV3
            tok_cfg_path = save_dir / "tokenizer_config.json"
            if tok_cfg_path.exists():
                with tok_cfg_path.open("r", encoding="utf-8") as f:
                    cfg = json.load(f)
                cfg["tokenizer_class"] = "tokenization_opencua.TikTokenV3"
                auto_map = cfg.get("auto_map", {})
                auto_map["AutoTokenizer"] = ["tokenization_opencua.TikTokenV3", None]
                cfg["auto_map"] = auto_map
                with tok_cfg_path.open("w", encoding="utf-8") as f:
                    json.dump(cfg, f, indent=2, ensure_ascii=False)

        # ====== 通用多模态兜底（Qwen3-VL 等） ======
        if base_dir is not None and not is_opencua:
            for fn in [
                "preprocessor_config.json",
                "processor_config.json",
                "image_processor_config.json",
            ]:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

        # 3) metadata（可选）
        metadata = {
            "save_time": time.strftime("%Y-%m-%d %H:%M:%S"),
        }
        with (save_base / "metadata.json").open("w") as f:
            json.dump(metadata, f, indent=2)

        logger.info(f"Saved model + tokenizer to {save_dir}")


# ==================== Main ====================

def main():
    config = get_config()

    project_name = config.experiment.project

    pretrained_model = config.model.pretrained_model
    optimized_name = config.model.optimized_name
    max_prompt_len = config.training.max_prompt_len
    max_gen_length = config.training.max_gen_length
    optimization_data = config.dataset
    update_per_step = config.training.update_per_step
    batch_size_lm = config.training.batch_size_lm
    gradient_checkpointing_enable = config.training.gradient_checkpointing_enable
    

    # TF32
    if config.training.enable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    # ========== Load optimization data ==========
    opt_path = Path(config.system.rl_base_dir) / "data" / f"{optimization_data}.json"
    with opt_path.open("r", encoding="utf-8") as f:
        dataset_load = json.load(f)

    ws = int(os.environ.get("WORLD_SIZE", "1"))
    _ = int(os.environ.get("RANK", "0"))  # 占位，不用也无所谓
    total_n = len(dataset_load)

    # 沿用原来的自动计算 grad_acc 逻辑
    gradient_accumulation_steps = max(
        1,
        math.ceil(total_n / (update_per_step * batch_size_lm * ws)),
    )

    # policy 用 OpenCUA，reward 用 Qwen3-VL（和 GRPO 版一致）
    if config.model.model_type == "qwen3vl":
        is_qwen3_vl = True
    else:
        is_qwen3_vl = False

    # ========== Tokenizer & Image Processor ==========
    if is_qwen3_vl:
        from transformers import Qwen3VLForConditionalGeneration
        processor = AutoProcessor.from_pretrained(pretrained_model, trust_remote_code=True)
        tokenizer = processor.tokenizer
        image_processor = processor.image_processor
    else:
        processor = None
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)
        image_processor = AutoImageProcessor.from_pretrained(pretrained_model, trust_remote_code=True)

    # ========== Build SFT dataset ==========
    dataset_lm = build_vlm_sft_dataset(
        dataset_load,
        tokenizer=tokenizer,
        image_processor=image_processor,
        max_prompt_len=max_prompt_len,
        max_gen_length=max_gen_length,
        is_qwen3_vl=is_qwen3_vl,
        processor=processor,
    )

    # ========== Accelerator ==========
    config.experiment.logging_dir = str(Path(project_name) / "logs")
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=config.training.mixed_precision,
        log_with=None,
        project_dir=config.experiment.logging_dir,
        split_batches=True,
    )
    assert ws == accelerator.num_processes

    # Logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        set_verbosity_info()
    else:
        set_verbosity_error()

    # Wandb
    if accelerator.is_main_process and wandb is not None:
        resume_wandb_run = config.wandb.resume
        run_id = config.wandb.get("run_id", None)
        if run_id is None:
            resume_wandb_run = False
            run_id = wandb.util.generate_id()
            config.wandb.run_id = run_id

        wandb_init_kwargs = dict(
            name=project_name,
            id=run_id,
            resume=resume_wandb_run,
            entity=config.wandb.get("entity", None),
            config_exclude_keys=[],
        )
        wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
        wandb_config.pop("experiment.resume_from_checkpoint", None)

        accelerator.init_trackers(
            project_name,
            config=wandb_config,
            init_kwargs={"wandb": wandb_init_kwargs},
        )

    if accelerator.is_main_process:
        os.makedirs(project_name, exist_ok=True)
        config_path = Path(project_name) / "config.yaml"
        logging.info(f"Saving config to {config_path}")
        OmegaConf.save(config, config_path)

    # Seed
    if config.training.seed is not None:
        set_seed(config.training.seed)

    # ========== Model & Optimizer ==========
    logger.info("Loading VLM model and optimizer")

    if is_qwen3_vl:
        from transformers import Qwen3VLForConditionalGeneration
        model = Qwen3VLForConditionalGeneration.from_pretrained(
            pretrained_model,
            trust_remote_code=True,
            torch_dtype="auto",
        )
    else:
        model = AutoModel.from_pretrained(
            pretrained_model,
            trust_remote_code=True,
            torch_dtype="auto",
        )

    def enable_gc(model):
        """
        统一的 gradient checkpointing 开启逻辑。
        """
        # 先关所有能看到的 use_cache，防止和 gc 冲突
        if hasattr(model, "config") and hasattr(model.config, "use_cache"):
            model.config.use_cache = False
        if hasattr(model, "language_model") and hasattr(model.language_model, "config"):
            if hasattr(model.language_model.config, "use_cache"):
                model.language_model.config.use_cache = False

        # 情况 1：OpenCUA 这类 wrapper，真正的 LM 在 language_model 里
        if hasattr(model, "language_model") and hasattr(model.language_model, "gradient_checkpointing_enable"):
            model.language_model.gradient_checkpointing_enable()
            for m in model.language_model.modules():
                if hasattr(m, "gradient_checkpointing"):
                    m.gradient_checkpointing = True
            logger.info("[GC] Enabled on model.language_model (OpenCUA/Qwen-style)")
            return

        # 情况 2：普通支持 gc 的模型
        if hasattr(model, "gradient_checkpointing_enable"):
            try:
                model.gradient_checkpointing_enable()
                for m in model.modules():
                    if hasattr(m, "gradient_checkpointing"):
                        m.gradient_checkpointing = True
                logger.info("[GC] Enabled on model directly")
                return
            except ValueError as e:
                logger.warning(f"[GC] Direct enable failed: {e}")

        # 情况 3：兜底
        any_flag = False
        for m in model.modules():
            if hasattr(m, "gradient_checkpointing"):
                m.gradient_checkpointing = True
                any_flag = True
        if any_flag:
            logger.info("[GC] Enabled by setting gradient_checkpointing=True on submodules")
        else:
            logger.warning("[GC] No supported gradient checkpointing hooks found; running without GC.")

    if gradient_checkpointing_enable:
        enable_gc(model)
    else:
        model = model.to(accelerator.device)

    pad_id = tokenizer.pad_token_id

    optimizer_config = config.optimizer.params

    no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and not any(nd in n for nd in no_decay)
            ],
            "weight_decay": optimizer_config.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if config.optimizer.name == "adamw":
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=optimizer_config.learning_rate,
            betas=(optimizer_config.beta1, optimizer_config.beta2),
            weight_decay=optimizer_config.weight_decay,
            eps=optimizer_config.epsilon,
        )
    else:
        raise ValueError(f"Optimizer {config.optimizer.name} not supported")

    total_batch_size_lm = batch_size_lm * ws * gradient_accumulation_steps
    num_update_steps_per_epoch = max(
        1,
        math.ceil(total_n / total_batch_size_lm),
    )
    num_train_epochs = config.training.num_train_epochs
    max_train_steps = num_update_steps_per_epoch * num_train_epochs

    lr_scheduler = get_scheduler(
        config.lr_scheduler.scheduler,
        optimizer=optimizer,
        num_training_steps=max_train_steps,
        num_warmup_steps=config.lr_scheduler.params.warmup_steps,
        min_lr_scale=config.lr_scheduler.params.min_lr_scale,
    )

    train_dataloader_lm = DataLoader(
        dataset_lm,
        batch_size=batch_size_lm,
        sampler=None,
        collate_fn=sft_collate,
        num_workers=0,
    )

    # Prepare with accelerator
    logger.info("Preparing model, optimizer, scheduler and dataloader")
    model, optimizer, lr_scheduler, train_dataloader_lm = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader_lm
    )

    # ========== Training ==========
    logger.info("***** Running SFT training *****")
    logger.info(f"  Num training data = {len(dataset_lm)}")
    logger.info(f"  Num training steps = {max_train_steps}")
    logger.info(f"  Instantaneous batch size per device = {batch_size_lm}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size_lm}")
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")

    from tqdm.auto import tqdm

    global_step = 0

    for epoch in range(num_train_epochs):
        model.train()
        progress = tqdm(
            train_dataloader_lm,
            disable=not accelerator.is_local_main_process,
            dynamic_ncols=True,
        )

        for step, batch in enumerate(progress, start=1):
            loss = forward_sft_vlm(
                accelerator=accelerator,
                model=model,
                batch=batch,
                tokenizer=tokenizer,
                is_qwen3_vl=is_qwen3_vl,
            )

            loss = loss / gradient_accumulation_steps
            accelerator.backward(loss)

            if (step % gradient_accumulation_steps) == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

                global_step += 1
                loss_val = float(accelerator.gather(loss.detach()).mean().item())
                progress.set_postfix(loss=loss_val)

                if accelerator.is_main_process and wandb is not None:
                    accelerator.log({"train/loss": loss_val, "train/step": global_step})

    torch.cuda.empty_cache()
    accelerator.wait_for_everyone()

    # save final
    save_checkpoint(model, tokenizer, config, accelerator, optimized_name)

    accelerator.end_training()


if __name__ == "__main__":
    main()
