# -*- coding: utf-8 -*-
"""
train/osworld_vlm_preprocess_shards.py

每个结点处理全量 json 的一段（平均分），做“重活”预处理：
- prompt -> prompt_ids（含多模态占位符对齐）
- response -> resp_ids
- image -> pixel_values / grid_thws（与原训练 build_vlm_grpo_dataset 一致）
- truncate: max_prompt_len / max_gen_length（仅截断，不 padding）

输出（每个 node 一个文件）：
  <project>/temp_data/<optimization_data>_preproc_node{node_idx}_of{num_nodes}.pt

保存为“变长 packed”：
  flat_input_ids:  (sum_L,)
  offsets:         (N+1,)  prefix sum
  start_pos:       (N,)    = len(prompt_ids)
  advantage:       (N,)
  pixel_values_list / grid_thws_list: list len=N
  meta: 记录 pad_id, range, total_size 等
"""

import os
import io
import json
import base64
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional

import torch
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from omegaconf import OmegaConf
from transformers import AutoTokenizer, AutoImageProcessor, AutoProcessor

from train.utils import get_config
from train.osworld_rl_preload import collect_images_from_messages  # 复用你原函数（若路径不同就改一下）


def _decode_image(src: str) -> Image.Image:
    s = src.strip()
    if s.startswith("http://") or s.startswith("https://"):
        import requests
        resp = requests.get(s, timeout=10)
        resp.raise_for_status()
        return Image.open(io.BytesIO(resp.content)).convert("RGB")
    if not os.path.exists(s):
        b64 = s.split(",", 1)[1] if "," in s else s
        raw = base64.b64decode(b64, validate=False)
        return Image.open(io.BytesIO(raw)).convert("RGB")
    return Image.open(s).convert("RGB")


def _normalize_qwen_messages(prompt_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    和你训练脚本里 Qwen3-VL 分支一致：把 message/content 规范成 processor.apply_chat_template 可吃的格式
    """
    qwen_messages = []
    for m in prompt_messages:
        role = m.get("role", "user")
        content = m.get("content", "")

        new_m = {"role": role}
        new_content = []

        if isinstance(content, str):
            new_content.append({"type": "text", "text": content})
        elif isinstance(content, list):
            for part in content:
                if isinstance(part, str):
                    new_content.append({"type": "text", "text": part})
                elif isinstance(part, dict):
                    p_type = part.get("type")
                    if p_type in ("text", "paragraph"):
                        new_content.append({"type": "text", "text": part.get("text", "")})
                    elif p_type in ("image", "video"):
                        new_content.append(part)
                    elif p_type == "image_url":
                        url = (part.get("image_url") or {}).get("url") or part.get("url")
                        if url:
                            new_content.append({"type": "image", "image": url})
                    else:
                        txt = part.get("text", None)
                        new_content.append({"type": "text", "text": txt if txt is not None else str(part)})
                else:
                    new_content.append({"type": "text", "text": str(part)})
        else:
            new_content.append({"type": "text", "text": str(content)})

        new_m["content"] = new_content
        qwen_messages.append(new_m)

    return qwen_messages


def _pack_1d_long_tensors(seqs: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    seqs: list of 1D LongTensor with variable length
    returns:
      flat: (sum_L,)
      offsets: (N+1,) int64, offsets[0]=0, offsets[i+1]=offsets[i]+len_i
    """
    if not seqs:
        flat = torch.empty((0,), dtype=torch.long)
        offsets = torch.zeros((1,), dtype=torch.long)
        return flat, offsets

    lengths = torch.tensor([int(s.numel()) for s in seqs], dtype=torch.long)
    offsets = torch.zeros((len(seqs) + 1,), dtype=torch.long)
    offsets[1:] = torch.cumsum(lengths, dim=0)

    flat = torch.empty((int(offsets[-1].item()),), dtype=torch.long)
    cur = 0
    for s in seqs:
        L = int(s.numel())
        flat[cur:cur + L] = s
        cur += L
    return flat, offsets


def main():
    cfg = get_config()
    project_name = cfg.experiment.project

    # -------- choose branch (policy/reward) --------
    if cfg.training.target == "policy":
        if cfg.experiment.current_epoch == 1:
            pretrained_model = cfg.model.policy_model
        else:
            pretrained_model = cfg.system.rl_base_dir + "/" + project_name + "/ckpt/" + cfg.model.optimized_name
        max_prompt_len = int(cfg.training.policy.max_prompt_len)
        max_gen_length = int(cfg.training.policy.max_gen_length)
        optimization_data = "policy_optimization_data"
    elif cfg.training.target == "reward":
        if cfg.experiment.current_epoch == 1:
            pretrained_model = cfg.model.reward_model
        else:
            pretrained_model = cfg.system.rl_base_dir + "/" + project_name + "/ckpt/" + cfg.model.optimized_reward_name
        max_prompt_len = int(cfg.training.reward.max_prompt_len)
        max_gen_length = int(cfg.training.reward.max_gen_length)
        optimization_data = "reward_optimization_data"
    else:
        raise ValueError(f"Unknown training.target = {cfg.training.target}")

    # -------- node split (来自你 dispatcher 传的 dataset.*) --------
    node_idx = int(OmegaConf.select(cfg, "dataset.node_rank", default=os.environ.get("NODE_RANK", 0)))
    num_nodes = int(OmegaConf.select(cfg, "dataset.num_nodes", default=os.environ.get("NUM_NODES", 1)))
    assert num_nodes >= 1
    assert 0 <= node_idx < num_nodes

    cli = OmegaConf.from_cli()
    cast_pixel_fp16 = int(getattr(cli, "cast_pixel_fp16", 1))
    save_dir = getattr(cli, "save_dir", None)
    if save_dir is None:
        save_dir = str(Path(project_name) / "temp_data")
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    is_qwen3_vl = (getattr(cfg, "model_type", "") == "qwen3vl")

    # -------- load and slice --------
    opt_path = Path(project_name) / "temp_data" / f"{optimization_data}.json"
    with opt_path.open("r", encoding="utf-8") as f:
        data_all = json.load(f)
    total_n = len(data_all)
    start = (total_n * node_idx) // num_nodes
    end = (total_n * (node_idx + 1)) // num_nodes
    data = data_all[start:end]

    print(f"[Node {node_idx}/{num_nodes}] range [{start}, {end}) size={len(data)} total={total_n}", flush=True)

    # -------- tokenizer/processor --------
    if is_qwen3_vl:
        processor = AutoProcessor.from_pretrained(pretrained_model, trust_remote_code=True)
        tokenizer = processor.tokenizer
        image_processor = processor.image_processor
    else:
        processor = None
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)
        image_processor = AutoImageProcessor.from_pretrained(pretrained_model, trust_remote_code=True)

    # pad_token_id 兜底（和你原 build 一致）
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    pad_id = int(tokenizer.pad_token_id)

    # -------- preprocess (NO padding) --------
    seq_tensors: List[torch.Tensor] = []
    start_pos_list: List[int] = []
    adv_list: List[float] = []
    pixel_values_list: List[Any] = []
    grid_thws_list: List[Any] = []

    skipped = 0

    for item in data:
        prompt_messages = item["prompt_messages"]
        response_text = item["response"]
        reward = float(item.get("reward", 0.0))

        # ----- prompt ids + vision -----
        if is_qwen3_vl:
            qwen_messages = _normalize_qwen_messages(prompt_messages)
            proc_inputs = processor.apply_chat_template(
                qwen_messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
                truncation=False,
                max_length=100000,
            )
            prompt_ids = proc_inputs["input_ids"][0].tolist()

            pixel_values = proc_inputs.get("pixel_values", None)
            if pixel_values is not None:
                pixel_values = pixel_values.clone().cpu()

            grid = proc_inputs.get("image_grid_thw", None)
            grid_thws = grid.clone().long().cpu() if grid is not None else None

        else:
            prompt_ids = tokenizer.apply_chat_template(
                prompt_messages,
                tokenize=True,
                add_generation_prompt=True,
            )

            # images
            pil_images = collect_images_from_messages(prompt_messages)
            pixel_values, grid_thws = None, None
            if pil_images:
                info = image_processor.preprocess(images=pil_images)
                pixel = info.get("pixel_values", None) or info.get("pixel_values_videos", None)

                grid = info.get("image_grid_thw", None)
                if grid is None:
                    grid = info.get("images_grid_thw", None)
                if grid is None:
                    grid = info.get("grid_thws", None)

                if pixel is not None:
                    pixel_values = torch.as_tensor(pixel).cpu()
                if grid is not None:
                    grid_thws = torch.as_tensor(grid, dtype=torch.long).cpu()

        # ----- response ids -----
        resp_ids = tokenizer(response_text, add_special_tokens=False)["input_ids"]

        # ----- truncate rules (和你原 build 一致) -----
        has_image = (pixel_values is not None) or (grid_thws is not None)

        if max_prompt_len > 0 and len(prompt_ids) > max_prompt_len:
            if has_image:
                # 有图不允许截 prompt（会坏对齐）-> skip（和你原 build 一致）
                skipped += 1
                continue
            else:
                prompt_ids = prompt_ids[-max_prompt_len:]

        if max_gen_length > 0 and len(resp_ids) > max_gen_length:
            resp_ids = resp_ids[:max_gen_length]

        input_ids = prompt_ids + resp_ids
        if len(input_ids) == 0:
            skipped += 1
            continue

        start_pos = len(prompt_ids)

        # pixel fp16 optional
        if cast_pixel_fp16 and pixel_values is not None and pixel_values.dtype == torch.float32:
            pixel_values = pixel_values.half().contiguous()
        elif pixel_values is not None:
            pixel_values = pixel_values.contiguous()

        if grid_thws is not None:
            grid_thws = grid_thws.contiguous()

        seq_tensors.append(torch.tensor(input_ids, dtype=torch.long))
        start_pos_list.append(int(start_pos))
        adv_list.append(float(reward))
        pixel_values_list.append(pixel_values)
        grid_thws_list.append(grid_thws)

    N = len(seq_tensors)
    flat_input_ids, offsets = _pack_1d_long_tensors(seq_tensors)

    pack = {
        "flat_input_ids": flat_input_ids,
        "offsets": offsets,
        "start_pos": torch.tensor(start_pos_list, dtype=torch.long),
        "advantage": torch.tensor(adv_list, dtype=torch.float32),
        "pixel_values_list": pixel_values_list,
        "grid_thws_list": grid_thws_list,
        "meta": {
            "project": project_name,
            "optimization_data": optimization_data,
            "pretrained_model": str(pretrained_model),
            "is_qwen3_vl": bool(is_qwen3_vl),
            "pad_token_id": int(pad_id),
            "node_idx": int(node_idx),
            "num_nodes": int(num_nodes),
            "range_start": int(start),
            "range_end": int(end),
            "shard_size": int(N),
            "total_size": int(total_n),
            "skipped": int(skipped),
            "max_prompt_len_cap": int(max_prompt_len),
            "max_gen_length_cap": int(max_gen_length),
            "cast_pixel_fp16": int(cast_pixel_fp16),
        },
    }

    out_path = save_dir / f"{optimization_data}_preproc_node{node_idx}_of{num_nodes}.pt"
    torch.save(pack, out_path)
    print(
        f"[Node {node_idx}/{num_nodes}] saved: {out_path} "
        f"(N={N} skipped={skipped} flat={int(flat_input_ids.numel())})",
        flush=True,
    )


if __name__ == "__main__":
    main()
