# -*- coding: utf-8 -*-
"""
Distributed multi-turn runner:
- OpenAI SDK calls (no requests wrapper)
- timeout fixed to 1200s
- 1 process per IP
- each process keeps 16 concurrent trajectories (threads)
- a trajectory is pinned to the IP for all rounds (KV reuse on server side)
- if an IP becomes unavailable, the whole trajectory is re-queued and restarted on another IP (no partial save)
- output is a single JSONL file (append one line per finished trajectory, order not guaranteed)

关键修复：
1) 修复 JoinableQueue + requeue 导致 task_q.join() 永远等不到 unfinished_tasks 归零（死锁）
2) 不再依赖 None sentinel；stop_event 后 worker 会 drain 队列并 task_done
3) 默认优先 fork（Linux）避免 spawn 导致子进程重复模块级大加载；不支持 fork 则 fallback spawn
4) 大文件/缓存加载从模块级挪到进程内一次性初始化（每个进程只做一次）
5) print flush=True，避免输出被缓冲看起来“卡住”
"""

import copy
import sys
import os
import json
import re
import time
import queue
import threading
import traceback
import multiprocessing as mp
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime

import yaml
from openai import OpenAI

# ====== PATH 修复 ======
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)
print("Added ROOT:", ROOT, flush=True)

# ====== 引入工具 ======
from Tools.General.General_tools import GeneralTool
from Tools.Attraction.Attraction_tools import AttractionTool
from Tools.Flight.Flight_tools import FlightTool
from Tools.Train.Train_tools import TrainTool
from Tools.Restaurant.Restaurant_tools import RestaurantTool
from Tools.Hotel.Hotel_tools import HotelTool

from Tools.tool_description import (
    Flight_tool_description,
    Train_tool_description,
    Hotel_tool_description,
    Attraction_tool_description,
    Restaurant_tool_description,
    General_tool_description,
)

from interact.prompt.agent_system_prompt import system_prompt_en,system_prompt_en_single_turn
from interact.prompt.user_prompt import user_prompt_easy_en, user_prompt_easy_en_no_issue
from interact.user_agent import UserSimulator
from interact.evaluation_agent import TripPlanEvaluator

from Evaluation.Attraction.val_attractions import AttractionEvaluator
from Evaluation.Restaurant.val_restaurants import RestaurantEvaluator
from Evaluation.Hotel.val_hotels import HotelEvaluator
from Evaluation.Transportation.val_transports import TransportationEvaluator
from Evaluation.General.val_general import GeneralEvaluator, load_or_build_cache


# ============================================================
#                    CONFIG / RUNNER CONFIG
# ============================================================

CONFIG_PATH = " 模型名按你实际服务端支持
AGENT_MODEL_NAME = "deepseek-v32"
USER_MODEL_NAME = "deepseek-v32"
TEMPERATURE = 1.0
MAX_TOKENS = 32768

# 多轮
DIALOGUE_TURNS = 15
KEEP_LAST_ASSISTANT = 3

# IP 列表（每 IP 一个进程）
IPS = [
#shxs

"33.32.14.88",

"33.32.11.158",
"33.18.255.113",
"33.32.14.105",
"33.32.22.48",

"33.32.32.77",
"33.18.242.10",
"33.18.235.74",

]


# 并发（每 IP 同时跑 16 条 trajectory）
REQUESTS_PER_IP = 16

# IP 冷却
COOLDOWN_SECONDS = 1 * 60

# 输出 JSONL（跑完一条写一条，不保证顺序）
OUTPUT_JSONL_PATH = " ============================================================
#           进程内资源初始化（避免 spawn 重复模块级大加载）
# ============================================================
GLOBAL_RESOURCES = None  # fork 后子进程继承（CoW）

def _load_resources_for_process(config_path: str) -> Dict[str, Any]:
    """
    在每个子进程里调用一次：加载 config、数据、cache、evaluator。
    避免 spawn 模式下 import 阶段就卡住/重复重建大对象。
    """
    with open(config_path, "r", encoding="utf-8") as f:
        config = yaml.safe_load(f)

    data_path = config["data_path"]
    attraction_path = data_path["attraction"]
    restaurant_path = data_path["restaurant"]
    hotel_path = data_path["hotel"]
    flight_path = data_path["flight"]
    train_path = data_path["train"]

    # evaluators
    with open(attraction_path, "r", encoding="utf-8") as f:
        attractions_data = json.load(f)

    attraction_evaluator = AttractionEvaluator(attractions_data)
    restaurant_evaluator = RestaurantEvaluator(restaurants_path=restaurant_path)
    hotel_evaluator = HotelEvaluator(hotel_path=hotel_path)
    transportation_evaluator = TransportationEvaluator(flights_path=flight_path)

    # 这个通常很重：放到进程内初始化，但只做一次
    data_cache = load_or_build_cache(config_path)
    general_evaluator = GeneralEvaluator(data_cache)

    return {
        "config": config,
        "attraction_evaluator": attraction_evaluator,
        "restaurant_evaluator": restaurant_evaluator,
        "hotel_evaluator": hotel_evaluator,
        "transportation_evaluator": transportation_evaluator,
        "general_evaluator": general_evaluator,
    }


# ============================================================
# 工具调用封装类（保持你原来的结构）
# ============================================================

class FunctionCall:
    def __init__(self, data: Dict[str, Any]):
        self.name = data.get("name")
        self.arguments = data.get("arguments")


class ToolCall:
    def __init__(self, data: Dict[str, Any]):
        self.id = data.get("id")
        self.type = data.get("type")
        self.function = FunctionCall(data.get("function") or {})


class AssistantMessage:
    def __init__(self, data: Dict[str, Any]):
        self.role = data.get("role")
        self.content = data.get("content")
        self.tool_calls = [ToolCall(tc) for tc in (data.get("tool_calls") or [])]


def serialize_tool_calls(tool_calls):
    if not tool_calls:
        return None
    out = []
    for call in tool_calls:
        out.append({
            "id": call.id,
            "type": call.type,
            "function": {
                "name": call.function.name,
                "arguments": call.function.arguments,
            }
        })
    return out


# ============================================================
# 每轮用户状态记录
# ============================================================

@dataclass
class RoundUserState:
    round_idx: int
    id_list_before: List[str]
    id_list_after: List[str]
    blocks: Dict[str, str]                # HISTORY/NEW/MODIFY/ISSUE
    sim_prompt: Optional[str]
    sim_output: Optional[Dict[str, Any]]
    newly_added_ids: List[str]
    all_added_ids: List[str]
    final_answer: str
    eval_result: Dict[str, Any]
    use_next_turn_reward: Optional[bool]
    start_time: Optional[str]
    end_time: Optional[str]
    tool_calls: List[Any]


# ============================================================
#                  TravelAgent 主类（OpenAI SDK 版本）
# ============================================================

class TravelAgent:
    """
    每个进程只初始化一个 TravelAgent。
    为了线程安全与并发，OpenAI client 使用 thread-local（每线程一个 client），但 agent 只有一个实例。
    """
    def __init__(
        self,
        *,
        config_path: str,
        base_url: str,
        api_key: str = "EMPTY",
        timeout_sec: float = 1200.0,
        # evaluators 注入（进程内初始化一次）
        attraction_evaluator: AttractionEvaluator,
        restaurant_evaluator: RestaurantEvaluator,
        hotel_evaluator: HotelEvaluator,
        transportation_evaluator: TransportationEvaluator,
        general_evaluator: GeneralEvaluator,
    ):
        self._base_url = base_url
        self._api_key = api_key
        self._timeout_sec = timeout_sec
        self._tlocal = threading.local()

        # 工具类
        self.general_tool = GeneralTool()
        self.attraction_tool = AttractionTool(config_path)
        self.flight_tool = FlightTool(config_path)
        self.train_tool = TrainTool(config_path)
        self.restaurant_tool = RestaurantTool(config_path)
        self.hotel_tool = HotelTool(config_path)

        # 工具映射
        self.tools = {
            "search_flights": self.flight_tool.search_flights,
            "get_flight_detail_with_products": self.flight_tool.get_flight_detail_with_products,
            "get_airport_coordinates": self.flight_tool.get_airport_coordinates,

            "search_trains": self.train_tool.search_trains,
            "get_train_detail_with_products": self.train_tool.get_train_detail_with_products,
            "get_station_coordinates": self.train_tool.get_station_coordinates,

            "search_hotels": self.hotel_tool.search_hotels,
            "get_hotel_detail_with_products": self.hotel_tool.get_hotel_detail_with_products,
            "get_hotel_coordinates": self.hotel_tool.get_hotel_coordinates,

            "search_attractions": self.attraction_tool.search_attractions,
            "get_attraction_detail_with_products": self.attraction_tool.get_attraction_detail_with_products,
            "get_attraction_coordinates": self.attraction_tool.get_attraction_coordinates,

            "search_restaurants": self.restaurant_tool.search_restaurants,
            "get_restaurant_detail_with_products": self.restaurant_tool.get_restaurant_detail_with_products,
            "get_restaurant_coordinates": self.restaurant_tool.get_restaurant_coordinates,

            "get_route_estimate": self.general_tool.get_route_estimate,
            "get_city_center_coords": self.general_tool.get_city_center_coords,
            "get_date_after": self.general_tool.get_date_after,
        }

        # 工具描述
        self.tool_descriptions = (
            Flight_tool_description
            + Train_tool_description
            + Hotel_tool_description
            + Attraction_tool_description
            + Restaurant_tool_description
            + General_tool_description
        )

        # Evaluator
        self.trip_plan_evaluator = TripPlanEvaluator(
            attraction_evaluator=attraction_evaluator,
            restaurant_evaluator=restaurant_evaluator,
            hotel_evaluator=hotel_evaluator,
            transportation_evaluator=transportation_evaluator,
            general_evaluator=general_evaluator,
        )

    def _get_client(self) -> OpenAI:
        if not hasattr(self._tlocal, "client"):
            self._tlocal.client = OpenAI(
                api_key=self._api_key,
                base_url=self._base_url,
                timeout=self._timeout_sec,
            )
        return self._tlocal.client

    def _get_user_simulator(self) -> UserSimulator:
        if not hasattr(self._tlocal, "user_sim"):
            self._tlocal.user_sim = UserSimulator(
                client=self._get_client(),
                model_name=USER_MODEL_NAME,
                temperature=TEMPERATURE,
                max_tokens=16384,
            )
        return self._tlocal.user_sim

    # ======================== 核心 Chat Loop ========================
    def chat(self, history: Optional[List[Dict[str, str]]] = None, max_rounds=1000):
        if history is None:
            history = []

        trace_log = {
            "start_time": datetime.now().isoformat(),
            "messages": [],
            "trace_log_messages": [],
            "tool_calls": [],
            "final_answer": None,
        }

        messages = copy.deepcopy(history)
        trace_log["trace_log_messages"].extend(messages)

        for _ in range(max_rounds):
            client = self._get_client()
            resp = client.chat.completions.create(
                model=AGENT_MODEL_NAME,
                messages=messages,
                tools=self.tool_descriptions,
                tool_choice="auto",
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
                extra_body={"chat_template_kwargs": {"thinking": False}},
            )

            # ⚠️ resp 很大，建议只打印必要信息，避免 IO 卡顿
            # print(resp, flush=True)
            # try:
            #     usage = getattr(resp, "usage", None)
            #     print(f"[chat] got response, usage={usage}", flush=True)
            # except Exception:
            #     print("[chat] got response", flush=True)

            assistant_msg_dict = resp.choices[0].message.model_dump()
            messages.append(assistant_msg_dict)

            assistant_msg = AssistantMessage(assistant_msg_dict)
            trace_log["trace_log_messages"].append({
                "role": "assistant",
                "content": assistant_msg.content,
                "tool_calls": serialize_tool_calls(assistant_msg.tool_calls),
            })

            tool_calls = assistant_msg.tool_calls

            # ---------- 没有工具调用 = 最终答案 ----------
            if not tool_calls:
                final_answer = assistant_msg.content or ""
                trace_log["messages"] = messages
                trace_log["final_answer"] = final_answer
                trace_log["end_time"] = datetime.now().isoformat()
                return trace_log

            # ---------- 工具调用 ----------
            for call in tool_calls:
                tool_name = call.function.name
                raw_args = call.function.arguments or "{}"

                try:
                    args = json.loads(raw_args)
                except Exception:
                    args = {}

                tool_func = self.tools.get(tool_name)

                if tool_func is None:
                    result = {"error": f"Unknown tool {tool_name}"}
                    trace_log["tool_calls"].append({
                        "tool": tool_name,
                        "args": args,
                        "result": {"success": False, "error": result["error"]},
                    })
                else:
                    try:
                        output = tool_func(**args)
                        result = output
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": True, "data": output},
                        })
                    except Exception as e:
                        err = str(e)
                        result = {"error": err}
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": False, "error": err},
                        })

                if isinstance(result, str):
                    content = result
                else:
                    content = json.dumps(result, ensure_ascii=False)

                tool_msg = {
                    "role": "tool",
                    "tool_call_id": call.id,
                    "name": tool_name,
                    "content": content,
                }

                messages.append(tool_msg)
                trace_log["trace_log_messages"].append(tool_msg)

        trace_log["messages"] = messages
        trace_log["final_answer"] = "Too many rounds."
        trace_log["end_time"] = datetime.now().isoformat()
        return trace_log

    # ============================================================
    #                  Multi-turn helpers（静态方法）
    # ============================================================

    @staticmethod
    def _safe_format(desc: str, slot: str) -> str:
        try:
            return desc.format(slot=slot)
        except Exception:
            return (desc or "").replace("{slot}", slot)

    @staticmethod
    def _build_id2instruction(meta_dict: Dict[str, Any]) -> Dict[str, str]:
        id2instr = {}
        rubric_results = meta_dict.get("rubric_results", {}) or {}

        for _, rubric_group in rubric_results.items():
            for _, cfg in (rubric_group or {}).items():
                desc_tmpl = cfg.get("description", "") or ""
                result = cfg.get("result", {}) or {}
                all_labels = result.get("all_labels_and_ranges", {}) or {}

                for label_text, payload in all_labels.items():
                    _id = (payload or {}).get("_id")
                    if not _id:
                        continue
                    instr = TravelAgent._safe_format(desc_tmpl, label_text).strip()
                    id2instr[_id] = instr
        return id2instr

    @staticmethod
    def _flatten_applied_chains(meta_dict: Dict[str, Any]) -> List[str]:
        chains = meta_dict.get("applied_modification_chains") or {}
        out = []
        for _, chain in chains.items():
            if chain:
                out.append(chain[-1])
        return out

    @staticmethod
    def _build_blocks(meta_dict: Dict[str, Any], id_list: List[str], rubric_progress: Dict[str, int]) -> Dict[str, str]:
        id2instr = TravelAgent._build_id2instruction(meta_dict)

        # ===== HISTORY =====
        history_lines = []
        for cid in id_list:
            txt = id2instr.get(cid)
            if txt:
                history_lines.append(f"[ID: HIS_{cid}] {txt}")

        chains = meta_dict.get("applied_modification_chains", {}) or {}
        new_candidates = []
        modify_items = []  # (rubric_key, expected_id)

        for rubric_key in sorted(chains.keys()):
            chain = chains[rubric_key] or []
            if not chain:
                continue

            prog = rubric_progress.get(rubric_key, 0)

            if prog >= len(chain):
                continue

            if prog == 0:
                new_candidates.append(chain[0])
            else:
                modify_items.append((rubric_key, chain[prog]))

        # ===== NEW =====
        new_lines = []
        for cid in new_candidates:
            txt = id2instr.get(cid)
            if txt:
                new_lines.append(f"[ID: NEW_{cid}] {txt}")

        # ===== MODIFY =====
        modify_lines = []
        for rubric_key, expected_id in modify_items:
            chain = chains.get(rubric_key, [])
            prog = rubric_progress.get(rubric_key, 0)
            expected_txt = id2instr.get(expected_id)
            if not expected_txt:
                continue

            if expected_id in chain[:prog]:
                removed_id = chain[prog - 1]
                removed_txt = id2instr.get(removed_id, "")
                line = (
                    f"[ID: MOD_{expected_id}] {expected_txt} "
                    f"(Note: the previous requirement {removed_txt} is no longer needed)"
                )
            else:
                line = f"[ID: MOD_{expected_id}] {expected_txt}"

            modify_lines.append(line)

        return {
            "HISTORY": "\n".join(history_lines).strip(),
            "NEW": "\n".join(new_lines).strip(),
            "MODIFY": "\n".join(modify_lines).strip(),
        }

    @staticmethod
    def _history_text_for_sim(conv_simple: List[Dict[str, str]], keep_last_assistant: int = 3) -> str:
        OMIT_TEXT = "Earlier assistant responses are omitted due to context length limits."
        assistant_indices = [i for i, m in enumerate(conv_simple) if m["role"] == "assistant"]
        keep_indices = set(assistant_indices[-keep_last_assistant:])

        lines = []
        for i, m in enumerate(conv_simple):
            if m["role"] == "user":
                lines.append(f"User: {m['content']}")
            elif m["role"] == "assistant":
                if i in keep_indices:
                    lines.append(f"Assistant: {m['content']}")
                else:
                    lines.append(f"Assistant: {OMIT_TEXT}")
        return "\n".join(lines).strip()

    def _apply_proposed_ids_with_prefix(
        self,
        meta_dict: Dict[str, Any],
        current_id_list: List[str],
        rubric_progress: Dict[str, int],
        proposed_ids: List[Any],
    ) -> Tuple[bool, List[str], List[str], List[str]]:
        chains = meta_dict.get("applied_modification_chains", {}) or {}
        id2pos: Dict[str, Tuple[str, int]] = {}
        for rubric_key, chain in chains.items():
            for idx, cid in enumerate(chain):
                id2pos[cid] = (rubric_key, idx)

        errors: List[str] = []
        newly_added: List[str] = []
        removed: List[str] = []

        if not isinstance(proposed_ids, list):
            return False, [], [], [f"instruction_ids is not a list: {type(proposed_ids)}"]

        for raw in proposed_ids:
            if not isinstance(raw, str) or not raw:
                errors.append(f"invalid instruction_id (not str/empty): {raw!r}")
                continue

            if raw.startswith("NEW_"):
                op = "NEW"
                base_id = raw[len("NEW_"):]
            elif raw.startswith("MOD_"):
                op = "MOD"
                base_id = raw[len("MOD_"):]
            else:
                continue

            if base_id not in id2pos:
                errors.append(f"{raw} -> {base_id} not found in applied_modification_chains")
                continue

            rubric_key, idx = id2pos[base_id]
            chain = chains.get(rubric_key, []) or []

            if op == "NEW":
                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)
                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

            else:  # MOD
                if idx <= 0:
                    errors.append(f"{raw}: target is first element of chain, no previous to remove")
                    continue

                prev_id = chain[idx - 1]
                if prev_id not in current_id_list:
                    errors.append(f"{raw}: expected to remove prev_id={prev_id}, but it's not in current_id_list")
                    continue

                try:
                    current_id_list.remove(prev_id)
                    removed.append(prev_id)
                except ValueError:
                    errors.append(f"{raw}: failed to remove prev_id={prev_id} (ValueError)")
                    continue

                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)

                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

        ok = (len(errors) == 0)
        return ok, newly_added, removed, errors

    def rubric_all_satisfied(self, rubric_progress: dict, applied_chains: dict) -> bool:
        for rubric, chains in applied_chains.items():
            required = len(chains)
            current = rubric_progress.get(rubric, 0)
            if current < required:
                return False
        return True

    def select_raw_answer_for_eval(self, current_raw_answer: str, conv_simple: List[Dict]) -> str:
        def extract_daily_schedule(raw_text: str) -> Optional[list]:
            try:
                data = json.loads(raw_text)
                if isinstance(data, dict):
                    tp = data.get("trip_plan", {})
                    if isinstance(tp, dict):
                        return tp.get("daily_schedule")
            except Exception:
                pass

            match = re.search(r"```json\s*(\{.*?\})\s*```", raw_text, re.S)
            if match:
                try:
                    data = json.loads(match.group(1))
                    tp = data.get("trip_plan", {})
                    if isinstance(tp, dict):
                        return tp.get("daily_schedule")
                except Exception:
                    pass
            return None

        schedule = extract_daily_schedule(current_raw_answer)
        if isinstance(schedule, list) and len(schedule) > 0:
            return current_raw_answer

        for msg in reversed(conv_simple):
            if msg.get("role") != "assistant":
                continue
            raw_text = msg.get("content", "")
            schedule = extract_daily_schedule(raw_text)
            if isinstance(schedule, list) and len(schedule) > 0:
                return raw_text

        return current_raw_answer

    def run_multiturn_trip(
        self,
        meta_dict: Dict[str, Any],
        dialog_turns: int = DIALOGUE_TURNS,
        keep_last_assistant: int = 3,
        extra_save_dir: str = "__unused__",
    ) -> Dict[str, Any]:

        trip_id = meta_dict.get("trip_id", "unknown_trip")
        rubric_progress = copy.deepcopy(meta_dict.get("rubric_progress", {}))

        conv_full: List[Dict[str, str]] = []
        conv_simple: List[Dict[str, str]] = []
        rounds: List[RoundUserState] = []

        if dialog_turns == 1:
            current_user_query = meta_dict.get("query_with_constraints", meta_dict.get("query_basic", ""))
            current_id_list = self._flatten_applied_chains(meta_dict)
        else:
            current_user_query = meta_dict.get("query_basic", "")
            current_id_list = meta_dict.get("instruction_ids_basic", [])

        issue_text = ""

        for r in range(1, dialog_turns + 1):
            # ====== 快照：保证失败轮不落盘 & 回滚到上一轮 ======
            snap_conv_full = copy.deepcopy(conv_full)
            snap_conv_simple = copy.deepcopy(conv_simple)
            snap_id_list = list(current_id_list)
            snap_progress = copy.deepcopy(rubric_progress)
            snap_issue_text = issue_text

            try:
                id_before = list(current_id_list)

                blocks = self._build_blocks(meta_dict, current_id_list, rubric_progress)
                blocks["ISSUE"] = issue_text or ""

                sim_prompt = None
                sim_output = None
                newly_added = []
                all_added = []
                use_next_turn_reward = False

                if r >= 2:
                    history_text = self._history_text_for_sim(conv_simple, keep_last_assistant)
                    max_retry = 3
                    retry = 0

                    while True:
                        sim_prompt = user_prompt_easy_en_no_issue \
                            .replace("{{HISTORY}}", blocks.get("HISTORY", "")) \
                            .replace("{{NEW}}", blocks.get("NEW", "")) \
                            .replace("{{MODIFY}}", blocks.get("MODIFY", "")) \
                            .replace("{{ISSUE}}", blocks.get("ISSUE", "")) \
                            .replace("{{HISTORY_MESSAGES}}", history_text)

                        user_sim = self._get_user_simulator()
                        if retry > 0:
                            sim_output = user_sim.generate(sim_prompt, temperature=0.7)
                        else:
                            sim_output = user_sim.generate(sim_prompt)

                        proposed_ids = sim_output.get("instruction_ids", []) or []
                        use_next_turn_reward = (
                            not proposed_ids or
                            all(isinstance(pid, str) and pid.startswith("HIS_") for pid in proposed_ids)
                        )

                        tmp_id_list = list(current_id_list)
                        tmp_progress = copy.deepcopy(rubric_progress)

                        ok, newly_added_tmp, _, errors = self._apply_proposed_ids_with_prefix(
                            meta_dict=meta_dict,
                            current_id_list=tmp_id_list,
                            rubric_progress=tmp_progress,
                            proposed_ids=proposed_ids,
                        )

                        if ok:
                            current_id_list = tmp_id_list
                            rubric_progress = tmp_progress
                            newly_added = newly_added_tmp
                            current_user_query = sim_output.get("user_query", "")
                            all_added = proposed_ids
                            break

                        retry += 1
                        if retry >= max_retry:
                            newly_added = []
                            current_user_query = sim_output.get("user_query", "")
                            break

                if r == 1:
                    if dialog_turns == 1:
                        conv_full.append({"role": "system", "content": system_prompt_en_single_turn})
                    else:
                        conv_full.append({"role": "system", "content": system_prompt_en})

                conv_full.append({"role": "user", "content": current_user_query})
                conv_simple.append({"role": "user", "content": current_user_query})

                trace_log = self.chat(history=conv_full)
                conv_full = copy.deepcopy(trace_log["messages"])
                conv_simple.append({"role": "assistant", "content": trace_log["final_answer"]})

                raw_answer_for_eval = self.select_raw_answer_for_eval(
                    current_raw_answer=trace_log["final_answer"],
                    conv_simple=conv_simple
                )

                eval_result = self.trip_plan_evaluator.evaluate(
                    meta_dict=meta_dict,
                    raw_text=raw_answer_for_eval,
                    id_list=current_id_list
                )
                issue_text = eval_result.get("error_text", "") or ""

                round_state = RoundUserState(
                    round_idx=r,
                    id_list_before=id_before,
                    id_list_after=list(current_id_list),
                    blocks=blocks,
                    sim_prompt=sim_prompt,
                    sim_output=sim_output,
                    newly_added_ids=newly_added,
                    all_added_ids=all_added,
                    final_answer=trace_log["final_answer"],
                    tool_calls=trace_log["tool_calls"],
                    eval_result=eval_result,
                    use_next_turn_reward=use_next_turn_reward,
                    start_time=trace_log["start_time"],
                    end_time=trace_log["end_time"],
                )
                rounds.append(round_state)

                applied_chains = meta_dict.get("applied_modification_chains", {})
                if self.rubric_all_satisfied(rubric_progress, applied_chains):
                    print(f"All rubrics are proposed at round {r}", flush=True)
                    break

            except Exception as e:
                # ====== 本轮失败：回滚 ======
                conv_full = snap_conv_full
                conv_simple = snap_conv_simple
                current_id_list = snap_id_list
                rubric_progress = snap_progress
                issue_text = snap_issue_text

                # ====== 关键规则：只要已有轮次>=1就返回 partial；否则继续抛给上层按原逻辑处理 ======
                if len(rounds) >= 1:
                    summary = {
                        "trip_id": trip_id,
                        "dialog_turns": dialog_turns,
                        "keep_last_assistant": keep_last_assistant,
                        "final_id_list": list(current_id_list),
                        "messages": conv_full,
                        "rounds": [asdict(x) for x in rounds],
                        "terminate_reason": {
                            "round_failed": r,
                            "error": str(e),
                        },
                    }
                    return summary
                else:
                    raise

        summary = {
            "trip_id": trip_id,
            "dialog_turns": dialog_turns,
            "keep_last_assistant": keep_last_assistant,
            "final_id_list": list(current_id_list),
            "messages": conv_full,
            "rounds": [asdict(x) for x in rounds],
            "terminate_reason": None,
        }
        return summary


# ============================================================
#                    DISTRIBUTED RUNNER
# ============================================================
def canonical_final_id_list(final_id_list):
    """去重+排序，保证顺序不影响相等判定"""
    if not isinstance(final_id_list, list):
        return tuple()
    return tuple(sorted({str(x) for x in final_id_list if x is not None}))

def make_sample_key(trip_id, final_id_list):
    return (str(trip_id), canonical_final_id_list(final_id_list))

def final_id_list_from_meta(meta: dict):
    """
    与你参考代码一致：
    - 优先 applied_modification_chains：每个 rubric_key 的 chain 取最后一个 id
    - 兜底：meta["final_id_list"]
    """
    chains = meta.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        out = []
        for _, chain in chains.items():
            if isinstance(chain, list) and chain:
                out.append(chain[-1])
        return out

    fil = meta.get("final_id_list")
    if isinstance(fil, list):
        return fil
    return []

def final_id_list_from_output_record(rec: dict):
    """
    从输出 jsonl 的一条 record 里取用于判重的 id_list：

    - dialog_turns > 1（multi-turn）：
        优先用 rounds[0].id_list_before（它在你的实现里就是第1轮开始前的 current_id_list，
        通常等于 meta["instruction_ids_basic"]），避免 summary.final_id_list 漂移（没跑完/中断）。

    - dialog_turns <= 1（single-turn）：
        保持原逻辑：优先 summary.final_id_list，兜底 rec.final_id_list
    """
    if not isinstance(rec, dict):
        return []

    summary = rec.get("summary") or {}

    # dialog_turns 优先从顶层读（你写 record 时有），没有再从 summary 读
    dt = rec.get("dialog_turns")
    if dt is None:
        dt = summary.get("dialog_turns")

    try:
        dt = int(dt or 0)
    except Exception:
        dt = 0

    if dt > 1:
        # ✅ multi-turn：优先 round1 的 id_list_before（稳定、与 meta["instruction_ids_basic"] 对齐）
        rounds = summary.get("rounds")
        if isinstance(rounds, list) and len(rounds) > 0:
            r0 = rounds[0] or {}
            fil0 = r0.get("id_list_before")
            if isinstance(fil0, list):
                return fil0

        # 可选兼容：如果你未来在 summary 里也存了 basic
        fil_basic = summary.get("instruction_ids_basic")
        if isinstance(fil_basic, list):
            return fil_basic

        # 最后兜底（不理想，但总比空强）
        fil = summary.get("final_id_list")
        if isinstance(fil, list):
            return fil

        fil2 = rec.get("final_id_list")
        if isinstance(fil2, list):
            return fil2

        return []

    # ✅ single-turn：保持原逻辑
    fil = summary.get("final_id_list")
    if isinstance(fil, list):
        return fil
    fil2 = rec.get("final_id_list")
    if isinstance(fil2, list):
        return fil2
    return []


def load_done_keys_from_output_jsonl(output_jsonl_path: str):
    """
    扫 output jsonl，构建已经跑过的 sample_key 集合。
    不要求每行都有 eval；只要 summary.final_id_list 写了就算跑过。

    ⚠️ 这里不改变你的“算 done 的条件”（仍然：只要有 final_id_list 就算 done），
       只是通过 final_id_list_from_output_record() 在 multi-turn 时改用 round1.id_list_before。
       如果你希望“partial 不算 done”，我也能给你对应的版本。
    """
    done = set()
    if not output_jsonl_path or (not os.path.exists(output_jsonl_path)):
        return done

    with open(output_jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except Exception:
                continue

            trip_id = rec.get("trip_id") or (rec.get("summary") or {}).get("trip_id")
            fil = final_id_list_from_output_record(rec)
            done.add(make_sample_key(trip_id, fil))

    return done



def get_resources_for_process(config_path: str):
    """
    子进程入口获取资源：
    - fork：GLOBAL_RESOURCES 已继承，直接返回（不再 load）
    - spawn：GLOBAL_RESOURCES 为空，只能各进程自己 load 一份
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is not None:
        return GLOBAL_RESOURCES
    return _load_resources_for_process(config_path)

def preload_resources_in_parent(config_path: str):
    """
    只在父进程里调用（fork 前）。
    fork 后子进程会继承 GLOBAL_RESOURCES，并通过 Copy-on-Write 共享内存页。
    注意：子进程如果写入这些对象，会触发 CoW 复制导致内存上涨。
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is None:
        print("[runner] preload resources in parent ...", flush=True)
        GLOBAL_RESOURCES = _load_resources_for_process(config_path)
        print("[runner] preload resources done.", flush=True)

        # 可选：冻结当前对象图，降低后续 GC 扫描开销（不影响功能）
        try:
            import gc
            gc.freeze()
            print("[runner] gc.freeze() done.", flush=True)
        except Exception:
            pass


def _append_jsonl(file_lock: mp.Lock, output_jsonl_path: str, obj: dict):
    line = json.dumps(obj, ensure_ascii=False)
    with file_lock:
        with open(output_jsonl_path, "a", encoding="utf-8") as f:
            f.write(line + "\n")

def _ip_process_main(
    ip: str,
    task_q,                 # mp.JoinableQueue
    done_counter,           # mp.Value('i', 0)
    total_tasks: int,
    stop_event,             # mp.Event
    file_lock,              # mp.Lock
    output_jsonl_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
):
    cooldown_until = 0.0

    print(f"[{ip}] before get resources", flush=True)
    resources = get_resources_for_process(CONFIG_PATH)
    print(f"[{ip}] after get resources (shared if fork)", flush=True)

    agent = TravelAgent(
        config_path=CONFIG_PATH,
        base_url=f"http://{ip}:8000/v1",
        api_key="EMPTY",
        timeout_sec=1200.0,
        attraction_evaluator=resources["attraction_evaluator"],
        restaurant_evaluator=resources["restaurant_evaluator"],
        hotel_evaluator=resources["hotel_evaluator"],
        transportation_evaluator=resources["transportation_evaluator"],
        general_evaluator=resources["general_evaluator"],
    )

    def set_cooldown():
        nonlocal cooldown_until
        cooldown_until = max(cooldown_until, time.time() + COOLDOWN_SECONDS)
        print(
            f"[{ip}] ERROR -> cooldown {COOLDOWN_SECONDS}s until "
            f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(cooldown_until))}",
            flush=True
        )

    def wait_cooldown():
        while not stop_event.is_set() and time.time() < cooldown_until:
            time.sleep(1.0)

    def _is_max_tokens_bad_request(e: Exception) -> bool:
        """
        只针对你贴的这一类错误：
        Error code: 400 - {'error': {'message': 'max_tokens must be at least 1, got -793.' ...}}
        命中则：不 cooldown + 不重试 + 丢弃该 item
        """
        try:
            s = str(e)
        except Exception:
            s = repr(e)

        # 最稳：直接匹配 message 关键句
        if "max_tokens must be at least 1" in s:
            return True

        # 兜底：有的实现把 400/BadRequest 混在一起（可选）
        # 但你说只处理 max token，就不扩大范围了
        return False

    def worker_loop(worker_id: int):
        while True:
            wait_cooldown()

            if stop_event.is_set():
                try:
                    item = task_q.get(timeout=1.0)
                except queue.Empty:
                    return
                task_q.task_done()
                continue

            try:
                item = task_q.get(timeout=1.0)
            except queue.Empty:
                continue

            try:
                summary = agent.run_multiturn_trip(
                    meta_dict=item,
                    dialog_turns=dialog_turns,
                    keep_last_assistant=keep_last_assistant,
                    extra_save_dir="__unused__",
                )

                record = {
                    "ip": ip,
                    "trip_id": summary.get("trip_id"),
                    "dialog_turns": summary.get("dialog_turns"),
                    "summary": summary,
                }
                _append_jsonl(file_lock, output_jsonl_path, record)

                with done_counter.get_lock():
                    done_counter.value += 1
                    done = done_counter.value

                if done >= total_tasks:
                    stop_event.set()

            except Exception as e:
                # === 新规则：max_tokens 这类 400 -> 不 cooldown，不 requeue，直接丢弃该任务 ===
                if _is_max_tokens_bad_request(e):
                    print(f"[{ip}] NON-RETRYABLE(max_tokens) -> drop: {e}", flush=True)

                    # 丢弃也算“完成一个任务”，避免 total_tasks 永远达不到
                    with done_counter.get_lock():
                        done_counter.value += 1
                        done = done_counter.value
                    if done >= total_tasks:
                        stop_event.set()

                    # 关键：不 requeue，不 cooldown
                else:
                    # === 其他异常：保持你原来的处理（requeue + cooldown）===
                    if not stop_event.is_set():
                        item["_retry"] = int(item.get("_retry", 0) or 0) + 1
                        task_q.put(item)

                    print(f"[{ip}] trajectory failed -> requeue: {e}", flush=True)
                    # print(traceback.format_exc(), flush=True)
                    set_cooldown()

            finally:
                task_q.task_done()

    threads = []
    for wid in range(REQUESTS_PER_IP):
        t = threading.Thread(target=worker_loop, args=(wid,), daemon=True)
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

    print(f"[{ip}] process exit", flush=True)


def _pick_mp_context():
    """
    Linux 优先 fork（避免 spawn 导致每个子进程重复 import / 大加载）
    不支持 fork 的环境自动 fallback 到 spawn
    """
    try:
        methods = mp.get_all_start_methods()
    except Exception:
        methods = ["spawn"]

    if "fork" in methods:
        return mp.get_context("fork")
    return mp.get_context("spawn")


def run_multiturn_file_distributed(
    test_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
    output_jsonl_path: str,
):
    ctx = _pick_mp_context()
    print(f"[runner] mp start method = {ctx.get_start_method()}", flush=True)

    if ctx.get_start_method() == "fork":
        preload_resources_in_parent(CONFIG_PATH)

    with open(test_path, "r", encoding="utf-8") as f:
        items = json.load(f)

    os.makedirs(os.path.dirname(output_jsonl_path), exist_ok=True)
    # ✅ 不覆盖：如果文件不存在才创建空文件
    if not os.path.exists(output_jsonl_path):
        with open(output_jsonl_path, "w", encoding="utf-8") as f:
            pass

    # ✅ 默认：跑过的不跑（从 output jsonl 扫一遍）
    done_keys = load_done_keys_from_output_jsonl(output_jsonl_path)
    print(f"[runner] already_done={len(done_keys)} (from {output_jsonl_path})", flush=True)

    # ✅ 过滤任务：按 sample_key(trip_id + canonical(final_id_list))
    filtered = []
    seen_in_this_run = set()
    skipped_done = 0
    skipped_dup = 0

    for it in items:
        trip_id = it.get("trip_id")

        # ✅ 关键修改：仅当 dialog_turns>1（multi-turn）时，用 instruction_ids_basic 做 key；
        # 否则（single-turn）保持你原来的 final_id_list_from_meta 口径。
        if dialog_turns > 1:
            fil = it.get("instruction_ids_basic", [])
        else:
            fil = final_id_list_from_meta(it)

        sk = make_sample_key(trip_id, fil)

        if sk in done_keys:
            skipped_done += 1
            continue
        if sk in seen_in_this_run:
            skipped_dup += 1
            continue

        seen_in_this_run.add(sk)
        filtered.append(it)

    total = len(filtered)
    print(
        f"[runner] loaded={len(items)} queued={total} "
        f"skipped_done={skipped_done} skipped_dup_in_test={skipped_dup}",
        flush=True
    )

    if total == 0:
        print(f"[Done] nothing to run. output={output_jsonl_path}", flush=True)
        return

    task_q = ctx.JoinableQueue()
    done_counter = ctx.Value("i", 0)
    stop_event = ctx.Event()
    file_lock = ctx.Lock()

    procs = []
    for ip in IPS:
        p = ctx.Process(
            target=_ip_process_main,
            args=(
                ip,
                task_q,
                done_counter,
                total,          # ✅ 注意：这里用过滤后的 total
                stop_event,
                file_lock,
                output_jsonl_path,
                dialog_turns,
                keep_last_assistant,
            ),
            daemon=False,
        )
        p.start()
        procs.append(p)

    for it in filtered:
        task_q.put(it)

    task_q.join()

    stop_event.set()
    for p in procs:
        p.join()

    print(f"[Done] total={total}, finished={done_counter.value}, output={output_jsonl_path}", flush=True)


# ============================================================
#                         MAIN
# ============================================================

if __name__ == "__main__":


    TEST_PATH = "    OUTPUT_JSONL_PATH = "   
    run_multiturn_file_distributed(
        test_path=TEST_PATH,
        dialog_turns=10,                   # 改成 DIALOGUE_TURNS 即多轮
        keep_last_assistant=KEEP_LAST_ASSISTANT,
        output_jsonl_path=OUTPUT_JSONL_PATH,
    )

    # TEST_PATH = "    # OUTPUT_JSONL_PATH = "    # run_multiturn_file_distributed(
    #     test_path=TEST_PATH,
    #     dialog_turns=10,                   # 改成 DIALOGUE_TURNS 即多轮
    #     keep_last_assistant=KEEP_LAST_ASSISTANT,
    #     output_jsonl_path=OUTPUT_JSONL_PATH,
    # )

    