import copy
import sys
import os
import json
import re
import time
import queue
import threading
import traceback
import multiprocessing as mp
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime

import yaml
from openai import OpenAI

# ====== PATH 修复 ======
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)
print("Added ROOT:", ROOT, flush=True)

# ====== 引入工具 ======
from Tools.General.General_tools import GeneralTool
from Tools.Attraction.Attraction_tools import AttractionTool
from Tools.Flight.Flight_tools import FlightTool
from Tools.Train.Train_tools import TrainTool
from Tools.Restaurant.Restaurant_tools import RestaurantTool
from Tools.Hotel.Hotel_tools import HotelTool

from Tools.tool_description import (
    Flight_tool_description,
    Train_tool_description,
    Hotel_tool_description,
    Attraction_tool_description,
    Restaurant_tool_description,
    General_tool_description,
)

from interact.prompt.agent_system_prompt import system_prompt_en, system_prompt_en_single_turn
from interact.prompt.user_prompt import user_prompt_easy_en, user_prompt_mid_en, user_prompt_easy_en_no_issue
from interact.user_agent import UserSimulator
from interact.evaluation_agent import TripPlanEvaluator

from Evaluation.Attraction.val_attractions import AttractionEvaluator
from Evaluation.Restaurant.val_restaurants import RestaurantEvaluator
from Evaluation.Hotel.val_hotels import HotelEvaluator
from Evaluation.Transportation.val_transports import TransportationEvaluator
from Evaluation.General.val_general import GeneralEvaluator, load_or_build_cache


# ============================================================
#                    CONFIG / RUNNER CONFIG
# ============================================================

CONFIG_PATH = " =========================
# ✅ Agent/User 独立配置
# =========================

@dataclass
class AgentLLMConfig:
    model_name: str = "qwen2.5-32b-instruct"
    temperature: float = 0.7
    max_tokens: int = 32 * 1024
    tool_choice: str = "auto"
    thinking: bool = False
    extra_body: Optional[Dict[str, Any]] = None  # 额外透传给服务端
    timeout_sec: float = 1200.0
    base_url: Optional[str] = None
    api_key: str = "1737787093780320300"


@dataclass
class UserLLMConfig:
    model_name: str = "deepseek-v32"
    temperature: float = 0.7
    max_tokens: int = 32 * 1024
    retry_temperature: float = 0.7
    extra_body: Optional[Dict[str, Any]] = None
    timeout_sec: float = 120.0
    base_url: Optional[str] = None
    api_key: str = "1737787093780320300"



USER_IP_POOL = [ 
    "33.253.151.48",
    "33.253.71.98",
]

DEFAULT_USER_URLS = [f"http://{ip}:8000/v1" for ip in USER_IP_POOL]

# USER_CFG.base_url 仍然填第一个，真正容错在 TravelAgent 内切换
USER_CFG = UserLLMConfig(base_url=DEFAULT_USER_URLS[0])



REQUESTS_PER_IP = 12

AGENT_MODEL_NAME = "Toucan-Qwen2.5-32B-Instruct-v0.1"
AGENT_IP = "33.253.151.204"
DEFAULT_AGENT_URL = f"http://{AGENT_IP}:8000/v1"

AGENT_CFG = AgentLLMConfig(model_name=AGENT_MODEL_NAME, base_url=DEFAULT_AGENT_URL)


# "LongCat-Flash-Chat"
# "LongCat-Flash-Thinking"
# "qwen2.5-32b-instruct"
# "Qwen2.5-14B-Instruct"
# "Toucan-Qwen2.5-14B-Instruct-v0.1"
# "Toucan-Qwen2.5-32B-Instruct-v0.1"


# IP 冷却
COOLDOWN_SECONDS = 1 * 10

# 多轮
KEEP_LAST_ASSISTANT = 3

OUTPUT_DIR = " ============================================================
#           进程内资源初始化（避免 spawn 重复模块级大加载）
# ============================================================
GLOBAL_RESOURCES = None  # fork 后子进程继承（CoW）

def _load_resources_for_process(config_path: str) -> Dict[str, Any]:
    """
    在每个子进程里调用一次：加载 config、数据、cache、evaluator。
    避免 spawn 模式下 import 阶段就卡住/重复重建大对象。
    """
    with open(config_path, "r", encoding="utf-8") as f:
        config = yaml.safe_load(f)

    data_path = config["data_path"]
    attraction_path = data_path["attraction"]
    restaurant_path = data_path["restaurant"]
    hotel_path = data_path["hotel"]
    flight_path = data_path["flight"]
    train_path = data_path["train"]

    # evaluators
    with open(attraction_path, "r", encoding="utf-8") as f:
        attractions_data = json.load(f)

    attraction_evaluator = AttractionEvaluator(attractions_data)
    restaurant_evaluator = RestaurantEvaluator(restaurants_path=restaurant_path)
    hotel_evaluator = HotelEvaluator(hotel_path=hotel_path)
    transportation_evaluator = TransportationEvaluator(flights_path=flight_path)

    # 这个通常很重：放到进程内初始化，但只做一次
    data_cache = load_or_build_cache(config_path)
    general_evaluator = GeneralEvaluator(data_cache)

    return {
        "config": config,
        "attraction_evaluator": attraction_evaluator,
        "restaurant_evaluator": restaurant_evaluator,
        "hotel_evaluator": hotel_evaluator,
        "transportation_evaluator": transportation_evaluator,
        "general_evaluator": general_evaluator,
    }


# ============================================================
# 工具调用封装类（保持你原来的结构）
# ============================================================

class FunctionCall:
    def __init__(self, data: Dict[str, Any]):
        self.name = data.get("name")
        self.arguments = data.get("arguments")


class ToolCall:
    def __init__(self, data: Dict[str, Any]):
        self.id = data.get("id")
        self.type = data.get("type")
        self.function = FunctionCall(data.get("function") or {})


class AssistantMessage:
    def __init__(self, data: Dict[str, Any]):
        self.role = data.get("role")
        self.content = data.get("content")
        self.tool_calls = [ToolCall(tc) for tc in (data.get("tool_calls") or [])]


def serialize_tool_calls(tool_calls):
    if not tool_calls:
        return None
    out = []
    for call in tool_calls:
        out.append({
            "id": call.id,
            "type": call.type,
            "function": {
                "name": call.function.name,
                "arguments": call.function.arguments,
            }
        })
    return out


# ============================================================
# 每轮用户状态记录
# ============================================================

@dataclass
class RoundUserState:
    round_idx: int
    id_list_before: List[str]
    id_list_after: List[str]
    blocks: Dict[str, str]                # HISTORY/NEW/MODIFY/ISSUE
    sim_prompt: Optional[str]
    sim_output: Optional[Dict[str, Any]]
    newly_added_ids: List[str]
    all_added_ids: List[str]
    final_answer: str
    eval_result: Dict[str, Any]
    use_next_turn_reward: Optional[bool]
    start_time: Optional[str]
    end_time: Optional[str]
    tool_calls: List[Any]


# ============================================================
#                  TravelAgent 主类（OpenAI SDK 版本）
# ============================================================

class TravelAgent:
    """
    每个进程只初始化一个 TravelAgent。
    为了线程安全与并发，OpenAI client 使用 thread-local（每线程一个 client）。
    """
    def __init__(
        self,
        *,
        config_path: str,
        agent_cfg: AgentLLMConfig,
        user_cfg: UserLLMConfig,
        user_prompt_template: str = user_prompt_easy_en,  # ✅ 新增
        timeout_sec: float = 1200.0,  # 兼容旧参数（不用也不影响）
        # ✅ 新增：user 端 base_url 后备列表
        user_base_urls: Optional[List[str]] = None,
        # evaluators 注入（进程内初始化一次）
        attraction_evaluator: AttractionEvaluator,
        restaurant_evaluator: RestaurantEvaluator,
        hotel_evaluator: HotelEvaluator,
        transportation_evaluator: TransportationEvaluator,
        general_evaluator: GeneralEvaluator,
    ):
        self.agent_cfg = agent_cfg
        self.user_cfg = user_cfg
        self.user_prompt_template = user_prompt_template  # ✅ 新增

        # ✅ agent/user base_url & api_key & timeout 独立
        self._agent_base_url = agent_cfg.base_url
        self._agent_api_key = agent_cfg.api_key
        self._agent_timeout_sec = agent_cfg.timeout_sec

        self._user_base_url = user_cfg.base_url
        self._user_api_key = user_cfg.api_key
        self._user_timeout_sec = user_cfg.timeout_sec

        self._tlocal = threading.local()

        # ---------------------------
        # ✅ user base_url pool（主+后备）
        # ---------------------------
        urls = user_base_urls or ([self._user_base_url] if self._user_base_url else [])
        seen = set()
        self._user_base_urls = []
        for u in urls:
            if u and u not in seen:
                self._user_base_urls.append(u)
                seen.add(u)

        if not self._user_base_urls:
            self._user_base_urls = [self._user_base_url]

        try:
            self._user_url_idx = (
                self._user_base_urls.index(self._user_base_url)
                if self._user_base_url in self._user_base_urls
                else 0
            )
        except Exception:
            self._user_url_idx = 0
            if self._user_base_urls:
                self._user_base_url = self._user_base_urls[0]

        self._user_url_lock = threading.Lock()

        # 工具类
        self.general_tool = GeneralTool()
        self.attraction_tool = AttractionTool(config_path)
        self.flight_tool = FlightTool(config_path)
        self.train_tool = TrainTool(config_path)
        self.restaurant_tool = RestaurantTool(config_path)
        self.hotel_tool = HotelTool(config_path)

        # 工具映射
        self.tools = {
            "search_flights": self.flight_tool.search_flights,
            "get_flight_detail_with_products": self.flight_tool.get_flight_detail_with_products,
            "get_airport_coordinates": self.flight_tool.get_airport_coordinates,

            "search_trains": self.train_tool.search_trains,
            "get_train_detail_with_products": self.train_tool.get_train_detail_with_products,
            "get_station_coordinates": self.train_tool.get_station_coordinates,

            "search_hotels": self.hotel_tool.search_hotels,
            "get_hotel_detail_with_products": self.hotel_tool.get_hotel_detail_with_products,
            "get_hotel_coordinates": self.hotel_tool.get_hotel_coordinates,

            "search_attractions": self.attraction_tool.search_attractions,
            "get_attraction_detail_with_products": self.attraction_tool.get_attraction_detail_with_products,
            "get_attraction_coordinates": self.attraction_tool.get_attraction_coordinates,

            "search_restaurants": self.restaurant_tool.search_restaurants,
            "get_restaurant_detail_with_products": self.restaurant_tool.get_restaurant_detail_with_products,
            "get_restaurant_coordinates": self.restaurant_tool.get_restaurant_coordinates,

            "get_route_estimate": self.general_tool.get_route_estimate,
            "get_city_center_coords": self.general_tool.get_city_center_coords,
            "get_date_after": self.general_tool.get_date_after,
        }

        # 工具描述
        self.tool_descriptions = (
            Flight_tool_description
            + Train_tool_description
            + Hotel_tool_description
            + Attraction_tool_description
            + Restaurant_tool_description
            + General_tool_description
        )

        # Evaluator
        self.trip_plan_evaluator = TripPlanEvaluator(
            attraction_evaluator=attraction_evaluator,
            restaurant_evaluator=restaurant_evaluator,
            hotel_evaluator=hotel_evaluator,
            transportation_evaluator=transportation_evaluator,
            general_evaluator=general_evaluator,
        )


    def _get_agent_client(self) -> OpenAI:
        if not hasattr(self._tlocal, "agent_client"):
            self._tlocal.agent_client = OpenAI(
                api_key=self._agent_api_key,
                base_url=self._agent_base_url,
                timeout=self._agent_timeout_sec,
            )
        return self._tlocal.agent_client

    def _get_user_client(self) -> OpenAI:
        if not hasattr(self._tlocal, "user_client"):
            self._tlocal.user_client = OpenAI(
                api_key=self._user_api_key,
                base_url=self._user_base_url,
                timeout=self._user_timeout_sec,
            )
        return self._tlocal.user_client

    def _get_user_simulator(self) -> UserSimulator:
        if not hasattr(self._tlocal, "user_sim"):
            self._tlocal.user_sim = UserSimulator(
                client=self._get_user_client(),
                model_name=self.user_cfg.model_name,
                temperature=self.user_cfg.temperature,
                max_tokens=self.user_cfg.max_tokens,
            )
        return self._tlocal.user_sim

    def _invalidate_user_thread_local(self):
        """只清掉当前线程的 user client/sim，让它们按新 base_url 重建。"""
        if hasattr(self._tlocal, "user_client"):
            delattr(self._tlocal, "user_client")
        if hasattr(self._tlocal, "user_sim"):
            delattr(self._tlocal, "user_sim")
    
    def rotate_user_base_url(self, reason: str = "") -> str:
        """
        切到后备 base_url（轮询），并让当前线程下次请求重建 client/sim。
        返回新的 base_url。
        """
        with self._user_url_lock:
            if not self._user_base_urls:
                return self._user_base_url

            old = self._user_base_url
            self._user_url_idx = (self._user_url_idx + 1) % len(self._user_base_urls)
            self._user_base_url = self._user_base_urls[self._user_url_idx]

        self._invalidate_user_thread_local()
        print(f"[user-url] rotate: {old} -> {self._user_base_url}. reason={reason}", flush=True)
        return self._user_base_url


    # ======================== 核心 Chat Loop ========================
    def chat(self, history: Optional[List[Dict[str, str]]] = None, max_rounds=1000):
        if history is None:
            history = []

        trace_log = {
            "start_time": datetime.now().isoformat(),
            "messages": [],
            "trace_log_messages": [],
            "tool_calls": [],
            "final_answer": None,
        }

        messages = copy.deepcopy(history)
        trace_log["trace_log_messages"].extend(messages)

        for _ in range(max_rounds):
            client = self._get_agent_client()
            if self.agent_cfg.model_name == "deepseek-v32":
                extra_body = {"chat_template_kwargs": {"thinking": bool(self.agent_cfg.thinking)}}
                if isinstance(self.agent_cfg.extra_body, dict):
                    extra_body.update(self.agent_cfg.extra_body)

                resp = client.chat.completions.create(
                    model=self.agent_cfg.model_name,
                    messages=messages,
                    tools=self.tool_descriptions,
                    tool_choice=self.agent_cfg.tool_choice,
                    temperature=self.agent_cfg.temperature,
                    max_tokens=self.agent_cfg.max_tokens,
                    extra_body=extra_body,
                )
            else:
                resp = client.chat.completions.create(
                    model=self.agent_cfg.model_name,
                    messages=messages,
                    tools=self.tool_descriptions,
                    tool_choice=self.agent_cfg.tool_choice,
                    temperature=self.agent_cfg.temperature,
                    max_tokens=self.agent_cfg.max_tokens,
                )
            # print(resp)
            assistant_msg_dict = resp.choices[0].message.model_dump()
            messages.append(assistant_msg_dict)

            assistant_msg = AssistantMessage(assistant_msg_dict)
            trace_log["trace_log_messages"].append({
                "role": "assistant",
                "content": assistant_msg.content,
                "tool_calls": serialize_tool_calls(assistant_msg.tool_calls),
            })

            tool_calls = assistant_msg.tool_calls

            # ---------- 没有工具调用 = 最终答案 ----------
            if not tool_calls:
                final_answer = assistant_msg.content or ""
                trace_log["messages"] = messages
                trace_log["final_answer"] = final_answer
                trace_log["end_time"] = datetime.now().isoformat()
                return trace_log

            # ---------- 工具调用 ----------
            for call in tool_calls:
                tool_name = call.function.name
                raw_args = call.function.arguments or "{}"

                try:
                    args = json.loads(raw_args)
                except Exception:
                    args = {}

                tool_func = self.tools.get(tool_name)

                if tool_func is None:
                    result = {"error": f"Unknown tool {tool_name}"}
                    trace_log["tool_calls"].append({
                        "tool": tool_name,
                        "args": args,
                        "result": {"success": False, "error": result["error"]},
                    })
                else:
                    try:
                        output = tool_func(**args)
                        result = output
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": True, "data": output},
                        })
                    except Exception as e:
                        err = str(e)
                        result = {"error": err}
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": False, "error": err},
                        })

                if isinstance(result, str):
                    content = result
                else:
                    content = json.dumps(result, ensure_ascii=False)

                tool_msg = {
                    "role": "tool",
                    "tool_call_id": call.id,
                    "name": tool_name,
                    "content": content,
                }

                messages.append(tool_msg)
                trace_log["trace_log_messages"].append(tool_msg)

        trace_log["messages"] = messages
        trace_log["final_answer"] = "Too many rounds."
        trace_log["end_time"] = datetime.now().isoformat()
        return trace_log

    # ============================================================
    #                  Multi-turn helpers（静态方法）
    # ============================================================

    @staticmethod
    def _safe_format(desc: str, slot: str) -> str:
        try:
            return desc.format(slot=slot)
        except Exception:
            return (desc or "").replace("{slot}", slot)

    @staticmethod
    def _build_id2instruction(meta_dict: Dict[str, Any]) -> Dict[str, str]:
        id2instr = {}
        rubric_results = meta_dict.get("rubric_results", {}) or {}

        for _, rubric_group in rubric_results.items():
            for _, cfg in (rubric_group or {}).items():
                desc_tmpl = cfg.get("description", "") or ""
                result = cfg.get("result", {}) or {}
                all_labels = result.get("all_labels_and_ranges", {}) or {}

                for label_text, payload in all_labels.items():
                    _id = (payload or {}).get("_id")
                    if not _id:
                        continue
                    instr = TravelAgent._safe_format(desc_tmpl, label_text).strip()
                    id2instr[_id] = instr
        return id2instr

    @staticmethod
    def _flatten_applied_chains(meta_dict: Dict[str, Any]) -> List[str]:
        chains = meta_dict.get("applied_modification_chains") or {}
        out = []
        for _, chain in chains.items():
            if chain:
                out.append(chain[-1])
        return out

    @staticmethod
    def _build_blocks(meta_dict: Dict[str, Any], id_list: List[str], rubric_progress: Dict[str, int]) -> Dict[str, str]:
        id2instr = TravelAgent._build_id2instruction(meta_dict)

        # ===== HISTORY =====
        history_lines = []
        for cid in id_list:
            txt = id2instr.get(cid)
            if txt:
                history_lines.append(f"[ID: HIS_{cid}] {txt}")

        chains = meta_dict.get("applied_modification_chains", {}) or {}
        new_candidates = []
        modify_items = []  # (rubric_key, expected_id)

        for rubric_key in sorted(chains.keys()):
            chain = chains[rubric_key] or []
            if not chain:
                continue

            prog = rubric_progress.get(rubric_key, 0)

            if prog >= len(chain):
                continue

            if prog == 0:
                new_candidates.append(chain[0])
            else:
                modify_items.append((rubric_key, chain[prog]))

        # ===== NEW =====
        new_lines = []
        for cid in new_candidates:
            txt = id2instr.get(cid)
            if txt:
                new_lines.append(f"[ID: NEW_{cid}] {txt}")

        # ===== MODIFY =====
        modify_lines = []
        for rubric_key, expected_id in modify_items:
            chain = chains.get(rubric_key, [])
            prog = rubric_progress.get(rubric_key, 0)
            expected_txt = id2instr.get(expected_id)
            if not expected_txt:
                continue

            if expected_id in chain[:prog]:
                removed_id = chain[prog - 1]
                removed_txt = id2instr.get(removed_id, "")
                line = (
                    f"[ID: MOD_{expected_id}] {expected_txt} "
                    f"(Note: the previous requirement {removed_txt} is no longer needed)"
                )
            else:
                line = f"[ID: MOD_{expected_id}] {expected_txt}"

            modify_lines.append(line)

        return {
            "HISTORY": "\n".join(history_lines).strip(),
            "NEW": "\n".join(new_lines).strip(),
            "MODIFY": "\n".join(modify_lines).strip(),
        }

    @staticmethod
    def _history_text_for_sim(conv_simple: List[Dict[str, str]], keep_last_assistant: int = 3) -> str:
        OMIT_TEXT = "Earlier assistant responses are omitted due to context length limits."
        assistant_indices = [i for i, m in enumerate(conv_simple) if m["role"] == "assistant"]
        keep_indices = set(assistant_indices[-keep_last_assistant:])

        lines = []
        for i, m in enumerate(conv_simple):
            if m["role"] == "user":
                lines.append(f"User: {m['content']}")
            elif m["role"] == "assistant":
                if i in keep_indices:
                    lines.append(f"Assistant: {m['content']}")
                else:
                    lines.append(f"Assistant: {OMIT_TEXT}")
        return "\n".join(lines).strip()

    def _apply_proposed_ids_with_prefix(
        self,
        meta_dict: Dict[str, Any],
        current_id_list: List[str],
        rubric_progress: Dict[str, int],
        proposed_ids: List[Any],
    ) -> Tuple[bool, List[str], List[str], List[str]]:
        chains = meta_dict.get("applied_modification_chains", {}) or {}
        id2pos: Dict[str, Tuple[str, int]] = {}
        for rubric_key, chain in chains.items():
            for idx, cid in enumerate(chain):
                id2pos[cid] = (rubric_key, idx)

        errors: List[str] = []
        newly_added: List[str] = []
        removed: List[str] = []

        if not isinstance(proposed_ids, list):
            return False, [], [], [f"instruction_ids is not a list: {type(proposed_ids)}"]

        for raw in proposed_ids:
            if not isinstance(raw, str) or not raw:
                errors.append(f"invalid instruction_id (not str/empty): {raw!r}")
                continue

            if raw.startswith("NEW_"):
                op = "NEW"
                base_id = raw[len("NEW_"):]
            elif raw.startswith("MOD_"):
                op = "MOD"
                base_id = raw[len("MOD_"):]
            else:
                continue

            if base_id not in id2pos:
                errors.append(f"{raw} -> {base_id} not found in applied_modification_chains")
                continue

            rubric_key, idx = id2pos[base_id]
            chain = chains.get(rubric_key, []) or []

            if op == "NEW":
                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)
                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

            else:  # MOD
                if idx <= 0:
                    errors.append(f"{raw}: target is first element of chain, no previous to remove")
                    continue

                prev_id = chain[idx - 1]
                if prev_id not in current_id_list:
                    errors.append(f"{raw}: expected to remove prev_id={prev_id}, but it's not in current_id_list")
                    continue

                try:
                    current_id_list.remove(prev_id)
                    removed.append(prev_id)
                except ValueError:
                    errors.append(f"{raw}: failed to remove prev_id={prev_id} (ValueError)")
                    continue

                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)

                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

        ok = (len(errors) == 0)
        return ok, newly_added, removed, errors

    def rubric_all_satisfied(self, rubric_progress: dict, applied_chains: dict) -> bool:
        for rubric, chains in applied_chains.items():
            required = len(chains)
            current = rubric_progress.get(rubric, 0)
            if current < required:
                return False
        return True

    def select_raw_answer_for_eval(self, current_raw_answer: str, conv_simple: List[Dict]) -> str:
        def parse_trip_plan_json(raw_text: str) -> Optional[dict]:
            # 1) 纯 JSON
            try:
                data = json.loads(raw_text)
                return data if isinstance(data, dict) else None
            except Exception:
                pass

            # 2) ```json ... ```
            m = re.search(r"```json\s*(\{.*?\})\s*```", raw_text, re.S)
            if m:
                try:
                    data = json.loads(m.group(1))
                    return data if isinstance(data, dict) else None
                except Exception:
                    return None

            return None

        def should_fallback_only_if_empty_schedule(raw_text: str) -> bool:
            data = parse_trip_plan_json(raw_text)
            if not isinstance(data, dict):
                return False

            tp = data.get("trip_plan")
            if not isinstance(tp, dict):
                return False

            # 关键：必须有 daily_schedule 字段，且必须是“空 list”
            if "daily_schedule" not in tp:
                return False

            ds = tp.get("daily_schedule")
            return isinstance(ds, list) and len(ds) == 0

        if not should_fallback_only_if_empty_schedule(current_raw_answer):
            return current_raw_answer

        def extract_nonempty_schedule(raw_text: str) -> bool:
            data = parse_trip_plan_json(raw_text)
            if not isinstance(data, dict):
                return False
            tp = data.get("trip_plan")
            if not isinstance(tp, dict):
                return False
            ds = tp.get("daily_schedule")
            return isinstance(ds, list) and len(ds) > 0

        for msg in reversed(conv_simple):
            if msg.get("role") != "assistant":
                continue
            raw_text = msg.get("content", "")
            if extract_nonempty_schedule(raw_text):
                return raw_text

        return current_raw_answer

    def run_multiturn_trip(
        self,
        meta_dict: Dict[str, Any],
        dialog_turns: int = 1,
        keep_last_assistant: int = 3,
        extra_save_dir: str = "__unused__",
        max_round_retry: int = 3,          # ✅ 新增：单轮失败最多重跑次数（回滚上一轮后重跑当前轮）
        round_retry_sleep: float = 0.0,    # ✅ 新增：每次失败后 sleep（可选）
    ) -> Dict[str, Any]:

        trip_id = meta_dict.get("trip_id", "unknown_trip")
        rubric_progress = copy.deepcopy(meta_dict.get("rubric_progress", {}))

        conv_full: List[Dict[str, str]] = []
        conv_simple: List[Dict[str, str]] = []
        rounds: List[RoundUserState] = []

        if dialog_turns == 1:
            current_user_query = meta_dict.get("query_with_constraints", meta_dict.get("query_basic", ""))
            current_id_list = self._flatten_applied_chains(meta_dict)
        else:
            current_user_query = meta_dict.get("query_basic", "")
            current_id_list = meta_dict.get("instruction_ids_basic", [])

        issue_text = ""

        for r in range(1, dialog_turns + 1):
            round_try = 0

            # ✅ 当前轮失败就回滚到上一轮状态，重跑当前轮
            while True:
                # ====== 快照：用于“本轮失败 -> 回滚到上一轮结束状态” ======
                snap_conv_full = copy.deepcopy(conv_full)
                snap_conv_simple = copy.deepcopy(conv_simple)
                snap_id_list = list(current_id_list)
                snap_progress = copy.deepcopy(rubric_progress)
                snap_issue_text = issue_text
                snap_rounds_len = len(rounds)
                snap_current_user_query = current_user_query

                try:
                    id_before = list(current_id_list)

                    blocks = self._build_blocks(meta_dict, current_id_list, rubric_progress)
                    blocks["ISSUE"] = issue_text or ""

                    sim_prompt = None
                    sim_output = None
                    newly_added = []
                    all_added = []
                    use_next_turn_reward = False

                    # ====== multi-turn: round >= 2 才模拟用户 ======
                    if r >= 2:
                        history_text = self._history_text_for_sim(conv_simple, keep_last_assistant)
                        max_retry = 3
                        retry = 0

                        while True:
                            sim_prompt = self.user_prompt_template \
                                .replace("{{HISTORY}}", blocks.get("HISTORY", "")) \
                                .replace("{{NEW}}", blocks.get("NEW", "")) \
                                .replace("{{MODIFY}}", blocks.get("MODIFY", "")) \
                                .replace("{{ISSUE}}", blocks.get("ISSUE", "")) \
                                .replace("{{HISTORY_MESSAGES}}", history_text)

                            # ✅ user_sim.generate 报错就轮询切 url 重试
                            max_url_switch = max(1, len(getattr(self, "_user_base_urls", []) or []))
                            url_switch_cnt = 0
                            last_err = None

                            while True:
                                user_sim = self._get_user_simulator()
                                try:
                                    if retry > 0:
                                        sim_output = user_sim.generate(sim_prompt, temperature=self.user_cfg.retry_temperature)
                                    else:
                                        sim_output = user_sim.generate(sim_prompt, temperature=self.user_cfg.temperature)
                                    last_err = None
                                    break
                                except Exception as e:
                                    last_err = e
                                    if url_switch_cnt < max_url_switch:
                                        url_switch_cnt += 1
                                        self.rotate_user_base_url(reason=f"user_sim.generate error: {e}")
                                        continue
                                    raise last_err

                            proposed_ids = sim_output.get("instruction_ids", []) or []
                            use_next_turn_reward = (
                                not proposed_ids or
                                all(isinstance(pid, str) and pid.startswith("HIS_") for pid in proposed_ids)
                            )

                            tmp_id_list = list(current_id_list)
                            tmp_progress = copy.deepcopy(rubric_progress)

                            ok, newly_added_tmp, _, errors = self._apply_proposed_ids_with_prefix(
                                meta_dict=meta_dict,
                                current_id_list=tmp_id_list,
                                rubric_progress=tmp_progress,
                                proposed_ids=proposed_ids,
                            )

                            if ok:
                                current_id_list = tmp_id_list
                                rubric_progress = tmp_progress
                                newly_added = newly_added_tmp
                                current_user_query = sim_output.get("user_query", "")
                                all_added = proposed_ids
                                break

                            retry += 1
                            if retry >= max_retry:
                                newly_added = []
                                current_user_query = sim_output.get("user_query", "")
                                break

                    # ====== system prompt 只在第 1 轮加入（失败回滚后重跑也不会重复，因为 conv_full 会回滚） ======
                    if r == 1:
                        if dialog_turns == 1:
                            conv_full.append({"role": "system", "content": system_prompt_en_single_turn})
                        else:
                            conv_full.append({"role": "system", "content": system_prompt_en})

                    # ====== 发送用户 query ======
                    conv_full.append({"role": "user", "content": current_user_query})
                    conv_simple.append({"role": "user", "content": current_user_query})

                    # ====== agent 生成 ======
                    trace_log = self.chat(history=conv_full)
                    conv_full = copy.deepcopy(trace_log["messages"])
                    conv_simple.append({"role": "assistant", "content": trace_log["final_answer"]})

                    # ====== eval ======
                    raw_answer_for_eval = self.select_raw_answer_for_eval(
                        current_raw_answer=trace_log["final_answer"],
                        conv_simple=conv_simple
                    )

                    eval_result = self.trip_plan_evaluator.evaluate(
                        meta_dict=meta_dict,
                        raw_text=raw_answer_for_eval,
                        id_list=current_id_list
                    )
                    issue_text = eval_result.get("error_text", "") or ""

                    round_state = RoundUserState(
                        round_idx=r,
                        id_list_before=id_before,
                        id_list_after=list(current_id_list),
                        blocks=blocks,
                        sim_prompt=sim_prompt,
                        sim_output=sim_output,
                        newly_added_ids=newly_added,
                        all_added_ids=all_added,
                        final_answer=trace_log["final_answer"],
                        tool_calls=trace_log["tool_calls"],
                        eval_result=eval_result,
                        use_next_turn_reward=use_next_turn_reward,
                        start_time=trace_log["start_time"],
                        end_time=trace_log["end_time"],
                    )
                    rounds.append(round_state)

                    applied_chains = meta_dict.get("applied_modification_chains", {})
                    if self.rubric_all_satisfied(rubric_progress, applied_chains):
                        print(f"[trip_id={trip_id}] All rubrics are proposed at round {r}", flush=True)

                        summary = {
                            "trip_id": trip_id,
                            "dialog_turns": dialog_turns,
                            "keep_last_assistant": keep_last_assistant,
                            "final_id_list": list(current_id_list),
                            "messages": conv_full,
                            "rounds": [asdict(x) for x in rounds],
                            "terminate_reason": None,
                        }
                        return summary

                    # ✅ 本轮成功：跳出 while，进入下一轮 r+1
                    break

                except Exception as e:
                    # ====== 本轮失败：回滚到上一轮结束状态 ======
                    conv_full = snap_conv_full
                    conv_simple = snap_conv_simple
                    current_id_list = snap_id_list
                    rubric_progress = snap_progress
                    issue_text = snap_issue_text
                    current_user_query = snap_current_user_query

                    # ✅ rounds 也回滚（避免半截 state 污染）
                    if len(rounds) > snap_rounds_len:
                        rounds[:] = rounds[:snap_rounds_len]

                    round_try += 1
                    print(f"[multiturn] round {r} failed (try {round_try}/{max_round_retry}): {e}", flush=True)

                    # ✅ 本轮重试耗尽：抛到 worker，让 worker requeue（那就会从头）
                    if round_try >= max_round_retry:
                        raise

                    if round_retry_sleep > 0:
                        time.sleep(round_retry_sleep)

                    # ✅ 继续 while：重跑当前轮 r（不从头）
                    continue

        summary = {
            "trip_id": trip_id,
            "dialog_turns": dialog_turns,
            "keep_last_assistant": keep_last_assistant,
            "final_id_list": list(current_id_list),
            "messages": conv_full,
            "rounds": [asdict(x) for x in rounds],
            "terminate_reason": None,  # ✅ 成功才返回，失败都 raise（因此不会落盘 partial）
        }
        return summary

# ============================================================
#                    DISTRIBUTED RUNNER
# ============================================================
def canonical_final_id_list(final_id_list):
    """去重+排序，保证顺序不影响相等判定"""
    if not isinstance(final_id_list, list):
        return tuple()
    return tuple(sorted({str(x) for x in final_id_list if x is not None}))

def make_sample_key(trip_id, final_id_list):
    return (str(trip_id), canonical_final_id_list(final_id_list))

def final_id_list_from_meta(meta: dict):
    """
    与你参考代码一致：
    - 优先 applied_modification_chains：每个 rubric_key 的 chain 取最后一个 id
    - 兜底：meta["final_id_list"]
    """
    chains = meta.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        out = []
        for _, chain in chains.items():
            if isinstance(chain, list) and chain:
                out.append(chain[-1])
        return out

    fil = meta.get("final_id_list")
    if isinstance(fil, list):
        return fil
    return []

def final_id_list_from_output_record(rec: dict):
    """
    从输出 jsonl 的一条 record 里取用于判重的 id_list：

    - dialog_turns > 1（multi-turn）：
        优先用 rounds[0].id_list_before（它在你的实现里就是第1轮开始前的 current_id_list，
        通常等于 meta["instruction_ids_basic"]），避免 summary.final_id_list 漂移（没跑完/中断）。

    - dialog_turns <= 1（single-turn）：
        保持原逻辑：优先 summary.final_id_list，兜底 rec.final_id_list
    """
    if not isinstance(rec, dict):
        return []

    summary = rec.get("summary") or {}

    dt = rec.get("dialog_turns")
    if dt is None:
        dt = summary.get("dialog_turns")

    try:
        dt = int(dt or 0)
    except Exception:
        dt = 0

    if dt > 1:
        rounds = summary.get("rounds")
        if isinstance(rounds, list) and len(rounds) > 0:
            r0 = rounds[0] or {}
            fil0 = r0.get("id_list_before")
            if isinstance(fil0, list):
                return fil0

        fil_basic = summary.get("instruction_ids_basic")
        if isinstance(fil_basic, list):
            return fil_basic

        fil = summary.get("final_id_list")
        if isinstance(fil, list):
            return fil

        fil2 = rec.get("final_id_list")
        if isinstance(fil2, list):
            return fil2

        return []

    fil = summary.get("final_id_list")
    if isinstance(fil, list):
        return fil
    fil2 = rec.get("final_id_list")
    if isinstance(fil2, list):
        return fil2
    return []

def load_done_keys_from_output_jsonl(output_jsonl_path: str):
    """
    扫 output jsonl，构建已经跑过的 sample_key 集合。
    不要求每行都有 eval；只要 summary.final_id_list 写了就算跑过。
    """
    done = set()
    if not output_jsonl_path or (not os.path.exists(output_jsonl_path)):
        return done

    with open(output_jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except Exception:
                continue

            trip_id = rec.get("trip_id") or (rec.get("summary") or {}).get("trip_id")
            fil = final_id_list_from_output_record(rec)
            done.add(make_sample_key(trip_id, fil))

    return done


def get_resources_for_process(config_path: str):
    """
    子进程入口获取资源：
    - fork：GLOBAL_RESOURCES 已继承，直接返回（不再 load）
    - spawn：GLOBAL_RESOURCES 为空，只能各进程自己 load 一份
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is not None:
        return GLOBAL_RESOURCES
    return _load_resources_for_process(config_path)

def preload_resources_in_parent(config_path: str):
    """
    只在父进程里调用（fork 前）。
    fork 后子进程会继承 GLOBAL_RESOURCES，并通过 Copy-on-Write 共享内存页。
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is None:
        print("[runner] preload resources in parent ...", flush=True)
        GLOBAL_RESOURCES = _load_resources_for_process(config_path)
        print("[runner] preload resources done.", flush=True)

        try:
            import gc
            gc.freeze()
            print("[runner] gc.freeze() done.", flush=True)
        except Exception:
            pass


def _append_jsonl(file_lock: mp.Lock, output_jsonl_path: str, obj: dict):
    line = json.dumps(obj, ensure_ascii=False)
    with file_lock:
        with open(output_jsonl_path, "a", encoding="utf-8") as f:
            f.write(line + "\n")


def _ip_process_main(
    task_q,                 # mp.JoinableQueue
    done_counter,           # mp.Value('i', 0)
    total_tasks: int,
    stop_event,             # mp.Event
    file_lock,              # mp.Lock
    output_jsonl_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
    user_prompt_template: str,   # ✅ 新增
):
    cooldown_until = 0.0

    print("[worker] before get resources", flush=True)
    resources = get_resources_for_process(CONFIG_PATH)
    print("[worker] after get resources (shared if fork)", flush=True)

    # ✅ agent/user 参数独立（来自 AGENT_CFG/USER_CFG）
    agent = TravelAgent(
        config_path=CONFIG_PATH,
        agent_cfg=AGENT_CFG,
        user_cfg=USER_CFG,
        user_prompt_template=user_prompt_template,  # ✅ 新增
        # ✅ 新增：把 user 端主+后备 URL 列表传进去
        user_base_urls=DEFAULT_USER_URLS,
        attraction_evaluator=resources["attraction_evaluator"],
        restaurant_evaluator=resources["restaurant_evaluator"],
        hotel_evaluator=resources["hotel_evaluator"],
        transportation_evaluator=resources["transportation_evaluator"],
        general_evaluator=resources["general_evaluator"],
    )

    def set_cooldown():
        nonlocal cooldown_until
        cooldown_until = max(cooldown_until, time.time() + COOLDOWN_SECONDS)
        print(
            "[worker] ERROR -> cooldown "
            f"{COOLDOWN_SECONDS}s until "
            f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(cooldown_until))}",
            flush=True
        )

    def wait_cooldown():
        while not stop_event.is_set() and time.time() < cooldown_until:
            time.sleep(1.0)

    def _is_max_tokens_bad_request(e: Exception) -> bool:
        """
        只针对你贴的这一类错误：
        Error code: 400 - {'error': {'message': 'max_tokens must be at least 1, got -793.' ...}}
        命中则：不 cooldown + 不重试 + 丢弃该 item
        """
        try:
            s = str(e)
        except Exception:
            s = repr(e)

        return "max_tokens must be at least 1" in s

    def worker_loop(worker_id: int):
        MAX_ITEM_RETRY = 3  # ✅ 每条数据最多“从头重跑”(requeue) 3 次

        while True:
            wait_cooldown()

            if stop_event.is_set():
                try:
                    item = task_q.get(timeout=1.0)
                except queue.Empty:
                    return
                task_q.task_done()
                continue

            try:
                item = task_q.get(timeout=1.0)
            except queue.Empty:
                continue

            try:
                summary = agent.run_multiturn_trip(
                    meta_dict=item,
                    dialog_turns=dialog_turns,
                    keep_last_assistant=keep_last_assistant,
                    extra_save_dir="__unused__",
                    max_round_retry=3,         # 单轮回滚重跑（不从头）次数
                    round_retry_sleep=0.0,
                )

                # ✅ 兜底：任何“非完整成功”的 summary 都不允许落盘
                if summary.get("terminate_reason"):
                    raise RuntimeError(f"multiturn terminated early: {summary.get('terminate_reason')}")

                record = {
                    "trip_id": summary.get("trip_id"),
                    "dialog_turns": summary.get("dialog_turns"),
                    "summary": summary,
                }
                _append_jsonl(file_lock, output_jsonl_path, record)

                with done_counter.get_lock():
                    done_counter.value += 1
                    done = done_counter.value

                print(
                    f"[worker] progress: {done}/{total_tasks} "
                    f"(trip_id={summary.get('trip_id')})",
                    flush=True
                )

                if done >= total_tasks:
                    stop_event.set()

            except Exception as e:
                # === 新规则：max_tokens 这类 400 -> 不 cooldown，不 requeue，直接丢弃 ===
                if _is_max_tokens_bad_request(e):
                    print(f"[worker] NON-RETRYABLE(max_tokens) -> drop: {e}", flush=True)

                    with done_counter.get_lock():
                        done_counter.value += 1
                        done = done_counter.value

                    print(f"[worker] progress: {done}/{total_tasks} (dropped)", flush=True)

                    if done >= total_tasks:
                        stop_event.set()

                else:
                    # ✅ 其他异常：限制每条数据“从头重跑”(requeue) 最多 3 次
                    if not stop_event.is_set():
                        cur_retry = int(item.get("_retry", 0) or 0)

                        if cur_retry >= MAX_ITEM_RETRY:
                            # 超过上限：不再 requeue，直接丢弃并计数
                            print(
                                f"[worker] RETRY_LIMIT({MAX_ITEM_RETRY}) -> drop: "
                                f"trip_id={item.get('trip_id')} err={e}",
                                flush=True
                            )

                            with done_counter.get_lock():
                                done_counter.value += 1
                                done = done_counter.value

                            print(f"[worker] progress: {done}/{total_tasks} (dropped_retry_limit)", flush=True)

                            if done >= total_tasks:
                                stop_event.set()

                            # 不 cooldown（可选）。如果你希望仍 cooldown，把下面 set_cooldown() 打开即可
                            # set_cooldown()

                        else:
                            # 还没到上限：requeue + cooldown
                            item["_retry"] = cur_retry + 1
                            task_q.put(item)

                            print(
                                f"[worker] trajectory failed -> requeue({_retry}/{MAX_ITEM_RETRY}): {e}"
                                .replace("{_retry}", str(item["_retry"])),
                                flush=True
                            )
                            set_cooldown()

            finally:
                task_q.task_done()

    threads = []
    for wid in range(REQUESTS_PER_IP):
        t = threading.Thread(target=worker_loop, args=(wid,), daemon=True)
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

    print("[worker] process exit", flush=True)



def _pick_mp_context():
    """
    Linux 优先 fork（避免 spawn 导致每个子进程重复 import / 大加载）
    不支持 fork 的环境自动 fallback 到 spawn
    """
    try:
        methods = mp.get_all_start_methods()
    except Exception:
        methods = ["spawn"]

    if "fork" in methods:
        return mp.get_context("fork")
    return mp.get_context("spawn")


def run_multiturn_file_distributed(
    test_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
    output_jsonl_path: str,
    user_prompt_template: str,   # ✅ 新增
):
    ctx = _pick_mp_context()
    print(f"[runner] mp start method = {ctx.get_start_method()}", flush=True)

    if ctx.get_start_method() == "fork":
        preload_resources_in_parent(CONFIG_PATH)

    with open(test_path, "r", encoding="utf-8") as f:
        items = json.load(f)

    os.makedirs(os.path.dirname(output_jsonl_path), exist_ok=True)
    if not os.path.exists(output_jsonl_path):
        with open(output_jsonl_path, "w", encoding="utf-8") as f:
            pass

    done_keys = load_done_keys_from_output_jsonl(output_jsonl_path)
    print(f"[runner] already_done={len(done_keys)} (from {output_jsonl_path})", flush=True)

    filtered = []
    seen_in_this_run = set()
    skipped_done = 0
    skipped_dup = 0

    for it in items:
        trip_id = it.get("trip_id")

        # multi-turn 用 instruction_ids_basic 做 key；single-turn 保持原逻辑
        if dialog_turns > 1:
            fil = it.get("instruction_ids_basic", [])
        else:
            fil = final_id_list_from_meta(it)

        sk = make_sample_key(trip_id, fil)

        if sk in done_keys:
            skipped_done += 1
            continue
        if sk in seen_in_this_run:
            skipped_dup += 1
            continue

        seen_in_this_run.add(sk)
        filtered.append(it)

    total = len(filtered)
    print(
        f"[runner] loaded={len(items)} queued={total} "
        f"skipped_done={skipped_done} skipped_dup_in_test={skipped_dup}",
        flush=True
    )

    if total == 0:
        print(f"[Done] nothing to run. output={output_jsonl_path}", flush=True)
        return

    task_q = ctx.JoinableQueue()
    done_counter = ctx.Value("i", 0)
    stop_event = ctx.Event()
    file_lock = ctx.Lock()

    # ✅ 只起一个 IP 进程
    procs = []
    p = ctx.Process(
        target=_ip_process_main,
        args=(
            task_q,
            done_counter,
            total,
            stop_event,
            file_lock,
            output_jsonl_path,
            dialog_turns,
            keep_last_assistant,
            user_prompt_template,  # ✅ 新增
        ),
        daemon=False,
    )
    p.start()
    procs.append(p)

    for it in filtered:
        task_q.put(it)

    task_q.join()

    stop_event.set()
    for p in procs:
        p.join()

    print(f"[Done] total={total}, finished={done_counter.value}, output={output_jsonl_path}", flush=True)


def make_output_jsonl_path(test_path: str, dialog_turns: int, output_dir: str) -> str:
    test_name = os.path.splitext(os.path.basename(test_path))[0]
    output_mode = "single" if dialog_turns == 1 else "multi"
    agent_name = AGENT_CFG.model_name

    return os.path.join(
        output_dir,
        f"{test_name}_{agent_name}_{output_mode}_t{dialog_turns}.jsonl"
    )


if __name__ == "__main__":

    TESTS = [
        ("test_easy", "        ("test_mid",  "    ]

    RUN_MODES = [
        ("single", 1),
        ("multi",  15),
    ]

    PROMPT_BY_TEST = {
        "test_easy": user_prompt_easy_en,
        "test_mid":  user_prompt_mid_en,
    }

    for test_tag, test_path in TESTS:
        user_prompt_template = PROMPT_BY_TEST.get(test_tag, user_prompt_easy_en)
        # print(user_prompt_template)

        for mode_tag, dialog_turns in RUN_MODES:
            output_jsonl_path = make_output_jsonl_path(
                test_path=test_path,
                dialog_turns=dialog_turns,
                output_dir=OUTPUT_DIR,
            )

            run_multiturn_file_distributed(
                test_path=test_path,
                dialog_turns=dialog_turns,
                keep_last_assistant=KEEP_LAST_ASSISTANT,
                output_jsonl_path=output_jsonl_path,
                user_prompt_template=user_prompt_template,  # ✅ 新增
            )
