# -*- coding: utf-8 -*-
"""
osworld_policy_rollout_single_engine.py

方案 A：单进程 + 单 vLLM Engine + TP=8（保留 CUDA graph；不使用 multiprocessing / 多 engine）

Per run folder outputs (same folder as trajectory.json):
  1) policy_rollout_results.json
  2) policy_rollout_reward_prompts.json

Also create placeholder dirs for future screenshots:
  <run_dir>/policy_aug/step_<k>/uniq_<j>_next_obs.png

This script:
- reads existing trajectory.json
- for each step: sample R policy responses with vLLM (single engine, TP=args.tp_size)
- parse each response into pyautogui action list (STRICTLY aligned with your parse_response_exact)
- AFTER parse: apply scale_scroll_for_windows() exactly like your standard snippet
- dedup action lists within step
- for each rollout: create a patched reward prompt by
    (a) best-effort replacing old policy response text with new response
    (b) replacing LAST image path in reward_messages with future next_obs path

No VM execution here.
"""

import os
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import re
import ast
import json
import copy
import argparse
import inspect
from typing import List, Dict, Any, Tuple, Optional

import torch
from PIL import Image
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info


############################
# Approx-consistency for action list dedup (your exact function)
############################
_CALL_RE = re.compile(r"^\s*pyautogui\.(?P<fn>\w+)\s*\((?P<args>.*)\)\s*$")

def _parse_action(s: str) -> Tuple[str, Optional[int], Optional[int], str]:
    """
    Parse an action string like 'pyautogui.click(633, 452)'.
    Returns: (fn_name, x, y, normalized_string)
    """
    s_norm = re.sub(r"\s+", "", s)
    m = _CALL_RE.match(s)
    if not m:
        return ("__UNKNOWN__", None, None, s_norm)

    fn = m.group("fn").lower()
    args = m.group("args")
    nums = re.findall(r"[-+]?\d+", args)
    x = int(nums[0]) if len(nums) >= 1 else None
    y = int(nums[1]) if len(nums) >= 2 else None
    return (fn, x, y, s_norm)


def approx_actionlists_consistent(action_lists: List[List[str]], tol: int = 10) -> bool:
    """
    判断 list of action list 是否“近似一致”。
    A) 所有 action list 完全一样（忽略空格） -> True
    B) 否则：长度一致，逐位置同函数名，且坐标 max-min <= tol
       若某位置无法解析坐标 -> 该位置必须完全一致（忽略空格）
    """
    if not action_lists:
        return True
    if len(action_lists) == 1:
        return True

    normalized = [[re.sub(r"\s+", "", a) for a in al] for al in action_lists]
    if all(normalized[i] == normalized[0] for i in range(1, len(normalized))):
        return True

    L = len(action_lists[0])
    if any(len(al) != L for al in action_lists):
        return False

    parsed = [[_parse_action(a) for a in al] for al in action_lists]
    for j in range(L):
        fns = [parsed[k][j][0] for k in range(len(parsed))]
        if any(fn != fns[0] for fn in fns):
            return False

        xs = [parsed[k][j][1] for k in range(len(parsed))]
        ys = [parsed[k][j][2] for k in range(len(parsed))]
        norms = [parsed[k][j][3] for k in range(len(parsed))]

        if all(x is not None for x in xs) and all(y is not None for y in ys):
            if max(xs) - min(xs) > tol:
                return False
            if max(ys) - min(ys) > tol:
                return False
        else:
            if any(n != norms[0] for n in norms):
                return False

    return True


# -------------------- 标准 Windows scroll scale（你指定的版本） --------------------
def scale_scroll_for_windows(code: str, platform: str, factor: int = 50) -> str:
    if platform.lower() != "windows":
        return code
    pattern_pos = re.compile(r"(pyautogui\.scroll\()\s*([-+]?\d+)\s*\)")
    return pattern_pos.sub(lambda m: f"{m.group(1)}{int(m.group(2)) * factor})", code)


# -------------------- VLM 输入准备（单进程/单 engine 复用） --------------------
def _resolve_local_images(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """把 {"type":"image","image":"path"} 的本地路径解析成 PIL.Image；失败则保留原值。"""
    msgs = copy.deepcopy(messages)
    for m in msgs:
        if "content" not in m:
            continue
        for chunk in (m.get("content") or []):
            if isinstance(chunk, dict) and chunk.get("type") == "image":
                img_val = chunk.get("image")
                if isinstance(img_val, str) and os.path.exists(img_val):
                    try:
                        with Image.open(img_val) as im:
                            chunk["image"] = im.convert("RGB")
                    except Exception:
                        pass
    return msgs


def prepare_inputs_for_vllm(messages: List[Dict[str, Any]], processor: AutoProcessor) -> Dict[str, Any]:
    """
    返回 vLLM Qwen-VL 多模态输入：
      {"prompt": ..., "multi_modal_data": {"image":...}, "mm_processor_kwargs": ...}
    """
    messages2 = _resolve_local_images(messages)

    # 文字 prompt
    text = processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)

    # 视觉输入
    patch_size = getattr(getattr(processor, "image_processor", object()), "patch_size", 14)
    try:
        image_inputs, video_inputs, video_kwargs = process_vision_info(
            messages2, image_patch_size=patch_size, return_video_kwargs=True,
            return_video_metadata=True
        )
    except TypeError:
        image_inputs, video_inputs, video_kwargs = process_vision_info(
            messages2, image_patch_size=patch_size, return_video_kwargs=True
        )

    mm_data: Dict[str, Any] = {}
    if image_inputs is not None:
        mm_data["image"] = image_inputs
    if video_inputs is not None:
        mm_data["video"] = video_inputs

    return {"prompt": text, "multi_modal_data": mm_data, "mm_processor_kwargs": video_kwargs}


def safe_prepare_inputs_for_vllm(messages: List[Dict[str, Any]], processor: AutoProcessor) -> Dict[str, Any]:
    """
    某些坏样本（坏图/坏结构）会在 process_vision_info 里炸。
    这里兜底：失败则退化成纯文本 prompt（仍能跑完，不会把整个任务卡死）。
    """
    try:
        return prepare_inputs_for_vllm(messages, processor)
    except Exception as e:
        try:
            text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        except Exception:
            text = str(messages)
        text = f"[WARN] multimodal prepare failed: {type(e).__name__}: {e}\n\n{text}"
        return {"prompt": text, "multi_modal_data": {}, "mm_processor_kwargs": {}}


# -------------------- policy response -> pyautogui actions（STRICT 对齐你的 parse_response） --------------------
def parse_response_exact(
    response: str,
    coordinate_type: str,
    original_width: int,
    original_height: int,
    processed_width: Optional[int] = None,
    processed_height: Optional[int] = None,
) -> Tuple[str, List[str], Dict[str, Any]]:
    """
    STRICT 对齐你贴的 parse_response：返回 (action_text, [pyautogui_code...], other_info)
    注意：这里不做 windows scroll scale（按你的标准做法应在 parse 之后做）
    """
    low_level_instruction = ""
    pyautogui_code: List[str] = []
    other: Dict[str, Any] = {"raw_response": response, "tool_calls": []}

    if response is None or not response.strip():
        return low_level_instruction, pyautogui_code, other

    def adjust_coordinates(x: float, y: float) -> Tuple[int, int]:
        # absolute / qwen25：按 processed->original 缩放；relative：按 0..999 网格缩放
        if coordinate_type in ("absolute", "qwen25"):
            if processed_width and processed_height:
                x_scale = original_width / processed_width
                y_scale = original_height / processed_height
                return int(x * x_scale), int(y * y_scale)
            return int(x), int(y)

        # relative (0..999)
        x_scale = original_width / 999
        y_scale = original_height / 999
        return int(x * x_scale), int(y * y_scale)

    def process_tool_call(json_str: str) -> None:
        try:
            tool_call = json.loads(json_str)
            other["tool_calls"].append(tool_call)

            if tool_call.get("name") != "computer_use":
                return
            args = tool_call.get("arguments", {})
            action = args.get("action")

            # --- mouse actions ---
            if action == "left_click":
                if "coordinate" in args:
                    x, y = args["coordinate"]
                    adj_x, adj_y = adjust_coordinates(float(x), float(y))
                    pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})")
                else:
                    pyautogui_code.append("pyautogui.click()")

            elif action == "right_click":
                if "coordinate" in args:
                    x, y = args["coordinate"]
                    adj_x, adj_y = adjust_coordinates(float(x), float(y))
                    pyautogui_code.append(f"pyautogui.rightClick({adj_x}, {adj_y})")
                else:
                    pyautogui_code.append("pyautogui.rightClick()")

            elif action == "middle_click":
                if "coordinate" in args:
                    x, y = args["coordinate"]
                    adj_x, adj_y = adjust_coordinates(float(x), float(y))
                    pyautogui_code.append(f"pyautogui.middleClick({adj_x}, {adj_y})")
                else:
                    pyautogui_code.append("pyautogui.middleClick()")

            elif action == "double_click":
                if "coordinate" in args:
                    x, y = args["coordinate"]
                    adj_x, adj_y = adjust_coordinates(float(x), float(y))
                    pyautogui_code.append(f"pyautogui.doubleClick({adj_x}, {adj_y})")
                else:
                    pyautogui_code.append("pyautogui.doubleClick()")

            elif action == "mouse_move":
                if "coordinate" in args:
                    x, y = args["coordinate"]
                    adj_x, adj_y = adjust_coordinates(float(x), float(y))
                    pyautogui_code.append(f"pyautogui.moveTo({adj_x}, {adj_y})")
                else:
                    pyautogui_code.append("pyautogui.moveTo(0, 0)")

            elif action == "left_click_drag":
                if "coordinate" in args:
                    x, y = args["coordinate"]
                    adj_x, adj_y = adjust_coordinates(float(x), float(y))
                    duration = args.get("duration", 0.5)
                    pyautogui_code.append(f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})")
                else:
                    pyautogui_code.append("pyautogui.dragTo(0, 0)")

            # --- keyboard ---
            elif action == "type":
                text = args.get("text", "")
                text = str(text).replace("\\", "\\\\").replace("'", "\\'")
                pyautogui_code.append(f"pyautogui.typewrite('{text}')")

            elif action == "key":
                keys = args.get("keys", [])
                if not isinstance(keys, list):
                    keys = [keys]
                keys = [str(k).strip() for k in keys if k is not None]
                keys_str = ", ".join([f"'{k}'" for k in keys])
                if len(keys) > 1:
                    pyautogui_code.append(f"pyautogui.hotkey({keys_str})")
                elif len(keys) == 1:
                    pyautogui_code.append(f"pyautogui.press({keys_str})")

            # --- scroll / wait / terminate ---
            elif action == "scroll":
                pixels = args.get("pixels", 0)
                pyautogui_code.append(f"pyautogui.scroll({int(pixels)})")

            elif action == "wait":
                pyautogui_code.append("WAIT")

            elif action == "terminate":
                status = (args.get("status") or "success").lower()
                pyautogui_code.append("DONE" if status == "success" else "FAIL")

        except Exception:
            return

    # ---- parse response text ----
    lines = response.split("\n")
    inside_tool_call = False
    current_tool_call: List[str] = []

    for raw in lines:
        line = raw.strip()
        if not line:
            continue

        if line.lower().startswith("action:"):
            if not low_level_instruction:
                low_level_instruction = line.split(":", 1)[-1].strip()
            continue

        if line.startswith("<tool_call>"):
            inside_tool_call = True
            continue

        if line.startswith("</tool_call>"):
            inside_tool_call = False
            if current_tool_call:
                process_tool_call("\n".join(current_tool_call))
                current_tool_call = []
            continue

        if inside_tool_call:
            current_tool_call.append(line)
            continue

        # 容忍一行 JSON
        if line.startswith("{") and line.endswith("}"):
            try:
                obj = json.loads(line)
                if "name" in obj and "arguments" in obj:
                    process_tool_call(line)
            except Exception:
                pass

    if current_tool_call:
        process_tool_call("\n".join(current_tool_call))

    if not low_level_instruction and pyautogui_code:
        low_level_instruction = "Execute the tool call"

    other["action"] = low_level_instruction
    other["code"] = pyautogui_code
    return low_level_instruction, pyautogui_code, other


# -------------------- reward prompt patch（替换 response + 替换最后一个image path） --------------------
def patch_reward_messages(
    base_reward_messages: List[Dict[str, Any]],
    old_response: str,
    new_response: str,
    new_last_image_path: str,
) -> List[Dict[str, Any]]:
    msgs = copy.deepcopy(base_reward_messages)

    # 1) 替换旧 response -> 新 response（best-effort）
    found = False
    if old_response:
        for m in msgs:
            content = m.get("content")
            if not isinstance(content, list):
                continue
            for c in content:
                if isinstance(c, dict) and c.get("type") == "text" and isinstance(c.get("text"), str):
                    if old_response in c["text"]:
                        c["text"] = c["text"].replace(old_response, new_response)
                        found = True

    if not found:
        # 找不到就追加（保证 reward prompt 至少带上新 response）
        for m in reversed(msgs):
            if m.get("role") == "user" and isinstance(m.get("content"), list):
                m["content"].append({"type": "text", "text": f"\n[New policy response]\n{new_response}\n"})
                break

    # 2) 替换最后一个 image chunk 的 path
    for m in reversed(msgs):
        content = m.get("content")
        if not isinstance(content, list):
            continue
        for c in reversed(content):
            if isinstance(c, dict) and c.get("type") == "image" and isinstance(c.get("image"), str):
                c["image"] = new_last_image_path
                return msgs

    return msgs


# -------------------- 目录与数据工具 --------------------
def is_dir(p: str) -> bool:
    return os.path.isdir(p)


def list_subdirs(p: str) -> List[str]:
    try:
        return [d for d in sorted(os.listdir(p)) if is_dir(os.path.join(p, d))]
    except FileNotFoundError:
        return []


def list_numeric_dirs(p: str) -> List[str]:
    try:
        subs = [d for d in os.listdir(p) if os.path.isdir(os.path.join(p, d))]
    except FileNotFoundError:
        return []
    if subs and all(s.isdigit() for s in subs):
        return sorted(subs, key=lambda x: int(x))
    return sorted(subs)


def load_examples_any(path: str) -> set[str]:
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)

    wanted: set[str] = set()

    def _add_item(x):
        if isinstance(x, str):
            wanted.add(x)
        elif isinstance(x, dict):
            eid = x.get("example_id") or x.get("example")
            if eid is not None:
                wanted.add(str(eid))

    if isinstance(obj, list):
        for x in obj:
            _add_item(x)
        return wanted

    if isinstance(obj, dict):
        for x in (obj.get("active") or []):
            _add_item(x)
        for x in (obj.get("temp") or []):
            _add_item(x)
        return wanted

    raise ValueError(f"examples-json 仅支持 list 或 dict，实际是 {type(obj)}")


# -------------------- 单 Engine 推理（关键） --------------------
def generate_results_single_engine(
    llm, sampling_params,
    processor: AutoProcessor,
    messages_all: List[List[Dict[str, Any]]],
    R: int,
    chunk_size: int,
    tag: str,
) -> List[str]:
    """
    返回扁平 list：长度 = len(messages_all) * R
    顺序与 messages_all 一致：每条 messages 连续 R 个输出
    """
    if not messages_all:
        return []

    out_texts: List[str] = []
    N = len(messages_all)

    for s in range(0, N, chunk_size):
        sub = messages_all[s:s + chunk_size]

        # 扩展 R 次（同一个对象引用，方便 prepare cache 命中）
        expanded: List[List[Dict[str, Any]]] = []
        for m in sub:
            expanded.extend([m] * R)

        # prepare（每个 sub 内 cache）
        prep_cache: Dict[int, Dict[str, Any]] = {}
        batch_inputs: List[Dict[str, Any]] = []
        for m in expanded:
            k = id(m)
            inp = prep_cache.get(k)
            if inp is None:
                inp = safe_prepare_inputs_for_vllm(m, processor)
                prep_cache[k] = inp
            batch_inputs.append(inp)

        print(f"[GEN:{tag}] batch {s}..{min(s + chunk_size, N)} (expanded={len(batch_inputs)})", flush=True)
        outs = llm.generate(batch_inputs, sampling_params=sampling_params)
        texts = [o.outputs[0].text if (o.outputs and len(o.outputs) > 0) else "" for o in outs]
        out_texts.extend(texts)

    return out_texts


# -------------------- 主逻辑 --------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--root-dir", required=True)
    parser.add_argument("--examples-json", required=True)

    # policy model
    parser.add_argument("--model", default="Qwen/Qwen3-VL-8B-Thinking")

    # vLLM / sampling
    parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"])
    parser.add_argument("--gpu-mem", type=float, default=0.85)
    parser.add_argument("--max-model-len", type=int, default=40960, help="<=0 使用模型默认")
    parser.add_argument("--temperature", type=float, default=0.8)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--max-tokens", type=int, default=1024)
    parser.add_argument("--stop-words", type=str,
                        default="</answer>,User:,Human:,Assistant:,<|im_end|>,<|endoftext|>")

    # 方案A：单 engine TP
    parser.add_argument("--tp-size", type=int, default=8, help="tensor_parallel_size（方案A推荐 8）")

    # CUDA graph：enforce_eager=False 才会启用；默认 False
    parser.add_argument("--enforce-eager", action="store_true",
                        help="强制 eager（会禁用 cudagraph；仅用于排障，不建议）")

    # 稳定优先：默认禁用 custom all-reduce（更少 NCCL/graph 奇怪问题）
    parser.add_argument("--enable-custom-all-reduce", action="store_true",
                        help="开启 custom all-reduce（可能更快，但更容易遇到 graph/NCCL 奇怪问题）")

    # rollout count per step
    parser.add_argument("--num-rollout-per-task", type=int, default=4)

    # parsing
    parser.add_argument("--coordinate-type", default="relative", choices=["relative", "absolute", "qwen25"])
    parser.add_argument("--platform", default="ubuntu", help="ubuntu|windows (for scroll scaling)")

    # approx uniq merge
    parser.add_argument("--approx-uniq-tol", type=int, default=10,
                        help="merge action lists within a step if approx-consistent (coord tol in pixels)")
    parser.add_argument("--no-approx-uniq", action="store_true",
                        help="disable approx merge; unique_actions falls back to exact action list dedup")

    # output names in each run folder
    parser.add_argument("--policy-output-name", default="policy_rollout_results.json")
    parser.add_argument("--reward-prompt-output-name", default="policy_rollout_reward_prompts.json")

    # placeholder screenshot dir name under each run folder
    parser.add_argument("--policy-aug-dirname", default="policy_aug")

    # batching（单 engine generate 分块，避免一次性 batch 太大）
    parser.add_argument("--chunk-size", type=int, default=1024,
                        help="单 engine 每次 generate 的 messages 数量（注意 expanded=chunk_size*R）")

    # 代理（可选）
    parser.add_argument("--download_proxy", type=str)

    args = parser.parse_args()

    if args.download_proxy:
        os.environ["HTTP_PROXY"] = args.download_proxy
        os.environ["HTTPS_PROXY"] = args.download_proxy

    print(f"[ENV] CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}", flush=True)
    print(f"[ENV] torch.cuda.device_count()={torch.cuda.device_count()}", flush=True)

    if torch.cuda.device_count() <= 0:
        raise RuntimeError("No CUDA device visible. Please set CUDA_VISIBLE_DEVICES properly.")
    if int(args.tp_size) > torch.cuda.device_count():
        raise RuntimeError(f"--tp-size={args.tp_size} > visible GPUs={torch.cuda.device_count()}")

    wanted_examples = load_examples_any(args.examples_json)

    # 1) 扫描 trajectory.json
    domains = list_subdirs(args.root_dir)
    run_tasks: List[Tuple[str, str, str, str]] = []  # (domain, example, run, traj_json)

    for domain in domains:
        domain_dir = os.path.join(args.root_dir, domain)
        examples = list_subdirs(domain_dir)
        for ex in examples:
            if ex not in wanted_examples:
                continue
            ex_dir = os.path.join(domain_dir, ex)
            runs = list_numeric_dirs(ex_dir)
            for run in runs:
                traj_json = os.path.join(ex_dir, run, "trajectory.json")
                if os.path.isfile(traj_json):
                    run_tasks.append((domain, ex, run, traj_json))

    if not run_tasks:
        print("[DONE] 没找到任何 trajectory.json（或 examples-json 为空）", flush=True)
        return

    # 2) 读取每个 run 的数据，构建 pending（扁平化 step 推理）
    # pending item:
    #   (run_key, step_pos, step_index, policy_messages, old_response, base_reward_messages, ow, oh, pw, ph)
    pending: List[Tuple[str, int, int, List[Dict[str, Any]], str, List[Dict[str, Any]], int, int, int, int]] = []

    # run_key -> run_context
    runs_ctx: Dict[str, Dict[str, Any]] = {}

    for domain, ex, run, traj_json in run_tasks:
        run_dir = os.path.dirname(traj_json)
        run_id = int(run) if str(run).isdigit() else run
        run_key = f"{domain}///{ex}///{run_id}///{run_dir}"

        try:
            with open(traj_json, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] 打开失败，跳过：{traj_json} | 原因：{e}", flush=True)
            continue

        meta_result = (data.get("meta") or {}).get("result")
        traj = data.get("trajectory") or []
        reward_traj = data.get("reward_trajectory") or []

        # step_index -> reward_messages
        reward_map: Dict[int, List[Dict[str, Any]]] = {}
        if isinstance(reward_traj, list):
            for j, item in enumerate(reward_traj):
                if isinstance(item, dict) and isinstance(item.get("reward_messages"), list):
                    si = item.get("step_index")
                    if si is None:
                        si = j
                    try:
                        reward_map[int(si)] = item["reward_messages"]
                    except Exception:
                        continue
                elif isinstance(item, list):
                    reward_map[j] = item

        steps: List[Tuple[int, Dict[str, Any]]] = []
        for j, s in enumerate(traj):
            if not isinstance(s, dict):
                continue
            si = s.get("step_index")
            if si is None:
                si = j
            try:
                si_int = int(si)
            except Exception:
                continue
            steps.append((si_int, s))
        steps.sort(key=lambda x: x[0])

        runs_ctx[run_key] = {
            "domain": domain,
            "example": ex,
            "run_id": run_id,
            "run_dir": run_dir,
            "trajectory_json": traj_json,
            "result": meta_result,
        }

        for step_pos, (step_index, s) in enumerate(steps):
            policy_messages = s.get("messages")
            if not isinstance(policy_messages, list):
                continue

            old_response = str(s.get("response") or "")

            ow, oh = 0, 0
            pw, ph = 0, 0
            try:
                osz = s.get("original_screen_size")
                if isinstance(osz, list) and len(osz) == 2:
                    ow, oh = int(osz[0]), int(osz[1])
            except Exception:
                pass
            try:
                psz = s.get("processed_screen_size")
                if isinstance(psz, list) and len(psz) == 2:
                    pw, ph = int(psz[0]), int(psz[1])
            except Exception:
                pass

            base_reward_messages = reward_map.get(step_index, [])
            if not isinstance(base_reward_messages, list):
                base_reward_messages = []

            pending.append((
                run_key, step_pos, step_index,
                policy_messages, old_response, base_reward_messages,
                ow, oh, pw, ph
            ))

    if not pending:
        print("[DONE] 没有可推理 step（可能 messages 缺失）", flush=True)
        return

    # 3) 单 engine 初始化（关键：CUDA graph）
    from vllm import LLM, SamplingParams

    stop_words = [w for w in args.stop_words.split(",") if w]
    R = max(1, int(args.num_rollout_per_task))

    disable_custom_all_reduce = (not args.enable_custom_all_reduce)
    enforce_eager = bool(args.enforce_eager)  # False => 允许 cudagraph

    print("[INIT] Loading processor...", flush=True)
    processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)

    llm_kwargs = dict(
        model=args.model,
        trust_remote_code=True,
        dtype=args.dtype,
        tensor_parallel_size=int(args.tp_size),
        gpu_memory_utilization=float(args.gpu_mem),
        max_model_len=(int(args.max_model_len) if int(args.max_model_len) > 0 else None),
        enforce_eager=enforce_eager,
        disable_custom_all_reduce=disable_custom_all_reduce,
    )
    sig = inspect.signature(LLM.__init__)
    llm_kwargs = {k: v for k, v in llm_kwargs.items() if k in sig.parameters}

    print(
        f"[INIT] Loading vLLM LLM (TP={args.tp_size}, enforce_eager={enforce_eager}, "
        f"disable_custom_all_reduce={disable_custom_all_reduce}, max_model_len={args.max_model_len}) ...",
        flush=True
    )
    llm = LLM(**llm_kwargs)

    sampling = SamplingParams(
        temperature=float(args.temperature),
        top_p=float(args.top_p),
        max_tokens=int(args.max_tokens),
        stop=stop_words,
    )

    print(f"[INFO] Single-engine mode | R={R} chunk_size={args.chunk_size}", flush=True)

    # 4) 批量 policy inference（单 engine + chunk）
    messages_all = [x[3] for x in pending]  # policy_messages
    flat_texts = generate_results_single_engine(
        llm, sampling, processor,
        messages_all=messages_all, R=R, chunk_size=int(args.chunk_size),
        tag="POL",
    )
    expected = len(messages_all) * R
    if len(flat_texts) != expected:
        raise RuntimeError(f"policy inference mismatch: got={len(flat_texts)} expected={expected}")

    # 5) 聚合输出（每个 run_dir 两个文件）
    for rk in runs_ctx:
        runs_ctx[rk]["policy_rollout_results"] = {
            "meta": {
                "domain": runs_ctx[rk]["domain"],
                "example": runs_ctx[rk]["example"],
                "run_id": runs_ctx[rk]["run_id"],
                "trajectory_json": runs_ctx[rk]["trajectory_json"],
                "result": runs_ctx[rk]["result"],
                "num_rollout_per_step": R,
            },
            "step_rollouts": []
        }
        runs_ctx[rk]["policy_rollout_reward_prompts"] = {
            "meta": {
                "domain": runs_ctx[rk]["domain"],
                "example": runs_ctx[rk]["example"],
                "run_id": runs_ctx[rk]["run_id"],
                "trajectory_json": runs_ctx[rk]["trajectory_json"],
                "result": runs_ctx[rk]["result"],
                "num_rollout_per_step": R,
            },
            "prompts": []
        }

    pos = 0
    for (run_key, _step_pos, step_index,
         _policy_messages, old_response, base_reward_messages,
         ow, oh, pw, ph) in pending:

        texts = flat_texts[pos:pos + R]
        pos += R

        run_dir = runs_ctx[run_key]["run_dir"]
        policy_aug_dir = os.path.join(run_dir, args.policy_aug_dirname, f"step_{step_index}")
        os.makedirs(policy_aug_dir, exist_ok=True)

        #uniq_map: Dict[str, Dict[str, Any]] = {}
        rollout_records: List[Dict[str, Any]] = []
        prompt_records: List[Dict[str, Any]] = []

        # real_uniq_map：完全一致（旧逻辑）的备份 -> 写到 real_unique_actions 里
        real_uniq_map: Dict[str, Dict[str, Any]] = {}

        # uniq_groups：新的 unique_actions（默认启用 approx merge；可用 --no-approx-uniq 关闭）
        uniq_groups: List[Dict[str, Any]] = []
        uniq_groups_exact: Dict[str, Dict[str, Any]] = {}  # 仅在 no-approx-uniq 时用

        tol = int(args.approx_uniq_tol)
        approx_enabled = (not bool(args.no_approx_uniq))

        def _assign_real_exact(action_key: str, actions: List[str], r_idx: int) -> Tuple[int, str]:
            """
            旧逻辑 exact uniq，但 next_obs_path 用 real_uniq_{i}_... 避免与新 uniq_{i}_... 冲突
            """
            if action_key not in real_uniq_map:
                real_idx = len(real_uniq_map)
                real_next_obs = os.path.join(policy_aug_dir, f"real_uniq_{real_idx}_next_obs.png")
                real_uniq_map[action_key] = {
                    "uniq_action_idx": real_idx,
                    "pyautogui_actions": actions,
                    "count": 0,
                    "member_r_idx": [],
                    "next_obs_path": real_next_obs,
                    "approx_uniq_action_idx": None,  # 稍后填
                }
            real_uniq_map[action_key]["count"] += 1
            real_uniq_map[action_key]["member_r_idx"].append(r_idx)
            return int(real_uniq_map[action_key]["uniq_action_idx"]), str(real_uniq_map[action_key]["next_obs_path"])

        def _assign_unique_group(action_key: str, actions: List[str], r_idx: int, real_idx: int) -> Tuple[int, str]:
            """
            返回 (uniq_action_idx, next_obs_path) —— 这就是“新的 unique_actions”的 index/path
            - approx_enabled=True：用 approx_actionlists_consistent([rep, actions], tol) 找可合并的 group
            - approx_enabled=False：退化为 exact (action_key) 去重（但 path 仍然是 uniq_{i}_...）
            """
            nonlocal uniq_groups, uniq_groups_exact

            if not approx_enabled:
                if action_key not in uniq_groups_exact:
                    uniq_idx = len(uniq_groups)
                    next_obs = os.path.join(policy_aug_dir, f"uniq_{uniq_idx}_next_obs.png")
                    g = {
                        "uniq_action_idx": uniq_idx,
                        "pyautogui_actions": actions,
                        "count": 0,
                        "member_r_idx": [],
                        "next_obs_path": next_obs,
                        "member_real_uniq_action_idx": [],
                    }
                    uniq_groups_exact[action_key] = g
                    uniq_groups.append(g)
                g = uniq_groups_exact[action_key]
            else:
                g = None
                for cand in uniq_groups:
                    rep = cand.get("pyautogui_actions") or []
                    if approx_actionlists_consistent([rep, actions], tol=tol):
                        g = cand
                        break
                if g is None:
                    uniq_idx = len(uniq_groups)
                    next_obs = os.path.join(policy_aug_dir, f"uniq_{uniq_idx}_next_obs.png")
                    g = {
                        "uniq_action_idx": uniq_idx,
                        "pyautogui_actions": actions,  # 用首次出现的作为代表
                        "count": 0,
                        "member_r_idx": [],
                        "next_obs_path": next_obs,
                        "member_real_uniq_action_idx": [],
                    }
                    uniq_groups.append(g)

            g["count"] += 1
            g["member_r_idx"].append(r_idx)
            if real_idx not in (g.get("member_real_uniq_action_idx") or []):
                g.setdefault("member_real_uniq_action_idx", []).append(real_idx)

            uniq_idx = int(g["uniq_action_idx"])
            next_obs = str(g["next_obs_path"])
            return uniq_idx, next_obs


        for r_idx, resp in enumerate(texts):
            resp = resp or ""

            # STRICT parse (no windows scale inside)
            low_level, actions, other = parse_response_exact(
                response=resp,
                coordinate_type=args.coordinate_type,
                original_width=ow if ow > 0 else 1920,
                original_height=oh if oh > 0 else 1080,
                processed_width=pw if pw > 0 else None,
                processed_height=ph if ph > 0 else None,
            )

            # EXACT: after parse, apply windows scroll scale snippet
            actions = [scale_scroll_for_windows(a, platform=args.platform) for a in actions]

            parse_ok = bool(actions) and not any(a is None or a == "" for a in actions)

            action_key = json.dumps(actions, ensure_ascii=False)
            '''
            if action_key not in uniq_map:
                uniq_idx = len(uniq_map)
                next_obs_path = os.path.join(policy_aug_dir, f"uniq_{uniq_idx}_next_obs.png")
                uniq_map[action_key] = {
                    "uniq_action_idx": uniq_idx,
                    "pyautogui_actions": actions,
                    "count": 0,
                    "member_r_idx": [],
                    "next_obs_path": next_obs_path,
                }

            uniq_map[action_key]["count"] += 1
            uniq_map[action_key]["member_r_idx"].append(r_idx)

            uniq_idx = int(uniq_map[action_key]["uniq_action_idx"])
            next_obs_path = str(uniq_map[action_key]["next_obs_path"])'''

            # 1) 旧 exact uniq -> real_unique_actions（仅备份/检查用；path 用 real_uniq_ 前缀）
            real_idx, _real_next_obs_path = _assign_real_exact(action_key, actions, r_idx)

            # 2) 新 uniq（默认 approx merge）-> unique_actions（path 用 uniq_ 前缀，给 VM 真跑）
            uniq_idx, next_obs_path = _assign_unique_group(action_key, actions, r_idx, real_idx)

            # 回填：real -> approx 归属（方便你后面 check）
            try:
                real_uniq_map[action_key]["approx_uniq_action_idx"] = int(uniq_idx)
            except Exception:
                pass

            # reward prompt patch
            patched_reward_messages = patch_reward_messages(
                base_reward_messages=base_reward_messages,
                old_response=old_response,
                new_response=resp,
                new_last_image_path=next_obs_path,
            )

            rollout_records.append({
                "r_idx": r_idx,
                "parse_ok": parse_ok,
                "response": resp,
                "low_level_instruction": low_level,
                "pyautogui_actions": actions,
                "uniq_action_idx": uniq_idx,
                "real_uniq_action_idx": real_idx,
                "other_info": other,
            })

            prompt_records.append({
                "step_index": step_index,
                "r_idx": r_idx,
                "uniq_action_idx": uniq_idx,
                "next_obs_path": next_obs_path,
                "real_uniq_action_idx": real_idx,
                "reward_messages": patched_reward_messages,
            })

        #unique_actions = sorted(uniq_map.values(), key=lambda x: int(x["uniq_action_idx"]))
        #unique_action_lists = [ua["pyautogui_actions"] for ua in unique_actions]

        # 新 unique_actions（approx merge 后）
        unique_actions = sorted(uniq_groups, key=lambda x: int(x["uniq_action_idx"]))
        unique_action_lists = [ua.get("pyautogui_actions") or [] for ua in unique_actions]

        # 旧 unique_actions（exact）备份到 real_unique_actions（供你后面 check）
        real_unique_actions = sorted(real_uniq_map.values(), key=lambda x: int(x["uniq_action_idx"]))
        real_unique_action_lists = [ua.get("pyautogui_actions") or [] for ua in real_unique_actions]


        runs_ctx[run_key]["policy_rollout_results"]["step_rollouts"].append({
            "step_index": step_index,
            "num_rollout": R,
            "policy_rollouts": rollout_records,          # 每个 rollout 一个 action list
            "unique_actions": unique_actions,            # 去重后的 action list + next_obs_path
            "unique_action_lists": unique_action_lists,  # “list of action list”
            "num_unique_actions": len(unique_actions),
            "real_unique_actions": real_unique_actions,              # ✅ 旧逻辑备份（exact）
            "real_unique_action_lists": real_unique_action_lists,    # ✅ 旧逻辑备份（exact）
            "num_real_unique_actions": len(real_unique_actions),
            "approx_uniq_enabled": bool(approx_enabled),
            "approx_uniq_tol": int(tol),
        })

        runs_ctx[run_key]["policy_rollout_reward_prompts"]["prompts"].append({
            "step_index": step_index,
            "num_rollout": R,
            "items": prompt_records,  # each rollout has its reward_messages
        })

    # 6) 写回到每个 run_dir（和 trajectory.json 同目录）
    for rk, ctx in runs_ctx.items():
        run_dir = ctx["run_dir"]
        os.makedirs(run_dir, exist_ok=True)

        out_policy = os.path.join(run_dir, args.policy_output_name)
        out_prompt = os.path.join(run_dir, args.reward_prompt_output_name)

        tmp1 = out_policy + ".tmp"
        tmp2 = out_prompt + ".tmp"
        with open(tmp1, "w", encoding="utf-8") as f:
            json.dump(ctx["policy_rollout_results"], f, ensure_ascii=False, indent=2)
        with open(tmp2, "w", encoding="utf-8") as f:
            json.dump(ctx["policy_rollout_reward_prompts"], f, ensure_ascii=False, indent=2)

        os.replace(tmp1, out_policy)
        os.replace(tmp2, out_prompt)

        print(f"[WRITE] {out_policy}", flush=True)
        print(f"[WRITE] {out_prompt}", flush=True)

    print("[DONE] all runs written.", flush=True)


if __name__ == "__main__":
    main()
