# -*- coding: utf-8 -*-
"""
SFT training for OpenCUA / Qwen3-VL VLM policy model (lazy dataset version).

- 输入：policy_optimization_data.json / reward_optimization_data.json
  [
    {
      "domain": "...",
      "example": "...",
      "run_id": 0,
      "step": 0,
      "prompt_messages": [...],   # system/user 多模态，格式同 serve_opencua.py
      "response": "...",          # assistant 文本
      "reward": 0.0               # 在 SFT 中不用
    },
    ...
  ]

- 多模态处理与 serve_opencua / GRPO 版完全对齐：
  - prompt_messages -> tokenizer/processor.apply_chat_template(..., add_generation_prompt=True)
  - images in prompt_messages -> AutoImageProcessor.preprocess / Qwen3-VL processor 提供的接口

- SFT 风格：
  - 对 response tokens 做标准自回归交叉熵：
      - labels 在 prompt 部分设为 -100（不参与 loss）
      - labels 在 response 部分 = 对应 token id
      - 使用 ignore_index=-100 的 cross_entropy

- Lazy 特点：
  - 不预先把所有样本展开成 (N, L) 的大 tensor
  - 每次 __getitem__ 动态构造并 pad 到统一 max_seq_len
"""

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ["TOKENIZERS_PARALLELISM"] = "true"

import io
import json
import math
import time
import shutil
import base64
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from omegaconf import OmegaConf

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from transformers import (
    AutoTokenizer,
    AutoImageProcessor,
    AutoModel,
    AutoProcessor,
)

from train.utils import get_config, flatten_omega_conf, AverageMeter
from train.lr_schedulers import get_scheduler
from train.log_utils import set_verbosity_info, set_verbosity_error

try:
    import wandb
except ImportError:
    wandb = None

logger = get_logger(__name__, log_level="INFO")


# ==================== Utils for images / prompts ====================

def _decode_image(src: str) -> Image.Image:
    s = src.strip()
    if s.startswith("http://") or s.startswith("https://"):
        # 训练阶段一般是本地路径，这里简单支持一下 http(s) / data URL
        import requests
        resp = requests.get(s, timeout=10)
        resp.raise_for_status()
        return Image.open(io.BytesIO(resp.content)).convert("RGB")
    # data/base64
    if not os.path.exists(s):
        # 可能是纯 base64 / data URI
        b64 = s.split(",", 1)[1] if "," in s else s
        raw = base64.b64decode(b64, validate=False)
        return Image.open(io.BytesIO(raw)).convert("RGB")
    # 本地路径
    return Image.open(s).convert("RGB")


def collect_images_from_messages(messages: List[Dict[str, Any]]) -> List[Image.Image]:
    pil_images: List[Image.Image] = []
    for m in messages:
        content = m.get("content")
        if isinstance(content, list):
            for part in content:
                t = part.get("type")
                if t == "image":
                    src = part.get("image")
                    if src:
                        try:
                            pil_images.append(_decode_image(src))
                        except Exception as e:
                            print(f"[WARN] decode image failed: {src} | {e}", flush=True)
                elif t == "image_url":
                    url = (part.get("image_url") or {}).get("url")
                    if url:
                        try:
                            pil_images.append(_decode_image(url))
                        except Exception as e:
                            print(f"[WARN] decode image_url failed: {url} | {e}", flush=True)
    return pil_images


# ==================== Lazy Dataset ====================

class VLMSFTLazyDataset(Dataset):
    """
    Lazy SFT dataset:

      - 不预先构建 (N, L) 大 tensor
      - __getitem__ 里现算 tokenizer / image_processor
      - 每个样本的 input_ids / labels 都 pad 到统一 max_seq_len
    """

    def __init__(
        self,
        raw_data: List[Dict[str, Any]],
        tokenizer,
        image_processor,
        max_prompt_len: int,
        max_gen_length: int,
        is_qwen3_vl: bool = False,
        processor=None,
        fallback_max_seq_len: int = 4096,
    ):
        self.raw_data = raw_data
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.is_qwen3_vl = is_qwen3_vl
        self.processor = processor

        if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token = tokenizer.eos_token
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        self.pad_id = tokenizer.pad_token_id

        self.max_prompt_len = max_prompt_len
        self.max_gen_length = max_gen_length

        # 统一的 max_seq_len，用于 padding（关键）
        if max_prompt_len > 0 and max_gen_length > 0:
            self.max_seq_len = max_prompt_len + max_gen_length
        else:
            # 没设就给个保守值，你也可以改成 config.model.max_position_embeddings
            self.max_seq_len = fallback_max_seq_len
            logger.warning(
                f"[VLMSFTLazyDataset] max_prompt_len/max_gen_length <= 0, "
                f"fallback max_seq_len = {self.max_seq_len}. "
                f"建议在 config 里显式设置这两个值。"
            )

    def __len__(self) -> int:
        return len(self.raw_data)

    def _build_one_sample(
        self,
        item: Dict[str, Any],
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        tokenizer = self.tokenizer
        image_processor = self.image_processor
        is_qwen3_vl = self.is_qwen3_vl
        processor = self.processor
        pad_id = self.pad_id

        prompt_messages = item["prompt_messages"]
        response_text = item["response"]

        # ---------- Qwen3-VL 分支：用 AutoProcessor 保证占位符和特征对齐 ----------
        if is_qwen3_vl:
            if processor is None:
                raise ValueError("Qwen3-VL requires `processor`, but got None.")

            qwen_messages = []
            for m in prompt_messages:
                role = m.get("role", "user")
                content = m.get("content", "")

                new_m: Dict[str, Any] = {"role": role}
                new_content: List[Dict[str, Any]] = []

                # 情况 1：原本就是字符串
                if isinstance(content, str):
                    new_content.append({"type": "text", "text": content})

                # 情况 2：是列表（可能混有 str / dict）
                elif isinstance(content, list):
                    for part in content:
                        # 字符串 -> text
                        if isinstance(part, str):
                            new_content.append({"type": "text", "text": part})
                        # dict，根据 type 归一
                        elif isinstance(part, dict):
                            p_type = part.get("type")
                            if p_type in ("text", "paragraph"):
                                txt = part.get("text", "")
                                new_content.append({"type": "text", "text": txt})
                            elif p_type in ("image", "video"):
                                new_content.append(part)
                            elif p_type == "image_url":
                                url = (part.get("image_url") or {}).get("url") or part.get("url")
                                if url:
                                    new_content.append({"type": "image", "image": url})
                            else:
                                txt = part.get("text", None)
                                if txt is not None:
                                    new_content.append({"type": "text", "text": txt})
                                else:
                                    new_content.append({"type": "text", "text": str(part)})
                        else:
                            new_content.append({"type": "text", "text": str(part)})
                else:
                    # 其他奇怪类型，兜底成文本
                    new_content.append({"type": "text", "text": str(content)})

                new_m["content"] = new_content
                qwen_messages.append(new_m)

            proc_inputs = processor.apply_chat_template(
                qwen_messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
                truncation=True,#False,
                max_length=self.max_prompt_len,#100000,
            )

            prompt_ids = proc_inputs["input_ids"][0].tolist()

            pixel_values = proc_inputs.get("pixel_values", None)
            if pixel_values is not None:
                pixel_values = pixel_values.clone()

            grid = proc_inputs.get("image_grid_thw", None)
            if grid is not None:
                grid_thws = grid.clone().long()
            else:
                grid_thws = None

        # ---------- 非 Qwen3-VL 分支：保持原来的 OpenCUA / policy 行为 ----------
        else:
            # 1) prompt token ids
            prompt_ids = tokenizer.apply_chat_template(
                prompt_messages,
                tokenize=True,
                add_generation_prompt=True,
            )

            # 2) 图像（只看 prompt_messages）
            pil_images = collect_images_from_messages(prompt_messages)
            pixel_values = None
            grid_thws = None
            if pil_images:
                info = image_processor.preprocess(images=pil_images)

                pixel = info.get("pixel_values", None)
                if pixel is None:
                    pixel = info.get("pixel_values_videos", None)

                grid = info.get("image_grid_thw", None)
                if grid is None:
                    grid = info.get("images_grid_thw", None)
                if grid is None:
                    grid = info.get("grid_thws", None)

                if pixel is not None:
                    pixel_values = torch.as_tensor(pixel)
                if grid is not None:
                    grid_thws = torch.as_tensor(grid, dtype=torch.long)

        # 3) response token ids（只文本）
        resp_ids = tokenizer(
            response_text,
            add_special_tokens=False,
        )["input_ids"]

        # 4) 截断 prompt / response
        if not is_qwen3_vl:
            # 非 Qwen3-VL 分支仍然可以自己截 prompt（文本 + 简单图片）
            if self.max_prompt_len > 0 and len(prompt_ids) > self.max_prompt_len:
                prompt_ids = prompt_ids[-self.max_prompt_len:]
        if self.max_gen_length > 0 and len(resp_ids) > self.max_gen_length:
            resp_ids = resp_ids[:self.max_gen_length]

        input_ids = prompt_ids + resp_ids
        start_pos = len(prompt_ids)
        L = len(input_ids)

        max_len = self.max_seq_len
        if L > max_len:
            # 理论上不会发生，因为前面已经按 max_prompt_len / max_gen_length 截断
            # 保险起见，再从左边截一截，并修正 start_pos
            offset = L - max_len
            input_ids = input_ids[offset:]
            start_pos = max(0, start_pos - offset)
            L = max_len

        input_ids_tensor = torch.full((max_len,), pad_id, dtype=torch.long)
        labels_tensor = torch.full((max_len,), -100, dtype=torch.long)

        if L > 0:
            ids_tensor = torch.tensor(input_ids, dtype=torch.long)
            input_ids_tensor[:L] = ids_tensor
            if L > start_pos:
                labels_tensor[start_pos:L] = ids_tensor[start_pos:L]

        return input_ids_tensor, labels_tensor, pixel_values, grid_thws

    def __getitem__(self, idx: int):
        return self._build_one_sample(self.raw_data[idx])


# ==================== Collate ====================

def sft_collate(batch):
    """
    collate：stack input_ids / labels，其余保持 list（逐样本处理图像）。
    由于 Dataset 已经保证每个样本的长度都是 max_seq_len，这里直接 stack 即可。
    """
    input_ids, labels, pixel_values, grid_thws = zip(*batch)

    return {
        "input_ids": torch.stack(input_ids),   # (B, L)
        "labels": torch.stack(labels),         # (B, L)
        "pixel_values": list(pixel_values),    # len=B, 每个是 Tensor 或 None
        "grid_thws": list(grid_thws),          # len=B, 每个是 Tensor 或 None
    }


# ==================== Core train ====================

def make_attn_and_pos(input_ids: torch.Tensor, pad_id: int):
    """
    attention_mask: 1 for non-pad
    position_ids:   cumsum over non-pad, pad 位置为 0（不会用到）
    """
    attention_mask = (input_ids != pad_id).to(torch.long)
    position_ids = attention_mask.cumsum(dim=1) - 1
    position_ids.masked_fill_(attention_mask == 0, 0)
    return attention_mask, position_ids


def forward_sft_vlm(
    accelerator: Accelerator,
    model,
    batch: Dict[str, Any],
    tokenizer,
    is_qwen3_vl: bool,
):
    """
    单 step 的 SFT loss 计算（多模态版）：
      - 只在 labels != -100 的 token 上做交叉熵（也就是 response tokens）。
    同样逐样本 forward，以兼容不同 pixel_values/grid_thws shape。
    """
    device = accelerator.device

    input_ids_batch = batch["input_ids"].to(device)  # (B, L)
    labels_batch = batch["labels"].to(device)        # (B, L)
    pixel_values_list = batch["pixel_values"]
    grid_thws_list = batch["grid_thws"]

    pad_id = tokenizer.pad_token_id
    B, L = input_ids_batch.shape

    total_loss = torch.tensor(0.0, device=device)

    for i in range(B):
        input_ids = input_ids_batch[i: i + 1]  # (1, L)
        labels = labels_batch[i: i + 1]        # (1, L)

        attention_mask, position_ids = make_attn_and_pos(input_ids, pad_id)

        kwargs: Dict[str, Any] = dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        pv = pixel_values_list[i]
        if pv is not None:
            kwargs["pixel_values"] = pv.to(device)

        g = grid_thws_list[i]
        if g is not None:
            if is_qwen3_vl:
                kwargs["image_grid_thw"] = g.to(device)
            else:
                kwargs["grid_thws"] = g.to(device)

        out = model(**kwargs)
        logits = out.logits  # (1, L, V)

        # shift，和 GRPO 版保持一致
        logits_shifted = logits[:, :-1, :]   # (1, L-1, V)
        labels_shifted = labels[:, 1:]       # (1, L-1)

        loss = F.cross_entropy(
            logits_shifted.reshape(-1, logits_shifted.size(-1)),
            labels_shifted.reshape(-1),
            ignore_index=-100,
        )

        total_loss = total_loss + loss

    return total_loss / max(B, 1)


# ==================== Checkpoint ====================

def save_checkpoint(model, tokenizer, config, accelerator, name: str):
    import json as _json
    from pathlib import Path as _Path

    project_name = config.experiment.project
    rl_base_dir = _Path(config.system.rl_base_dir)
    save_base = rl_base_dir / project_name / "ckpt"
    save_dir = save_base / name

    # unwrap & gather
    model_to_save = accelerator.unwrap_model(model)
    state_dict = accelerator.get_state_dict(model)

    if accelerator.is_main_process:
        save_dir.mkdir(parents=True, exist_ok=True)

        # 1) 保存模型权重
        model_to_save.save_pretrained(
            save_dir,
            save_function=accelerator.save,
            state_dict=state_dict,
            safe_serialization=True,
        )

        # 2) 保存 tokenizer（所有模型通用）
        tokenizer.save_pretrained(str(save_dir))

        # ====== 推断 base_dir，用来拷各种配置 ======
        init_kwargs = getattr(tokenizer, "init_kwargs", {})
        tok_name = (
            getattr(tokenizer, "name_or_path", "")
            or init_kwargs.get("_name_or_path", "")
        )

        target = getattr(config.training, "target", None)

        candidate_base = []
        if tok_name:
            candidate_base.append(_Path(tok_name))
        if target == "policy":
            candidate_base.append(_Path(config.model.policy_model))
        elif target == "reward":
            candidate_base.append(_Path(config.model.reward_model))

        base_dir = None
        for p in candidate_base:
            if p and p.exists():
                base_dir = p
                break

        # ====== 识别是不是 OpenCUA（TikTokenV3） ======
        tok_class_name = tokenizer.__class__.__name__
        tok_cfg_class = init_kwargs.get("tokenizer_class", "")

        is_opencua = (
            "TikTokenV3" in tok_class_name
            or "TikTokenV3" in str(tok_cfg_class)
            or "OpenCUA" in str(tok_name)
        )

        if is_opencua and base_dir is not None:
            # OpenCUA 需要的自定义代码和多模态配置
            extra_files = [
                "tokenization_opencua.py",
                "configuration_opencua.py",
                "modeling_opencua.py",
                "tiktoken.model",
                "preprocessor_config.json",
                "image_processor_config.json",
                "processor_config.json",
                "generation_config.json",
            ]
            for fn in extra_files:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

            # patch tokenizer_config.json -> 指向 tokenization_opencua.TikTokenV3
            tok_cfg_path = save_dir / "tokenizer_config.json"
            if tok_cfg_path.exists():
                with tok_cfg_path.open("r", encoding="utf-8") as f:
                    cfg = _json.load(f)
                cfg["tokenizer_class"] = "tokenization_opencua.TikTokenV3"
                auto_map = cfg.get("auto_map", {})
                auto_map["AutoTokenizer"] = ["tokenization_opencua.TikTokenV3", None]
                cfg["auto_map"] = auto_map
                with tok_cfg_path.open("w", encoding="utf-8") as f:
                    _json.dump(cfg, f, indent=2, ensure_ascii=False)

        # ====== 通用多模态兜底（Qwen3-VL 等） ======
        if base_dir is not None and not is_opencua:
            for fn in [
                "preprocessor_config.json",
                "processor_config.json",
                "image_processor_config.json",
            ]:
                src = base_dir / fn
                dst = save_dir / fn
                if src.exists() and not dst.exists():
                    shutil.copy(src, dst)

        # 3) metadata（可选）
        metadata = {
            "save_time": time.strftime("%Y-%m-%d %H:%M:%S"),
        }
        with (save_base / "metadata.json").open("w", encoding="utf-8") as f:
            _json.dump(metadata, f, indent=2, ensure_ascii=False)

        logger.info(f"Saved model + tokenizer to {save_dir}")


# ==================== Main ====================

def main():
    config = get_config()

    project_name = config.experiment.project

    pretrained_model = config.model.pretrained_model
    optimized_name = config.model.optimized_name
    max_prompt_len = config.training.max_prompt_len
    max_gen_length = config.training.max_gen_length
    optimization_data = config.dataset
    update_per_step = config.training.update_per_step
    batch_size_lm = config.training.batch_size_lm
    gradient_checkpointing_enable = config.training.gradient_checkpointing_enable

    # TF32
    if config.training.enable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    # ========== Load optimization data ==========
    opt_path = Path(config.system.rl_base_dir) / "data" / f"{optimization_data}.json"
    with opt_path.open("r", encoding="utf-8") as f:
        dataset_load = json.load(f)
    #dataset_load = dataset_load[:int(len(dataset_load) / 2)]
    #dataset_load = dataset_load[:100]

    ws_env = int(os.environ.get("WORLD_SIZE", "1"))
    total_n = len(dataset_load)

    # 沿用原来的自动计算 grad_acc 逻辑
    gradient_accumulation_steps = max(
        1,
        math.ceil(total_n / (update_per_step * batch_size_lm * ws_env)),
    )

    # policy 用 OpenCUA，reward 用 Qwen3-VL（和 GRPO 版一致）
    if config.model.model_type == "qwen3vl":
        is_qwen3_vl = True
    else:
        is_qwen3_vl = False

    # ========== Tokenizer & Image Processor ==========
    if is_qwen3_vl:
        from transformers import Qwen3VLForConditionalGeneration
        processor = AutoProcessor.from_pretrained(pretrained_model, trust_remote_code=True)
        tokenizer = processor.tokenizer
        image_processor = processor.image_processor
    else:
        processor = None
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)
        image_processor = AutoImageProcessor.from_pretrained(pretrained_model, trust_remote_code=True)

    # ========== Build SFT dataset (lazy) ==========
    dataset_lm = VLMSFTLazyDataset(
        raw_data=dataset_load,
        tokenizer=tokenizer,
        image_processor=image_processor,
        max_prompt_len=max_prompt_len,
        max_gen_length=max_gen_length,
        is_qwen3_vl=is_qwen3_vl,
        processor=processor,
    )

    # ========== Accelerator ==========
    config.experiment.logging_dir = str(Path(project_name) / "logs")
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=config.training.mixed_precision,
        log_with=None,
        project_dir=config.experiment.logging_dir,
        split_batches=True,
    )
    assert ws_env == accelerator.num_processes

    # Logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        set_verbosity_info()
    else:
        set_verbosity_error()

    # Wandb
    if accelerator.is_main_process and wandb is not None:
        resume_wandb_run = config.wandb.resume
        run_id = config.wandb.get("run_id", None)
        if run_id is None:
            resume_wandb_run = False
            run_id = wandb.util.generate_id()
            config.wandb.run_id = run_id

        wandb_init_kwargs = dict(
            name=project_name,
            id=run_id,
            resume=resume_wandb_run,
            entity=config.wandb.get("entity", None),
            config_exclude_keys=[],
        )
        wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
        wandb_config.pop("experiment.resume_from_checkpoint", None)

        accelerator.init_trackers(
            project_name,
            config=wandb_config,
            init_kwargs={"wandb": wandb_init_kwargs},
        )

    if accelerator.is_main_process:
        os.makedirs(project_name, exist_ok=True)
        config_path = Path(project_name) / "config.yaml"
        logging.info(f"Saving config to {config_path}")
        OmegaConf.save(config, config_path)

    # Seed
    if config.training.seed is not None:
        set_seed(config.training.seed)

    # ========== Model & Optimizer ==========
    logger.info("Loading VLM model and optimizer")

    if is_qwen3_vl:
        from transformers import Qwen3VLForConditionalGeneration
        model = Qwen3VLForConditionalGeneration.from_pretrained(
            pretrained_model,
            trust_remote_code=True,
            torch_dtype="auto",
        )
    else:
        model = AutoModel.from_pretrained(
            pretrained_model,
            trust_remote_code=True,
            torch_dtype="auto",
        )

    def enable_gc(model_):
        """
        统一的 gradient checkpointing 开启逻辑。
        """
        # 先关所有能看到的 use_cache，防止和 gc 冲突
        if hasattr(model_, "config") and hasattr(model_.config, "use_cache"):
            model_.config.use_cache = False
        if hasattr(model_, "language_model") and hasattr(model_.language_model, "config"):
            if hasattr(model_.language_model.config, "use_cache"):
                model_.language_model.config.use_cache = False

        # 情况 1：OpenCUA 这类 wrapper，真正的 LM 在 language_model 里
        if hasattr(model_, "language_model") and hasattr(model_.language_model, "gradient_checkpointing_enable"):
            model_.language_model.gradient_checkpointing_enable()
            for m in model_.language_model.modules():
                if hasattr(m, "gradient_checkpointing"):
                    m.gradient_checkpointing = True
            logger.info("[GC] Enabled on model.language_model (OpenCUA/Qwen-style)")
            return

        # 情况 2：普通支持 gc 的模型
        if hasattr(model_, "gradient_checkpointing_enable"):
            try:
                model_.gradient_checkpointing_enable()
                for m in model_.modules():
                    if hasattr(m, "gradient_checkpointing"):
                        m.gradient_checkpointing = True
                logger.info("[GC] Enabled on model directly")
                return
            except ValueError as e:
                logger.warning(f"[GC] Direct enable failed: {e}")

        # 情况 3：兜底
        any_flag = False
        for m in model_.modules():
            if hasattr(m, "gradient_checkpointing"):
                m.gradient_checkpointing = True
                any_flag = True
        if any_flag:
            logger.info("[GC] Enabled by setting gradient_checkpointing=True on submodules")
        else:
            logger.warning("[GC] No supported gradient checkpointing hooks found; running without GC.")

    if gradient_checkpointing_enable:
        enable_gc(model)
    else:
        model = model.to(accelerator.device)

    pad_id = tokenizer.pad_token_id

    optimizer_config = config.optimizer.params

    no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and not any(nd in n for nd in no_decay)
            ],
            "weight_decay": optimizer_config.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if p.requires_grad and any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if config.optimizer.name == "adamw":
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=optimizer_config.learning_rate,
            betas=(optimizer_config.beta1, optimizer_config.beta2),
            weight_decay=optimizer_config.weight_decay,
            eps=optimizer_config.epsilon,
        )
    else:
        raise ValueError(f"Optimizer {config.optimizer.name} not supported")

    total_batch_size_lm = batch_size_lm * ws_env * gradient_accumulation_steps
    num_update_steps_per_epoch = max(
        1,
        math.ceil(total_n / total_batch_size_lm),
    )
    num_train_epochs = config.training.num_train_epochs
    max_train_steps = num_update_steps_per_epoch * num_train_epochs

    lr_scheduler = get_scheduler(
        config.lr_scheduler.scheduler,
        optimizer=optimizer,
        num_training_steps=max_train_steps,
        num_warmup_steps=config.lr_scheduler.params.warmup_steps,
        min_lr_scale=config.lr_scheduler.params.min_lr_scale,
    )

    train_dataloader_lm = DataLoader(
        dataset_lm,
        batch_size=batch_size_lm,
        sampler=None,
        collate_fn=sft_collate,
        num_workers=0,
    )

    # Prepare with accelerator
    logger.info("Preparing model, optimizer, scheduler and dataloader")
    model, optimizer, lr_scheduler, train_dataloader_lm = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader_lm
    )

    # ========== Training ==========
    logger.info("***** Running SFT training *****")
    logger.info(f"  Num training data (total) = {len(dataset_lm)}")
    logger.info(f"  Approx num data for this rank = {math.ceil(len(dataset_lm) / accelerator.num_processes)}")
    logger.info(f"  Num training steps = {max_train_steps}")
    logger.info(f"  Instantaneous batch size per device = {batch_size_lm}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size_lm}")
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")

    from tqdm.auto import tqdm

    global_step = 0

    for epoch in range(num_train_epochs):
        model.train()
        progress = tqdm(
            train_dataloader_lm,
            disable=not accelerator.is_local_main_process,
            dynamic_ncols=True,
        )

        for step, batch in enumerate(progress, start=1):
            loss = forward_sft_vlm(
                accelerator=accelerator,
                model=model,
                batch=batch,
                tokenizer=tokenizer,
                is_qwen3_vl=is_qwen3_vl,
            )

            loss = loss / gradient_accumulation_steps
            accelerator.backward(loss)

            if (step % gradient_accumulation_steps) == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

                global_step += 1
                loss_val = float(accelerator.gather(loss.detach()).mean().item())
                progress.set_postfix(loss=loss_val)

                if accelerator.is_main_process and wandb is not None:
                    accelerator.log({"train/loss": loss_val, "train/step": global_step})

    torch.cuda.empty_cache()
    accelerator.wait_for_everyone()

    # save final
    save_checkpoint(model, tokenizer, config, accelerator, optimized_name)

    accelerator.end_training()


if __name__ == "__main__":
    main()
