import copy
import sys
import os
import json
import re
import time
import queue
import threading
import traceback
import multiprocessing as mp
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime

import yaml
from openai import OpenAI

# ====== PATH 修复 ======
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)
print("Added ROOT:", ROOT, flush=True)

# ====== 引入工具 ======
from Tools.General.General_tools import GeneralTool
from Tools.Attraction.Attraction_tools import AttractionTool
from Tools.Flight.Flight_tools import FlightTool
from Tools.Train.Train_tools import TrainTool
from Tools.Restaurant.Restaurant_tools import RestaurantTool
from Tools.Hotel.Hotel_tools import HotelTool

from Tools.tool_description import (
    Flight_tool_description,
    Train_tool_description,
    Hotel_tool_description,
    Attraction_tool_description,
    Restaurant_tool_description,
    General_tool_description,
)

from interact.prompt.agent_system_prompt import system_prompt_en, system_prompt_en_single_turn
from interact.prompt.user_prompt import user_prompt_easy_en, user_prompt_mid_en, user_prompt_easy_en_no_issue
from interact.prompt.user_prompt_style import user_prompt_hard_vague_style_en, USER_STYLE_DESCRIPTIONS
from interact.user_agent import UserSimulator
from interact.evaluation_agent import TripPlanEvaluator

from Evaluation.Attraction.val_attractions import AttractionEvaluator
from Evaluation.Restaurant.val_restaurants import RestaurantEvaluator
from Evaluation.Hotel.val_hotels import HotelEvaluator
from Evaluation.Transportation.val_transports import TransportationEvaluator
from Evaluation.General.val_general import GeneralEvaluator, load_or_build_cache


# ============================================================
#                    CONFIG / RUNNER CONFIG
# ============================================================

CONFIG_PATH = " _env_bool(name: str, default: bool) -> bool:
    val = os.environ.get(name)
    if val is None:
        return default
    return str(val).strip().lower() in {"1", "true", "yes", "y", "t"}

def _env_int(name: str) -> Optional[int]:
    val = os.environ.get(name)
    if val is None:
        return None
    try:
        return int(val)
    except Exception:
        return None

DEFAULT_THINKING = _env_bool("AGENT_THINKING", False)

# =========================
# ✅ Agent/User 独立配置
# =========================

@dataclass
class AgentLLMConfig:
    model_name: str = "deepseek-v32"
    temperature: float = 0.7
    max_tokens: int = 128 * 1024
    tool_choice: str = "auto"
    thinking: bool = DEFAULT_THINKING
    extra_body: Optional[Dict[str, Any]] = None  # 额外透传给服务端
    timeout_sec: float = 1200.0
    base_url: Optional[str] = None
    api_key: str = " "


@dataclass
class UserLLMConfig:
    model_name: str = "deepseek-v32"
    temperature: float = 0.7
    max_tokens: int = 64 * 1024
    retry_temperature: float = 0.7
    extra_body: Optional[Dict[str, Any]] = None
    timeout_sec: float = 360.0
    base_url: Optional[str] = None
    api_key: str = " "


# 33.253.241.234
# 33.253.71.98
# 33.253.249.222

USER_IP_POOL = []
_env_user_ip_pool = os.environ.get("USER_IP_POOL", "")
if _env_user_ip_pool:
    USER_IP_POOL = [ip.strip() for ip in _env_user_ip_pool.split(",") if ip.strip()]
if not USER_IP_POOL:
    USER_IP_POOL = [
        "33.253.241.234",
        "33.253.249.222"
    ]
DEFAULT_USER_URLS = [f"http://{ip}:8000/v1" for ip in USER_IP_POOL]

# USER_CFG.base_url 仍然填第一个，真正容错在 TravelAgent 内切换
# USER_CFG = UserLLMConfig(base_url=DEFAULT_USER_URLS[0])
USER_CFG = UserLLMConfig(base_url=DEFAULT_USER_URLS[0])


REQUESTS_PER_IP = int(os.environ.get("REQUESTS_PER_IP", "4"))

AGENT_MODEL_NAME = os.environ.get("AGENT_MODEL_NAME", "deepseek-v32")
AGENT_IP = os.environ.get("AGENT_IP", "33.253.249.222")

SELF_TRAINED_AGENT_MODELS = {
    "qwen2514_toucan_119k",
    "qwen2532_toucan_119k",
    "qwen2514_toucan_119k_travel_3k",
    "qwen2532_toucan_119k_travel_3k",
}

MAX_TOKENS_32K_AGENT_MODELS = {
    "qwen2.5-32b-instruct",
    "qwen3-32b",
    "gemini-3-flash-preview",
    "gemini-3-pro-preview",
}

def _resolve_agent_max_tokens(model_name: str) -> int:
    env_max_tokens = _env_int("AGENT_MAX_TOKENS")
    if env_max_tokens and env_max_tokens > 0:
        return env_max_tokens
    if model_name in SELF_TRAINED_AGENT_MODELS:
        return 20 * 1024
    if model_name in MAX_TOKENS_32K_AGENT_MODELS:
        return 32 * 1024
    return 128 * 1024

DEFAULT_AGENT_URL = os.environ.get("AGENT_BASE_URL")
if not DEFAULT_AGENT_URL:
    if AGENT_IP:
        DEFAULT_AGENT_URL = f"http://{AGENT_IP}:8000/v1"
    else:
        DEFAULT_AGENT_URL = "https://basicaiservice.sankuai.com/basicai/v1"

# AGENT_CFG = AgentLLMConfig(model_name=AGENT_MODEL_NAME, base_url=DEFAULT_AGENT_URL)
AGENT_CFG = AgentLLMConfig(
    model_name=AGENT_MODEL_NAME,
    base_url=DEFAULT_AGENT_URL,
    max_tokens=_resolve_agent_max_tokens(AGENT_MODEL_NAME),
)



# "LongCat-Flash-Chat"
# "LongCat-Flash-Thinking"
# "qwen2.5-32b-instruct"
# "Qwen2.5-14B-Instruct"
# "Toucan-Qwen2.5-14B-Instruct-v0.1"
# "Toucan-Qwen2.5-32B-Instruct-v0.1"


# IP 冷却
COOLDOWN_SECONDS = 1 * 1

# 多轮
KEEP_LAST_ASSISTANT = 3

# 输出会自动落在： <OUTPUT_DIR>/<test_name>_<agent_name>_<single|multi>_t<dialog_turns>.jsonl
OUTPUT_BASE_DIR_THINK = " = " = os.path.join(OUTPUT_BASE_DIR_THINK, AGENT_MODEL_NAME)
OUTPUT_DIR_NO_THINK = os.path.join(OUTPUT_BASE_DIR_NO_THINK, AGENT_MODEL_NAME)
OUTPUT_DIR = OUTPUT_DIR_THINK if DEFAULT_THINKING else OUTPUT_DIR_NO_THINK



# ============================================================
#           进程内资源初始化（避免 spawn 重复模块级大加载）
# ============================================================
GLOBAL_RESOURCES = None  # fork 后子进程继承（CoW）

def _load_resources_for_process(config_path: str) -> Dict[str, Any]:
    """
    在每个子进程里调用一次：加载 config、数据、cache、evaluator。
    避免 spawn 模式下 import 阶段就卡住/重复重建大对象。
    """
    with open(config_path, "r", encoding="utf-8") as f:
        config = yaml.safe_load(f)

    data_path = config["data_path"]
    attraction_path = data_path["attraction"]
    restaurant_path = data_path["restaurant"]
    hotel_path = data_path["hotel"]
    flight_path = data_path["flight"]
    train_path = data_path["train"]

    # evaluators
    with open(attraction_path, "r", encoding="utf-8") as f:
        attractions_data = json.load(f)

    attraction_evaluator = AttractionEvaluator(attractions_data)
    restaurant_evaluator = RestaurantEvaluator(restaurants_path=restaurant_path)
    hotel_evaluator = HotelEvaluator(hotel_path=hotel_path)
    transportation_evaluator = TransportationEvaluator(flights_path=flight_path)

    # 这个通常很重：放到进程内初始化，但只做一次
    data_cache = load_or_build_cache(config_path)
    general_evaluator = GeneralEvaluator(data_cache)

    return {
        "config": config,
        "attraction_evaluator": attraction_evaluator,
        "restaurant_evaluator": restaurant_evaluator,
        "hotel_evaluator": hotel_evaluator,
        "transportation_evaluator": transportation_evaluator,
        "general_evaluator": general_evaluator,
    }


# ============================================================
# 工具调用封装类（保持你原来的结构）
# ============================================================

class FunctionCall:
    def __init__(self, data: Dict[str, Any]):
        self.name = data.get("name")
        self.arguments = data.get("arguments")


class ToolCall:
    def __init__(self, data: Dict[str, Any]):
        self.id = data.get("id")
        self.type = data.get("type")
        self.function = FunctionCall(data.get("function") or {})


class AssistantMessage:
    def __init__(self, data: Dict[str, Any]):
        self.role = data.get("role")
        self.content = data.get("content")
        self.tool_calls = [ToolCall(tc) for tc in (data.get("tool_calls") or [])]


def serialize_tool_calls(tool_calls):
    if not tool_calls:
        return None
    out = []
    for call in tool_calls:
        out.append({
            "id": call.id,
            "type": call.type,
            "function": {
                "name": call.function.name,
                "arguments": call.function.arguments,
            }
        })
    return out


# ============================================================
# 每轮用户状态记录
# ============================================================

@dataclass
class RoundUserState:
    round_idx: int
    id_list_before: List[str]
    id_list_after: List[str]
    blocks: Dict[str, str]                # HISTORY/NEW/MODIFY/ISSUE
    sim_prompt: Optional[str]
    sim_output: Optional[Dict[str, Any]]
    newly_added_ids: List[str]
    all_added_ids: List[str]
    final_answer: str
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int
    usage_by_round: List[Dict[str, int]]
    eval_result: Dict[str, Any]
    use_next_turn_reward: Optional[bool]
    start_time: Optional[str]
    end_time: Optional[str]
    tool_calls: List[Any]


# ============================================================
#                  TravelAgent 主类（OpenAI SDK 版本）
# ============================================================

class TravelAgent:
    """
    每个进程只初始化一个 TravelAgent。
    为了线程安全与并发，OpenAI client 使用 thread-local（每线程一个 client）。
    """
    def __init__(
        self,
        *,
        config_path: str,
        agent_cfg: AgentLLMConfig,
        user_cfg: UserLLMConfig,
        user_prompt_template: str = user_prompt_easy_en,  # ✅ 新增
        timeout_sec: float = 1200.0,  # 兼容旧参数（不用也不影响）
        # ✅ 新增：user 端 base_url 后备列表
        user_base_urls: Optional[List[str]] = None,
        # evaluators 注入（进程内初始化一次）
        attraction_evaluator: AttractionEvaluator,
        restaurant_evaluator: RestaurantEvaluator,
        hotel_evaluator: HotelEvaluator,
        transportation_evaluator: TransportationEvaluator,
        general_evaluator: GeneralEvaluator,
    ):
        self.agent_cfg = agent_cfg
        self.user_cfg = user_cfg
        self.user_prompt_template = user_prompt_template  # ✅ 新增

        # ✅ agent/user base_url & api_key & timeout 独立
        self._agent_base_url = agent_cfg.base_url
        self._agent_api_key = agent_cfg.api_key
        self._agent_timeout_sec = agent_cfg.timeout_sec

        self._user_base_url = user_cfg.base_url
        self._user_api_key = user_cfg.api_key
        self._user_timeout_sec = user_cfg.timeout_sec

        self._tlocal = threading.local()

        # ---------------------------
        # ✅ user base_url pool（主+后备）
        # ---------------------------
        urls = user_base_urls or ([self._user_base_url] if self._user_base_url else [])
        seen = set()
        self._user_base_urls = []
        for u in urls:
            if u and u not in seen:
                self._user_base_urls.append(u)
                seen.add(u)

        if not self._user_base_urls:
            self._user_base_urls = [self._user_base_url]

        try:
            self._user_url_idx = (
                self._user_base_urls.index(self._user_base_url)
                if self._user_base_url in self._user_base_urls
                else 0
            )
        except Exception:
            self._user_url_idx = 0
            if self._user_base_urls:
                self._user_base_url = self._user_base_urls[0]

        self._user_url_lock = threading.Lock()

        # 工具类
        self.general_tool = GeneralTool()
        self.attraction_tool = AttractionTool(config_path)
        self.flight_tool = FlightTool(config_path)
        self.train_tool = TrainTool(config_path)
        self.restaurant_tool = RestaurantTool(config_path)
        self.hotel_tool = HotelTool(config_path)

        # 工具映射
        self.tools = {
            "search_flights": self.flight_tool.search_flights,
            "get_flight_detail_with_products": self.flight_tool.get_flight_detail_with_products,
            "get_airport_coordinates": self.flight_tool.get_airport_coordinates,

            "search_trains": self.train_tool.search_trains,
            "get_train_detail_with_products": self.train_tool.get_train_detail_with_products,
            "get_station_coordinates": self.train_tool.get_station_coordinates,

            "search_hotels": self.hotel_tool.search_hotels,
            "get_hotel_detail_with_products": self.hotel_tool.get_hotel_detail_with_products,
            "get_hotel_coordinates": self.hotel_tool.get_hotel_coordinates,

            "search_attractions": self.attraction_tool.search_attractions,
            "get_attraction_detail_with_products": self.attraction_tool.get_attraction_detail_with_products,
            "get_attraction_coordinates": self.attraction_tool.get_attraction_coordinates,

            "search_restaurants": self.restaurant_tool.search_restaurants,
            "get_restaurant_detail_with_products": self.restaurant_tool.get_restaurant_detail_with_products,
            "get_restaurant_coordinates": self.restaurant_tool.get_restaurant_coordinates,

            "get_route_estimate": self.general_tool.get_route_estimate,
            "get_city_center_coords": self.general_tool.get_city_center_coords,
            "get_date_after": self.general_tool.get_date_after,
        }

        # 工具描述
        self.tool_descriptions = (
            Flight_tool_description
            + Train_tool_description
            + Hotel_tool_description
            + Attraction_tool_description
            + Restaurant_tool_description
            + General_tool_description
        )

        # Evaluator
        self.trip_plan_evaluator = TripPlanEvaluator(
            attraction_evaluator=attraction_evaluator,
            restaurant_evaluator=restaurant_evaluator,
            hotel_evaluator=hotel_evaluator,
            transportation_evaluator=transportation_evaluator,
            general_evaluator=general_evaluator,
        )


    def _get_agent_client(self) -> OpenAI:
        if not hasattr(self._tlocal, "agent_client"):
            self._tlocal.agent_client = OpenAI(
                api_key=self._agent_api_key,
                base_url=self._agent_base_url,
                timeout=self._agent_timeout_sec,
            )
        return self._tlocal.agent_client

    def _get_user_client(self) -> OpenAI:
        if not hasattr(self._tlocal, "user_client"):
            self._tlocal.user_client = OpenAI(
                api_key=self._user_api_key,
                base_url=self._user_base_url,
                timeout=self._user_timeout_sec,
            )
        return self._tlocal.user_client

    def _get_user_simulator(self) -> UserSimulator:
        if not hasattr(self._tlocal, "user_sim"):
            self._tlocal.user_sim = UserSimulator(
                client=self._get_user_client(),
                model_name=self.user_cfg.model_name,
                temperature=self.user_cfg.temperature,
                max_tokens=self.user_cfg.max_tokens,
            )
        return self._tlocal.user_sim

    def _invalidate_user_thread_local(self):
        """只清掉当前线程的 user client/sim，让它们按新 base_url 重建。"""
        if hasattr(self._tlocal, "user_client"):
            delattr(self._tlocal, "user_client")
        if hasattr(self._tlocal, "user_sim"):
            delattr(self._tlocal, "user_sim")
    
    def rotate_user_base_url(self, reason: str = "") -> str:
        """
        切到后备 base_url（轮询），并让当前线程下次请求重建 client/sim。
        返回新的 base_url。
        """
        with self._user_url_lock:
            if not self._user_base_urls:
                return self._user_base_url

            old = self._user_base_url
            self._user_url_idx = (self._user_url_idx + 1) % len(self._user_base_urls)
            self._user_base_url = self._user_base_urls[self._user_url_idx]

        self._invalidate_user_thread_local()
        print(f"[user-url] rotate: {old} -> {self._user_base_url}. reason={reason}", flush=True)
        return self._user_base_url


    # ======================== 核心 Chat Loop ========================
    def chat(self, history: Optional[List[Dict[str, str]]] = None, max_rounds=1000):
        if history is None:
            history = []

        def _to_int(value):
            try:
                return int(value)
            except Exception:
                return 0

        trace_log = {
            "start_time": datetime.now().isoformat(),
            "messages": [],
            "trace_log_messages": [],
            "tool_calls": [],
            "final_answer": None,
            "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
            "usage_by_round": [],
        }

        messages = copy.deepcopy(history)
        trace_log["trace_log_messages"].extend(messages)

        for _ in range(max_rounds):
            client = self._get_agent_client()
            if self.agent_cfg.model_name == "deepseek-v32":
                extra_body = {"chat_template_kwargs": {"thinking": bool(self.agent_cfg.thinking)}}
                if isinstance(self.agent_cfg.extra_body, dict):
                    extra_body.update(self.agent_cfg.extra_body)

                resp = client.chat.completions.create(
                    model=self.agent_cfg.model_name,
                    messages=messages,
                    tools=self.tool_descriptions,
                    tool_choice=self.agent_cfg.tool_choice,
                    temperature=self.agent_cfg.temperature,
                    max_tokens=self.agent_cfg.max_tokens,
                    extra_body=extra_body,
                )
            else:
                resp = client.chat.completions.create(
                    model=self.agent_cfg.model_name,
                    messages=messages,
                    tools=self.tool_descriptions,
                    tool_choice=self.agent_cfg.tool_choice,
                    temperature=self.agent_cfg.temperature,
                    max_tokens=self.agent_cfg.max_tokens,
                )
            usage_obj = getattr(resp, "usage", None) or {}
            if isinstance(usage_obj, dict):
                prompt_tokens = usage_obj.get("prompt_tokens")
                completion_tokens = usage_obj.get("completion_tokens")
                total_tokens = usage_obj.get("total_tokens")
            else:
                prompt_tokens = getattr(usage_obj, "prompt_tokens", None)
                completion_tokens = getattr(usage_obj, "completion_tokens", None)
                total_tokens = getattr(usage_obj, "total_tokens", None)
            prompt_tokens_i = _to_int(prompt_tokens)
            completion_tokens_i = _to_int(completion_tokens)
            total_tokens_i = _to_int(total_tokens)
            trace_log["usage"]["prompt_tokens"] += prompt_tokens_i
            trace_log["usage"]["completion_tokens"] += completion_tokens_i
            trace_log["usage"]["total_tokens"] += total_tokens_i
            trace_log["usage_by_round"].append({
                "prompt_tokens": prompt_tokens_i,
                "completion_tokens": completion_tokens_i,
                "total_tokens": total_tokens_i,
            })

            assistant_msg_dict = resp.choices[0].message.model_dump()
            messages.append(assistant_msg_dict)

            assistant_msg = AssistantMessage(assistant_msg_dict)
            trace_log["trace_log_messages"].append({
                "role": "assistant",
                "content": assistant_msg.content,
                "tool_calls": serialize_tool_calls(assistant_msg.tool_calls),
            })

            tool_calls = assistant_msg.tool_calls

            # ---------- 没有工具调用 = 最终答案 ----------
            if not tool_calls:
                final_answer = assistant_msg.content or ""
                trace_log["messages"] = messages
                trace_log["final_answer"] = final_answer
                trace_log["end_time"] = datetime.now().isoformat()
                return trace_log

            # ---------- 工具调用 ----------
            for call in tool_calls:
                tool_name = call.function.name
                raw_args = call.function.arguments or "{}"

                try:
                    args = json.loads(raw_args)
                except Exception:
                    args = {}

                tool_func = self.tools.get(tool_name)

                if tool_func is None:
                    result = {"error": f"Unknown tool {tool_name}"}
                    trace_log["tool_calls"].append({
                        "tool": tool_name,
                        "args": args,
                        "result": {"success": False, "error": result["error"]},
                    })
                else:
                    try:
                        output = tool_func(**args)
                        result = output
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": True, "data": output},
                        })
                    except Exception as e:
                        err = str(e)
                        result = {"error": err}
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": False, "error": err},
                        })

                if isinstance(result, str):
                    content = result
                else:
                    content = json.dumps(result, ensure_ascii=False)

                tool_msg = {
                    "role": "tool",
                    "tool_call_id": call.id,
                    "name": tool_name,
                    "content": content,
                }

                messages.append(tool_msg)
                trace_log["trace_log_messages"].append(tool_msg)

        trace_log["messages"] = messages
        trace_log["final_answer"] = "Too many rounds."
        trace_log["end_time"] = datetime.now().isoformat()
        return trace_log

    # ============================================================
    #                  Multi-turn helpers（静态方法）
    # ============================================================

    @staticmethod
    def _safe_format(desc: str, slot: str) -> str:
        try:
            return desc.format(slot=slot)
        except Exception:
            return (desc or "").replace("{slot}", slot)

    @staticmethod
    def _build_id2instruction(meta_dict: Dict[str, Any]) -> Dict[str, str]:
        id2instr = {}
        rubric_results = meta_dict.get("rubric_results", {}) or {}

        for _, rubric_group in rubric_results.items():
            for _, cfg in (rubric_group or {}).items():
                desc_tmpl = cfg.get("description", "") or ""
                result = cfg.get("result", {}) or {}
                all_labels = result.get("all_labels_and_ranges", {}) or {}

                for label_text, payload in all_labels.items():
                    _id = (payload or {}).get("_id")
                    if not _id:
                        continue
                    instr = TravelAgent._safe_format(desc_tmpl, label_text).strip()
                    id2instr[_id] = instr
        return id2instr

    @staticmethod
    def _flatten_applied_chains(meta_dict: Dict[str, Any]) -> List[str]:
        chains = meta_dict.get("applied_modification_chains") or {}
        out = []
        for _, chain in chains.items():
            if chain:
                out.append(chain[-1])
        return out

    @staticmethod
    def _build_blocks(meta_dict: Dict[str, Any], id_list: List[str], rubric_progress: Dict[str, int]) -> Dict[str, str]:
        id2instr = TravelAgent._build_id2instruction(meta_dict)

        # ===== HISTORY =====
        history_lines = []
        for cid in id_list:
            txt = id2instr.get(cid)
            if txt:
                history_lines.append(f"[ID: HIS_{cid}] {txt}")

        chains = meta_dict.get("applied_modification_chains", {}) or {}
        new_candidates = []
        modify_items = []  # (rubric_key, expected_id)

        for rubric_key in sorted(chains.keys()):
            chain = chains[rubric_key] or []
            if not chain:
                continue

            prog = rubric_progress.get(rubric_key, 0)

            if prog >= len(chain):
                continue

            if prog == 0:
                new_candidates.append(chain[0])
            else:
                modify_items.append((rubric_key, chain[prog]))

        # ===== NEW =====
        new_lines = []
        for cid in new_candidates:
            txt = id2instr.get(cid)
            if txt:
                new_lines.append(f"[ID: NEW_{cid}] {txt}")

        # ===== MODIFY =====
        modify_lines = []
        for rubric_key, expected_id in modify_items:
            chain = chains.get(rubric_key, [])
            prog = rubric_progress.get(rubric_key, 0)
            expected_txt = id2instr.get(expected_id)
            if not expected_txt:
                continue

            if expected_id in chain[:prog]:
                removed_id = chain[prog - 1]
                removed_txt = id2instr.get(removed_id, "")
                line = (
                    f"[ID: MOD_{expected_id}] {expected_txt} "
                    f"(Note: the previous requirement {removed_txt} is no longer needed)"
                )
            else:
                line = f"[ID: MOD_{expected_id}] {expected_txt}"

            modify_lines.append(line)

        return {
            "HISTORY": "\n".join(history_lines).strip(),
            "NEW": "\n".join(new_lines).strip(),
            "MODIFY": "\n".join(modify_lines).strip(),
        }

    @staticmethod
    def _parse_json_from_text(text: str) -> Optional[dict]:
        """从文本中解析 JSON（支持纯 JSON 或 ```json ... ``` 格式）"""
        # 1) 纯 JSON
        try:
            data = json.loads(text)
            return data if isinstance(data, dict) else None
        except Exception:
            pass
        
        # 2) ```json ... ```
        m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S)
        if m:
            try:
                data = json.loads(m.group(1))
                return data if isinstance(data, dict) else None
            except Exception:
                pass
        
        return None
    
    def _extract_and_replace_json_in_text(self, text: str) -> str:
        """从文本中提取 JSON 部分，转换为自然语言，并保留其他自然语言文本"""
        try:
            # 1. 尝试匹配 ```json ... ``` 格式（支持多行）
            # 先找到 ```json 的开始位置
            json_start_marker = "```json"
            json_end_marker = "```"
            
            start_idx = text.find(json_start_marker)
            if start_idx != -1:
                # 找到开始标记，查找对应的结束标记
                # 从 ```json 之后开始查找 ```
                search_start = start_idx + len(json_start_marker)
                end_idx = text.find(json_end_marker, search_start)
                
                if end_idx != -1:
                    # 提取 JSON 内容（去掉 ```json 和 ```）
                    json_content = text[search_start:end_idx].strip()
                    
                    # 提取 JSON 前后的文本
                    before_json = text[:start_idx].rstrip()
                    after_json = text[end_idx + len(json_end_marker):].lstrip()
                    
                    # 解析 JSON 并转换为自然语言
                    try:
                        plan_json = json.loads(json_content)
                        if plan_json and plan_json.get("trip_plan"):
                            natural_lang = self._json_plan_to_natural_language(plan_json)
                            # 组合：前面的文本 + 转换后的自然语言 + 后面的文本
                            result_parts = []
                            if before_json:
                                result_parts.append(before_json)
                            result_parts.append(natural_lang)
                            if after_json:
                                result_parts.append(after_json)
                            return "\n\n".join(result_parts)
                    except Exception:
                        # JSON 解析失败，保持原样
                        pass
            
            # 2. 尝试匹配纯 JSON（整个文本就是 JSON）
            try:
                plan_json = json.loads(text.strip())
                if plan_json and plan_json.get("trip_plan"):
                    # 整个文本都是 JSON，直接转换
                    return self._json_plan_to_natural_language(plan_json)
            except Exception:
                pass
            
            # 3. 没有找到有效的 JSON，返回原文本
            return text
            
        except Exception:
            # 任何错误都返回原文本
            return text

    def _get_detail_by_id(self, item_id: str, item_type: str, date_str: Optional[str] = None, 
                          check_in_date: Optional[str] = None, check_out_date: Optional[str] = None) -> Optional[str]:
        """根据 ID 从数据库获取详细信息，提取重要字段并组织为自然语言"""
        try:
            if item_type == "attraction":
                # 从 attractions 字典中查找
                attraction = self.attraction_tool.attractions.get(str(item_id))
                if not attraction:
                    return None
                
                # 提取重要字段
                name = attraction.get("poiName", "")
                city = attraction.get("city", "")
                rating = attraction.get("commentScore")
                comment_count = attraction.get("commentCount")
                level = attraction.get("sightLevelStr")
                opening_hours = attraction.get("opening_hours", {})
                open_time = opening_hours.get("open", "")
                close_time = opening_hours.get("close", "")
                features = attraction.get("shortFeatures", "")
                
                # 组织为自然语言
                parts = []
                if name:
                    parts.append(name)
                if city:
                    parts.append(f"in {city}")
                if rating:
                    parts.append(f"rated {rating}/5")
                if comment_count:
                    parts.append(f"({comment_count} reviews)")
                if level:
                    parts.append(f"{level} attraction")
                if open_time and close_time:
                    parts.append(f"open {open_time}-{close_time}")
                if features:
                    parts.append(f"features: {features}")
                
                return ", ".join(parts) if parts else name or item_id
            
            elif item_type == "restaurant":
                # 从 restaurants 索引中查找
                # 处理 restaurant_ 前缀
                restaurant_id = str(item_id)
                # if restaurant_id.startswith("restaurant_"):
                #     restaurant_id = restaurant_id.replace("restaurant_", "", 1)
                
                restaurant = self.restaurant_tool.index_restaurant.get(restaurant_id)
                if not restaurant:
                    return None
                
                # 提取重要字段
                name = restaurant.get("name", "")
                city = restaurant.get("real_city", "")
                stars = restaurant.get("stars")
                review_count = restaurant.get("review_count")
                category = restaurant.get("small_cate", "")
                avg_price = restaurant.get("avg_price")
                
                # 组织为自然语言
                parts = []
                if name:
                    parts.append(name)
                if city:
                    parts.append(f"in {city}")
                if category:
                    parts.append(f"({category})")
                if stars:
                    parts.append(f"{stars} stars")
                if review_count:
                    parts.append(f"({review_count} reviews)")
                if avg_price:
                    parts.append(f"avg price ¥{avg_price}")
                
                return ", ".join(parts) if parts else name or item_id
            
            elif item_type == "hotel":
                # 从 hotels 索引中查找
                hotel = self.hotel_tool.index_hotel.get(str(item_id))
                if not hotel:
                    return None
                
                # 提取重要字段
                name = hotel.get("name", "")
                city = hotel.get("real_city", "")
                stars = hotel.get("stars")
                review_count = hotel.get("review_count")
                hotel_type = hotel.get("hotel_type", "")
                good_rate = hotel.get("good_remarks_rate")
                
                # 组织为自然语言
                parts = []
                if name:
                    parts.append(name)
                if city:
                    parts.append(f"in {city}")
                if hotel_type:
                    parts.append(f"({hotel_type})")
                if stars:
                    parts.append(f"{stars} stars")
                if review_count:
                    parts.append(f"({review_count} reviews)")
                if good_rate is not None:
                    parts.append(f"{good_rate*100:.0f}% positive")
                
                return ", ".join(parts) if parts else name or item_id
            
            elif item_type == "flight":
                # 从 flights 列表中查找
                flight = None
                for f in self.flight_tool.flights:
                    if str(f.get("Flight_id", "")) == str(item_id):
                        flight = f
                        break
                
                if not flight:
                    return None
                
                # 提取重要字段
                flight_num = flight.get("Flight Number", "")
                airline = flight.get("Airline", "")
                dep_time = flight.get("Departure Time", "")
                arr_time = flight.get("Arrival Time", "")
                dep_airport = flight.get("Departure Airport", "")
                arr_airport = flight.get("Arrival Airport", "")
                dep_city = flight.get("Departure City", "")
                arr_city = flight.get("Arrival City", "")
                
                # 组织为自然语言
                parts = []
                if flight_num:
                    parts.append(f"Flight {flight_num}")
                if airline:
                    parts.append(f"({airline})")
                if dep_time and arr_time:
                    parts.append(f"{dep_time}-{arr_time}")
                if dep_airport and arr_airport:
                    parts.append(f"from {dep_airport} to {arr_airport}")
                elif dep_city and arr_city:
                    parts.append(f"from {dep_city} to {arr_city}")
                
                return ", ".join(parts) if parts else flight_num or item_id
            
            elif item_type == "train":
                # 从 trains 列表中查找
                train = None
                for t in self.train_tool.trains:
                    if str(t.get("Train_id", "")) == str(item_id):
                        train = t
                        break
                
                if not train:
                    return None
                
                # 提取重要字段
                train_num = train.get("Train Number", "")
                dep_time = train.get("Departure Time", "")
                arr_time = train.get("Arrival Time", "")
                dep_station = train.get("Departure Station", "")
                arr_station = train.get("Arrival Station", "")
                dep_city = train.get("Departure City", "")
                arr_city = train.get("Arrival City", "")
                
                # 组织为自然语言
                parts = []
                if train_num:
                    parts.append(f"Train {train_num}")
                if dep_time and arr_time:
                    parts.append(f"{dep_time}-{arr_time}")
                if dep_station and arr_station:
                    parts.append(f"from {dep_station} to {arr_station}")
                elif dep_city and arr_city:
                    parts.append(f"from {dep_city} to {arr_city}")
                
                return ", ".join(parts) if parts else train_num or item_id
            
        except Exception as e:
            # 静默失败，不打印错误（避免日志过多）
            return None
        
        return None

    def _identify_item_type(self, item_id: str) -> Optional[str]:
        """根据 ID 前缀识别类型"""
        try:
            if item_id is None:
                return None
            
            item_id_str = str(item_id).strip()
            if not item_id_str:
                return None
            
            # 景点：A_ 开头，或纯数字（可能是 poiId）
            if item_id_str.startswith("A_"):
                return "attraction"
            elif item_id_str.isdigit():
                # 纯数字可能是景点 ID (poiId)
                # 检查是否在 attractions 中存在
                try:
                    if hasattr(self, 'attraction_tool') and hasattr(self.attraction_tool, 'attractions'):
                        if self.attraction_tool.attractions.get(item_id_str):
                            return "attraction"
                except Exception:
                    pass
            
            # 餐厅：R_ 开头，或 restaurant_ 开头
            if item_id_str.startswith("R_"):
                return "restaurant"
            elif item_id_str.startswith("restaurant_"):
                return "restaurant"
            
            # 酒店：H_ 开头，或 Hotel_ 开头
            if item_id_str.startswith("H_"):
                return "hotel"
            elif item_id_str.startswith("Hotel_"):
                return "hotel"
            
            # 航班：T_FLT_ 开头，或 Flight_ 开头
            if item_id_str.startswith("T_FLT_"):
                return "flight"
            elif item_id_str.startswith("Flight_"):
                return "flight"
            
            # 火车：T_SHN_ 或 T_TRAIN_ 开头，或 Train_ 开头
            if item_id_str.startswith("T_SHN_") or item_id_str.startswith("T_TRAIN_"):
                return "train"
            elif item_id_str.startswith("Train_"):
                return "train"
            
            return None
        except Exception:
            # 静默失败，返回 None
            return None

    def _json_plan_to_natural_language(self, plan_json: dict) -> str:
        """将 JSON 格式的旅行计划转换为自然语言描述"""
        try:
            if not isinstance(plan_json, dict):
                return json.dumps(plan_json, ensure_ascii=False, indent=2)
            
            trip_plan = plan_json.get("trip_plan", {})
            if not trip_plan or not isinstance(trip_plan, dict):
                return json.dumps(plan_json, ensure_ascii=False, indent=2)
            
            start_date = trip_plan.get("start_date", "")
            end_date = trip_plan.get("end_date", "")
            number_of_people = trip_plan.get("number_of_people", "")
            daily_schedule = trip_plan.get("daily_schedule", [])
            
            # 确保 daily_schedule 是列表
            if not isinstance(daily_schedule, list):
                return "The assistant provided a trip plan, but the daily schedule format is invalid."
            
            if not daily_schedule:
                return "The assistant provided a trip plan, but the daily schedule is empty."
            
            # 构建自然语言描述
            lines = []
            
            # # 总体信息
            # if start_date and end_date:
            #     lines.append(f"The assistant provided a trip plan from {start_date} to {end_date}")
            #     if number_of_people:
            #         lines[-1] += f" for {number_of_people} people"
            #     lines[-1] += "."
            # else:
            #     lines.append("The assistant provided the following trip plan:")
            
            # 计算酒店 check_out 日期（用于酒店详情查询）
            from datetime import timedelta
            
            # 每天的安排
            for day_idx, day in enumerate(daily_schedule, 1):
                try:
                    # 确保 day 是字典
                    if not isinstance(day, dict):
                        lines.append(f"\nDay {day_idx}: (Invalid day format)")
                        continue
                    
                    date = day.get("date", "")
                    cities = day.get("cities", "")
                    
                    day_header = f"\nDay {day_idx}"
                    if date:
                        day_header += f" ({date})"
                    if cities:
                        day_header += f": {cities}"
                    lines.append(day_header)
                    
                    # 酒店信息
                    try:
                        hotel = day.get("hotel")
                        if hotel and isinstance(hotel, dict) and hotel.get("id"):
                            hotel_id = hotel.get("id")
                            hotel_type = self._identify_item_type(hotel_id)
                            if hotel_type == "hotel":
                                # 计算 check_out 日期：找到下一个有酒店的日期，或者使用当前日期的下一天
                                check_in_date = date
                                check_out_date = None
                                
                                # 查找后续日期中是否有酒店
                                try:
                                    for next_day in daily_schedule[day_idx:]:
                                        if not isinstance(next_day, dict):
                                            continue
                                        next_hotel = next_day.get("hotel")
                                        if next_hotel and isinstance(next_hotel, dict) and next_hotel.get("id") == hotel_id:
                                            # 如果后续还有同一天酒店，说明还在住
                                            continue
                                        elif next_hotel and isinstance(next_hotel, dict) and next_hotel.get("id") != hotel_id:
                                            # 如果后续有不同酒店，说明在这里退房
                                            check_out_date = next_day.get("date", "")
                                            break
                                except Exception:
                                    pass
                                
                                # 如果没找到，使用当前日期的下一天
                                if not check_out_date and date:
                                    try:
                                        check_in = datetime.strptime(date, "%Y-%m-%d")
                                        check_out_date = (check_in + timedelta(days=1)).strftime("%Y-%m-%d")
                                    except:
                                        check_out_date = date
                                
                                hotel_detail = self._get_detail_by_id(
                                    hotel_id, "hotel", 
                                    date_str=date,
                                    check_in_date=check_in_date,
                                    check_out_date=check_out_date
                                )
                                if hotel_detail:
                                    lines.append(f"  Hotel: {hotel_detail}")
                                else:
                                    lines.append(f"  Hotel: {hotel_id}")
                            else:
                                lines.append(f"  Hotel: {hotel_id}")
                    except Exception:
                        # 酒店信息处理失败，跳过
                        pass
                    
                    # 活动列表
                    try:
                        activities = day.get("activities", [])
                        if activities and isinstance(activities, list):
                            for activity in activities:
                                try:
                                    # 确保 activity 是字典
                                    if not isinstance(activity, dict):
                                        continue
                                    
                                    time_range = activity.get("time", "")
                                    activity_type = activity.get("type", "")
                                    description = activity.get("description", "")
                                    activity_id = activity.get("id")
                                    
                                    # 根据类型和 ID 获取详细信息
                                    detail_info = None
                                    if activity_id:
                                        try:
                                            item_type = self._identify_item_type(activity_id)
                                            if item_type:
                                                detail_info = self._get_detail_by_id(activity_id, item_type, date)
                                        except Exception:
                                            pass
                                    
                                    # 第一行：时间和类型
                                    type_line = ""
                                    if time_range:
                                        type_line += f"  [{time_range}] "
                                    type_line += activity_type or "Activity"
                                    if activity_id:
                                        type_line += f" (ID: {activity_id})"
                                    lines.append(type_line)
                                    
                                    # 第二行：description（如果有）
                                    if description:
                                        lines.append(f"    Description: {description}")
                                    
                                    # 第三行：详细信息（如果有）
                                    if detail_info:
                                        lines.append(f"    Details: {detail_info}")
                                except Exception:
                                    # 单个 activity 处理失败，跳过
                                    continue
                    except Exception:
                        # activities 处理失败，跳过
                        pass
                except Exception:
                    # 单个 day 处理失败，跳过
                    lines.append(f"\nDay {day_idx}: (Error processing day)")
                    continue
            
            return "\n".join(lines).strip()
        except Exception as e:
            # 如果整个转换过程失败，返回错误提示
            return f"The assistant provided a trip plan, but it could not be converted to natural language due to an error."

    def _history_text_for_sim(self, conv_simple: List[Dict[str, str]], keep_last_assistant: int = 3) -> str:
        OMIT_TEXT = "Earlier assistant responses are omitted due to context length limits."
        try:
            assistant_indices = [i for i, m in enumerate(conv_simple) if m.get("role") == "assistant"]
            keep_indices = set(assistant_indices[-keep_last_assistant:])

            lines = []
            for i, m in enumerate(conv_simple):
                try:
                    if m.get("role") == "user":
                        lines.append(f"User: {m.get('content', '')}")
                    elif m.get("role") == "assistant":
                        if i in keep_indices:
                                    content = m.get('content', '')
                                    
                                    # ✅ 检测并转换 JSON 格式的旅行计划，保留其他自然语言
                                    try:
                                        processed_content = self._extract_and_replace_json_in_text(content)
                                        lines.append(f"Assistant: {processed_content}")
                                    except Exception:
                                        # 如果转换失败，保持原样
                                        lines.append(f"Assistant: {content}")
                        else:
                            lines.append(f"Assistant: {OMIT_TEXT}")
                except Exception:
                    # 单个消息处理失败，跳过
                    continue
            return "\n\n".join(lines).strip()
        except Exception:
            # 如果整个方法失败，返回空字符串或基本格式
            return "\n\n".join([f"{m.get('role', 'unknown')}: {m.get('content', '')}" for m in conv_simple if isinstance(m, dict)]).strip()

    def _apply_proposed_ids_with_prefix(
        self,
        meta_dict: Dict[str, Any],
        current_id_list: List[str],
        rubric_progress: Dict[str, int],
        proposed_ids: List[Any],
    ) -> Tuple[bool, List[str], List[str], List[str]]:
        chains = meta_dict.get("applied_modification_chains", {}) or {}
        id2pos: Dict[str, Tuple[str, int]] = {}
        for rubric_key, chain in chains.items():
            for idx, cid in enumerate(chain):
                id2pos[cid] = (rubric_key, idx)

        errors: List[str] = []
        newly_added: List[str] = []
        removed: List[str] = []

        if not isinstance(proposed_ids, list):
            return False, [], [], [f"instruction_ids is not a list: {type(proposed_ids)}"]

        for raw in proposed_ids:
            if not isinstance(raw, str) or not raw:
                errors.append(f"invalid instruction_id (not str/empty): {raw!r}")
                continue

            if raw.startswith("NEW_"):
                op = "NEW"
                base_id = raw[len("NEW_"):]
            elif raw.startswith("MOD_"):
                op = "MOD"
                base_id = raw[len("MOD_"):]
            else:
                continue

            if base_id not in id2pos:
                errors.append(f"{raw} -> {base_id} not found in applied_modification_chains")
                continue

            rubric_key, idx = id2pos[base_id]
            chain = chains.get(rubric_key, []) or []

            if op == "NEW":
                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)
                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

            else:  # MOD
                if idx <= 0:
                    errors.append(f"{raw}: target is first element of chain, no previous to remove")
                    continue

                prev_id = chain[idx - 1]
                if prev_id not in current_id_list:
                    errors.append(f"{raw}: expected to remove prev_id={prev_id}, but it's not in current_id_list")
                    continue

                try:
                    current_id_list.remove(prev_id)
                    removed.append(prev_id)
                except ValueError:
                    errors.append(f"{raw}: failed to remove prev_id={prev_id} (ValueError)")
                    continue

                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)

                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

        ok = (len(errors) == 0)
        return ok, newly_added, removed, errors

    def rubric_all_satisfied(self, rubric_progress: dict, applied_chains: dict) -> bool:
        for rubric, chains in applied_chains.items():
            required = len(chains)
            current = rubric_progress.get(rubric, 0)
            if current < required:
                return False
        return True

    def select_raw_answer_for_eval(self, current_raw_answer: str, conv_simple: List[Dict]) -> str:
        def parse_trip_plan_json(raw_text: str) -> Optional[dict]:
            # 1) 纯 JSON
            try:
                data = json.loads(raw_text)
                return data if isinstance(data, dict) else None
            except Exception:
                pass

            # 2) ```json ... ```
            m = re.search(r"```json\s*(\{.*?\})\s*```", raw_text, re.S)
            if m:
                try:
                    data = json.loads(m.group(1))
                    return data if isinstance(data, dict) else None
                except Exception:
                    return None

            return None

        def should_fallback_only_if_empty_schedule(raw_text: str) -> bool:
            data = parse_trip_plan_json(raw_text)
            if not isinstance(data, dict):
                return False

            tp = data.get("trip_plan")
            if not isinstance(tp, dict):
                return False

            # 关键：必须有 daily_schedule 字段，且必须是“空 list”
            if "daily_schedule" not in tp:
                return False

            ds = tp.get("daily_schedule")
            return isinstance(ds, list) and len(ds) == 0

        if not should_fallback_only_if_empty_schedule(current_raw_answer):
            return current_raw_answer

        def extract_nonempty_schedule(raw_text: str) -> bool:
            data = parse_trip_plan_json(raw_text)
            if not isinstance(data, dict):
                return False
            tp = data.get("trip_plan")
            if not isinstance(tp, dict):
                return False
            ds = tp.get("daily_schedule")
            return isinstance(ds, list) and len(ds) > 0

        for msg in reversed(conv_simple):
            if msg.get("role") != "assistant":
                continue
            raw_text = msg.get("content", "")
            if extract_nonempty_schedule(raw_text):
                return raw_text

        return current_raw_answer

    def run_multiturn_trip(
        self,
        meta_dict: Dict[str, Any],
        dialog_turns: int = 1,
        keep_last_assistant: int = 3,
        extra_save_dir: str = "__unused__",
        max_round_retry: int = 1,          # ✅ 新增：单轮失败最多重跑次数（回滚上一轮后重跑当前轮）
        round_retry_sleep: float = 0.0,    # ✅ 新增：每次失败后 sleep（可选）
    ) -> Dict[str, Any]:

        trip_id = meta_dict.get("trip_id", "unknown_trip")
        rubric_progress = copy.deepcopy(meta_dict.get("rubric_progress", {}))

        conv_full: List[Dict[str, str]] = []
        conv_simple: List[Dict[str, str]] = []
        rounds: List[RoundUserState] = []

        if dialog_turns == 1:
            current_user_query = meta_dict.get("query_with_constraints", meta_dict.get("query_basic", ""))
            current_id_list = self._flatten_applied_chains(meta_dict)
        else:
            current_user_query = meta_dict.get("query_basic", "")
            current_id_list = meta_dict.get("instruction_ids_basic", [])

        issue_text = ""

        for r in range(1, dialog_turns + 1):
            round_try = 0

            # ✅ 当前轮失败就回滚到上一轮状态，重跑当前轮
            while True:
                # ====== 快照：用于“本轮失败 -> 回滚到上一轮结束状态” ======
                snap_conv_full = copy.deepcopy(conv_full)
                snap_conv_simple = copy.deepcopy(conv_simple)
                snap_id_list = list(current_id_list)
                snap_progress = copy.deepcopy(rubric_progress)
                snap_issue_text = issue_text
                snap_rounds_len = len(rounds)
                snap_current_user_query = current_user_query

                try:
                    id_before = list(current_id_list)

                    blocks = self._build_blocks(meta_dict, current_id_list, rubric_progress)
                    blocks["ISSUE"] = issue_text or ""

                    sim_prompt = None
                    sim_output = None
                    newly_added = []
                    all_added = []
                    use_next_turn_reward = False

                    # ====== multi-turn: round >= 2 才模拟用户 ======
                    if r >= 2:
                        history_text = self._history_text_for_sim(conv_simple, keep_last_assistant)
                        max_retry = 3
                        retry = 0

                        while True:
                            # 获取 user_style 并填充到 prompt 中
                            user_style = meta_dict.get("user_style", "")
                            style_description = ""
                            if user_style and user_style in USER_STYLE_DESCRIPTIONS:
                                style_description = USER_STYLE_DESCRIPTIONS[user_style].strip()
                            
                            sim_prompt = self.user_prompt_template \
                                .replace("{{HISTORY}}", blocks.get("HISTORY", "")) \
                                .replace("{{NEW}}", blocks.get("NEW", "")) \
                                .replace("{{MODIFY}}", blocks.get("MODIFY", "")) \
                                .replace("{{HISTORY_MESSAGES}}", history_text) \
                                .replace("{{USER_STYLE}}", style_description)

                            # ✅ user_sim.generate 报错就轮询切 url 重试
                            max_url_switch = max(1, len(getattr(self, "_user_base_urls", []) or []))
                            url_switch_cnt = 0
                            last_err = None

                            while True:
                                user_sim = self._get_user_simulator()
                                try:
                                    if retry > 0:
                                        sim_output = user_sim.generate(sim_prompt, temperature=self.user_cfg.retry_temperature)
                                    else:
                                        sim_output = user_sim.generate(sim_prompt, temperature=self.user_cfg.temperature)
                                    last_err = None
                                    break
                                except Exception as e:
                                    last_err = e
                                    if url_switch_cnt < max_url_switch:
                                        url_switch_cnt += 1
                                        self.rotate_user_base_url(reason=f"user_sim.generate error: {e}")
                                        continue
                                    raise last_err

                            proposed_ids = sim_output.get("instruction_ids", []) or []
                            use_next_turn_reward = (
                                not proposed_ids or
                                all(isinstance(pid, str) and pid.startswith("HIS_") for pid in proposed_ids)
                            )

                            tmp_id_list = list(current_id_list)
                            tmp_progress = copy.deepcopy(rubric_progress)

                            ok, newly_added_tmp, _, errors = self._apply_proposed_ids_with_prefix(
                                meta_dict=meta_dict,
                                current_id_list=tmp_id_list,
                                rubric_progress=tmp_progress,
                                proposed_ids=proposed_ids,
                            )

                            if ok:
                                current_id_list = tmp_id_list
                                rubric_progress = tmp_progress
                                newly_added = newly_added_tmp
                                current_user_query = sim_output.get("user_query", "")
                                all_added = proposed_ids
                                break

                            retry += 1
                            if retry >= max_retry:
                                newly_added = []
                                current_user_query = sim_output.get("user_query", "")
                                break

                    # ====== system prompt 只在第 1 轮加入（失败回滚后重跑也不会重复，因为 conv_full 会回滚） ======
                    if r == 1:
                        if dialog_turns == 1:
                            conv_full.append({"role": "system", "content": system_prompt_en_single_turn})
                        else:
                            conv_full.append({"role": "system", "content": system_prompt_en})

                    # ====== 发送用户 query ======
                    conv_full.append({"role": "user", "content": current_user_query})
                    conv_simple.append({"role": "user", "content": current_user_query})

                    # ====== agent 生成 ======
                    trace_log = self.chat(history=conv_full)
                    conv_full = copy.deepcopy(trace_log["messages"])
                    conv_simple.append({"role": "assistant", "content": trace_log["final_answer"]})

                    # ====== eval ======
                    raw_answer_for_eval = self.select_raw_answer_for_eval(
                        current_raw_answer=trace_log["final_answer"],
                        conv_simple=conv_simple
                    )

                    eval_result = self.trip_plan_evaluator.evaluate(
                        meta_dict=meta_dict,
                        raw_text=raw_answer_for_eval,
                        id_list=current_id_list
                    )
                    issue_text = eval_result.get("error_text", "") or ""

                    round_state = RoundUserState(
                        round_idx=r,
                        id_list_before=id_before,
                        id_list_after=list(current_id_list),
                        blocks=blocks,
                        sim_prompt=sim_prompt,
                        sim_output=sim_output,
                        newly_added_ids=newly_added,
                        all_added_ids=all_added,
                        final_answer=trace_log["final_answer"],
                        prompt_tokens=int((trace_log.get("usage") or {}).get("prompt_tokens", 0) or 0),
                        completion_tokens=int((trace_log.get("usage") or {}).get("completion_tokens", 0) or 0),
                        total_tokens=int((trace_log.get("usage") or {}).get("total_tokens", 0) or 0),
                        usage_by_round=trace_log.get("usage_by_round") or [],
                        tool_calls=trace_log["tool_calls"],
                        eval_result=eval_result,
                        use_next_turn_reward=use_next_turn_reward,
                        start_time=trace_log["start_time"],
                        end_time=trace_log["end_time"],
                    )
                    rounds.append(round_state)

                    applied_chains = meta_dict.get("applied_modification_chains", {})
                    if self.rubric_all_satisfied(rubric_progress, applied_chains):
                        print(f"[trip_id={trip_id}] All rubrics are proposed at round {r}", flush=True)

                        summary = {
                            "trip_id": trip_id,
                            "dialog_turns": dialog_turns,
                            "keep_last_assistant": keep_last_assistant,
                            "final_id_list": list(current_id_list),
                            "messages": conv_full,
                            "rounds": [asdict(x) for x in rounds],
                            "terminate_reason": None,
                        }
                        return summary

                    # ✅ 本轮成功：跳出 while，进入下一轮 r+1
                    break

                except Exception as e:
                    # ====== 本轮失败：回滚到上一轮结束状态 ======
                    conv_full = snap_conv_full
                    conv_simple = snap_conv_simple
                    current_id_list = snap_id_list
                    rubric_progress = snap_progress
                    issue_text = snap_issue_text
                    current_user_query = snap_current_user_query

                    # ✅ rounds 也回滚（避免半截 state 污染）
                    if len(rounds) > snap_rounds_len:
                        rounds[:] = rounds[:snap_rounds_len]

                    round_try += 1
                    print(f"[multiturn] round {r} failed (try {round_try}/{max_round_retry}): {e}", flush=True)

                    # ✅ 本轮重试耗尽：抛到 worker，让 worker requeue（那就会从头）
                    if round_try >= max_round_retry:
                        raise

                    if round_retry_sleep > 0:
                        time.sleep(round_retry_sleep)

                    # ✅ 继续 while：重跑当前轮 r（不从头）
                    continue

        summary = {
            "trip_id": trip_id,
            "dialog_turns": dialog_turns,
            "keep_last_assistant": keep_last_assistant,
            "final_id_list": list(current_id_list),
            "messages": conv_full,
            "rounds": [asdict(x) for x in rounds],
            "terminate_reason": None,  # ✅ 成功才返回，失败都 raise（因此不会落盘 partial）
        }
        return summary

# ============================================================
#                    DISTRIBUTED RUNNER
# ============================================================
def canonical_final_id_list(final_id_list):
    """去重+排序，保证顺序不影响相等判定"""
    if not isinstance(final_id_list, list):
        return tuple()
    return tuple(sorted({str(x) for x in final_id_list if x is not None}))

def make_sample_key(trip_id, final_id_list):
    return (str(trip_id), canonical_final_id_list(final_id_list))

def final_id_list_from_meta(meta: dict):
    """
    与你参考代码一致：
    - 优先 applied_modification_chains：每个 rubric_key 的 chain 取最后一个 id
    - 兜底：meta["final_id_list"]
    """
    chains = meta.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        out = []
        for _, chain in chains.items():
            if isinstance(chain, list) and chain:
                out.append(chain[-1])
        return out

    fil = meta.get("final_id_list")
    if isinstance(fil, list):
        return fil
    return []

def final_id_list_from_output_record(rec: dict):
    """
    从输出 jsonl 的一条 record 里取用于判重的 id_list：

    - dialog_turns > 1（multi-turn）：
        优先用 rounds[0].id_list_before（它在你的实现里就是第1轮开始前的 current_id_list，
        通常等于 meta["instruction_ids_basic"]），避免 summary.final_id_list 漂移（没跑完/中断）。

    - dialog_turns <= 1（single-turn）：
        保持原逻辑：优先 summary.final_id_list，兜底 rec.final_id_list
    """
    if not isinstance(rec, dict):
        return []

    summary = rec.get("summary") or {}

    dt = rec.get("dialog_turns")
    if dt is None:
        dt = summary.get("dialog_turns")

    try:
        dt = int(dt or 0)
    except Exception:
        dt = 0

    if dt > 1:
        rounds = summary.get("rounds")
        if isinstance(rounds, list) and len(rounds) > 0:
            r0 = rounds[0] or {}
            fil0 = r0.get("id_list_before")
            if isinstance(fil0, list):
                return fil0

        fil_basic = summary.get("instruction_ids_basic")
        if isinstance(fil_basic, list):
            return fil_basic

        fil = summary.get("final_id_list")
        if isinstance(fil, list):
            return fil

        fil2 = rec.get("final_id_list")
        if isinstance(fil2, list):
            return fil2

        return []

    fil = summary.get("final_id_list")
    if isinstance(fil, list):
        return fil
    fil2 = rec.get("final_id_list")
    if isinstance(fil2, list):
        return fil2
    return []

def load_done_keys_from_output_jsonl(output_jsonl_path: str):
    """
    扫 output jsonl，构建已经跑过的 sample_key 集合。
    不要求每行都有 eval；只要 summary.final_id_list 写了就算跑过。
    """
    done = set()
    if not output_jsonl_path or (not os.path.exists(output_jsonl_path)):
        return done

    with open(output_jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except Exception:
                continue

            trip_id = rec.get("trip_id") or (rec.get("summary") or {}).get("trip_id")
            fil = final_id_list_from_output_record(rec)
            done.add(make_sample_key(trip_id, fil))

    return done


def get_resources_for_process(config_path: str):
    """
    子进程入口获取资源：
    - fork：GLOBAL_RESOURCES 已继承，直接返回（不再 load）
    - spawn：GLOBAL_RESOURCES 为空，只能各进程自己 load 一份
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is not None:
        return GLOBAL_RESOURCES
    return _load_resources_for_process(config_path)

def preload_resources_in_parent(config_path: str):
    """
    只在父进程里调用（fork 前）。
    fork 后子进程会继承 GLOBAL_RESOURCES，并通过 Copy-on-Write 共享内存页。
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is None:
        print("[runner] preload resources in parent ...", flush=True)
        GLOBAL_RESOURCES = _load_resources_for_process(config_path)
        print("[runner] preload resources done.", flush=True)

        try:
            import gc
            gc.freeze()
            print("[runner] gc.freeze() done.", flush=True)
        except Exception:
            pass


def _append_jsonl(file_lock: mp.Lock, output_jsonl_path: str, obj: dict):
    line = json.dumps(obj, ensure_ascii=False)
    with file_lock:
        with open(output_jsonl_path, "a", encoding="utf-8") as f:
            f.write(line + "\n")


def _ip_process_main(
    task_q,                 # mp.JoinableQueue
    done_counter,           # mp.Value('i', 0)
    total_tasks: int,
    stop_event,             # mp.Event
    file_lock,              # mp.Lock
    output_jsonl_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
    user_prompt_template: str,   # ✅ 新增
):
    cooldown_until = 0.0

    print("[worker] before get resources", flush=True)
    resources = get_resources_for_process(CONFIG_PATH)
    print("[worker] after get resources (shared if fork)", flush=True)

    # ✅ agent/user 参数独立（来自 AGENT_CFG/USER_CFG）
    agent = TravelAgent(
        config_path=CONFIG_PATH,
        agent_cfg=AGENT_CFG,
        user_cfg=USER_CFG,
        user_prompt_template=user_prompt_template,  # ✅ 新增
        # ✅ 新增：把 user 端主+后备 URL 列表传进去
        user_base_urls=DEFAULT_USER_URLS,
        attraction_evaluator=resources["attraction_evaluator"],
        restaurant_evaluator=resources["restaurant_evaluator"],
        hotel_evaluator=resources["hotel_evaluator"],
        transportation_evaluator=resources["transportation_evaluator"],
        general_evaluator=resources["general_evaluator"],
    )

    def set_cooldown():
        nonlocal cooldown_until
        cooldown_until = max(cooldown_until, time.time() + COOLDOWN_SECONDS)
        print(
            "[worker] ERROR -> cooldown "
            f"{COOLDOWN_SECONDS}s until "
            f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(cooldown_until))}",
            flush=True
        )

    def wait_cooldown():
        while not stop_event.is_set() and time.time() < cooldown_until:
            time.sleep(1.0)

    def _is_max_tokens_bad_request(e: Exception) -> bool:
        """
        只针对你贴的这一类错误：
        Error code: 400 - {'error': {'message': 'max_tokens must be at least 1, got -793.' ...}}
        命中则：不 cooldown + 不重试 + 丢弃该 item
        """
        try:
            s = str(e)
        except Exception:
            s = repr(e)

        return "max_tokens must be at least 1" in s

    def worker_loop(worker_id: int):
        MAX_ITEM_RETRY = 3  # ✅ 每条数据最多“从头重跑”(requeue) 3 次

        while True:
            wait_cooldown()

            if stop_event.is_set():
                try:
                    item = task_q.get(timeout=1.0)
                except queue.Empty:
                    return
                task_q.task_done()
                continue

            try:
                item = task_q.get(timeout=1.0)
            except queue.Empty:
                continue

            try:
                summary = agent.run_multiturn_trip(
                    meta_dict=item,
                    dialog_turns=dialog_turns,
                    keep_last_assistant=keep_last_assistant,
                    extra_save_dir="__unused__",
                    max_round_retry=3,         # 单轮回滚重跑（不从头）次数
                    round_retry_sleep=0.0,
                )

                # ✅ 兜底：任何“非完整成功”的 summary 都不允许落盘
                if summary.get("terminate_reason"):
                    raise RuntimeError(f"multiturn terminated early: {summary.get('terminate_reason')}")

                record = {
                    "trip_id": summary.get("trip_id"),
                    "dialog_turns": summary.get("dialog_turns"),
                    "summary": summary,
                }
                _append_jsonl(file_lock, output_jsonl_path, record)

                with done_counter.get_lock():
                    done_counter.value += 1
                    done = done_counter.value

                print(
                    f"[worker] progress: {done}/{total_tasks} "
                    f"(trip_id={summary.get('trip_id')})",
                    flush=True
                )

                if done >= total_tasks:
                    stop_event.set()

            except Exception as e:
                # === 新规则：max_tokens 这类 400 -> 不 cooldown，不 requeue，直接丢弃 ===
                if _is_max_tokens_bad_request(e):
                    print(f"[worker] NON-RETRYABLE(max_tokens) -> drop: {e}", flush=True)

                    with done_counter.get_lock():
                        done_counter.value += 1
                        done = done_counter.value

                    print(f"[worker] progress: {done}/{total_tasks} (dropped)", flush=True)

                    if done >= total_tasks:
                        stop_event.set()

                else:
                    # ✅ 其他异常：限制每条数据“从头重跑”(requeue) 最多 3 次
                    if not stop_event.is_set():
                        cur_retry = int(item.get("_retry", 0) or 0)

                        if cur_retry >= MAX_ITEM_RETRY:
                            # 超过上限：不再 requeue，直接丢弃并计数
                            print(
                                f"[worker] RETRY_LIMIT({MAX_ITEM_RETRY}) -> drop: "
                                f"trip_id={item.get('trip_id')} err={e}",
                                flush=True
                            )

                            with done_counter.get_lock():
                                done_counter.value += 1
                                done = done_counter.value

                            print(f"[worker] progress: {done}/{total_tasks} (dropped_retry_limit)", flush=True)

                            if done >= total_tasks:
                                stop_event.set()

                            # 不 cooldown（可选）。如果你希望仍 cooldown，把下面 set_cooldown() 打开即可
                            # set_cooldown()

                        else:
                            # 还没到上限：requeue + cooldown
                            item["_retry"] = cur_retry + 1
                            task_q.put(item)

                            print(
                                f"[worker] trajectory failed -> requeue({item['_retry']}/{MAX_ITEM_RETRY}): {e}",
                                flush=True
                            )
                            set_cooldown()

            finally:
                task_q.task_done()

    threads = []
    for wid in range(REQUESTS_PER_IP):
        t = threading.Thread(target=worker_loop, args=(wid,), daemon=True)
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

    print("[worker] process exit", flush=True)



def _pick_mp_context():
    """
    Linux 优先 fork（避免 spawn 导致每个子进程重复 import / 大加载）
    不支持 fork 的环境自动 fallback 到 spawn
    """
    try:
        methods = mp.get_all_start_methods()
    except Exception:
        methods = ["spawn"]

    if "fork" in methods:
        return mp.get_context("fork")
    return mp.get_context("spawn")


def run_multiturn_file_distributed(
    test_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
    output_jsonl_path: str,
    user_prompt_template: str,   # ✅ 新增
):
    ctx = _pick_mp_context()
    print(f"[runner] mp start method = {ctx.get_start_method()}", flush=True)

    if ctx.get_start_method() == "fork":
        preload_resources_in_parent(CONFIG_PATH)

    with open(test_path, "r", encoding="utf-8") as f:
        items = json.load(f)

    # ====== 测试模式：每个 style 取一个样本，共 5 个样本 ======
    # ALL_STYLES = ["casual", "detailed", "formal", "impatient", "patient"]
    # selected_items = []
    
    # for style in ALL_STYLES:
    #     style_items = [it for it in items if it.get("user_style") == style]
    #     if style_items:
    #         selected_items.append(style_items[0])
    #         print(f"[runner] 找到 {style} style 样本: {style_items[0].get('trip_id', 'unknown')}", flush=True)
    #     else:
    #         print(f"[runner] 警告：没有找到 {style} style 的样本", flush=True)
    
    # if selected_items:
    #     items = selected_items
    #     print(f"[runner] 测试模式：共选择 {len(items)} 个样本（每个 style 一个）", flush=True)
    # else:
    #     print(f"[runner] 警告：没有找到任何 style 的样本", flush=True)
    #     items = []
    # # ====== 测试代码结束 ======

    os.makedirs(os.path.dirname(output_jsonl_path), exist_ok=True)
    if not os.path.exists(output_jsonl_path):
        with open(output_jsonl_path, "w", encoding="utf-8") as f:
            pass

    done_keys = load_done_keys_from_output_jsonl(output_jsonl_path)
    print(f"[runner] already_done={len(done_keys)} (from {output_jsonl_path})", flush=True)

    filtered = []
    seen_in_this_run = set()
    skipped_done = 0
    skipped_dup = 0

    for it in items:
        trip_id = it.get("trip_id")

        # multi-turn 用 instruction_ids_basic 做 key；single-turn 保持原逻辑
        if dialog_turns > 1:
            fil = it.get("instruction_ids_basic", [])
        else:
            fil = final_id_list_from_meta(it)

        sk = make_sample_key(trip_id, fil)

        if sk in done_keys:
            skipped_done += 1
            continue
        if sk in seen_in_this_run:
            skipped_dup += 1
            continue

        seen_in_this_run.add(sk)
        filtered.append(it)

    total = len(filtered)
    print(
        f"[runner] loaded={len(items)} queued={total} "
        f"skipped_done={skipped_done} skipped_dup_in_test={skipped_dup}",
        flush=True
    )

    if total == 0:
        print(f"[Done] nothing to run. output={output_jsonl_path}", flush=True)
        return

    task_q = ctx.JoinableQueue()
    done_counter = ctx.Value("i", 0)
    stop_event = ctx.Event()
    file_lock = ctx.Lock()

    # ✅ 只起一个 IP 进程
    procs = []
    p = ctx.Process(
        target=_ip_process_main,
        args=(
            task_q,
            done_counter,
            total,
            stop_event,
            file_lock,
            output_jsonl_path,
            dialog_turns,
            keep_last_assistant,
            user_prompt_template,  # ✅ 新增
        ),
        daemon=False,
    )
    p.start()
    procs.append(p)

    for it in filtered:
        task_q.put(it)

    task_q.join()

    stop_event.set()
    for p in procs:
        p.join()

    print(f"[Done] total={total}, finished={done_counter.value}, output={output_jsonl_path}", flush=True)


def make_output_jsonl_path(test_path: str, dialog_turns: int, output_dir: str) -> str:
    test_name = os.path.splitext(os.path.basename(test_path))[0]
    output_mode = "single" if dialog_turns == 1 else "multi"
    agent_name = AGENT_CFG.model_name

    return os.path.join(
        output_dir,
        f"{test_name}_{agent_name}_{output_mode}_t{dialog_turns}.jsonl"
    )


if __name__ == "__main__":
    def _parse_tests_env(value: str) -> List[Tuple[str, str]]:
        out = []
        for raw in value.split(","):
            raw = raw.strip()
            if not raw:
                continue
            if ":" in raw:
                tag, path = raw.split(":", 1)
                tag = tag.strip()
                path = path.strip()
            else:
                path = raw
                tag = os.path.splitext(os.path.basename(path))[0]
            if tag and path:
                out.append((tag, path))
        return out

    _env_tests = os.environ.get("TESTS", "").strip()
    if _env_tests:
        TESTS = _parse_tests_env(_env_tests)
    else:
        TESTS = [
            # ("test_easy", "            # ("test_mid",  "            ("hard_vague_style", "        ]

    RUN_MODES = [
        # ("single", 1),
        ("multi", 15),
    ]

    PROMPT_BY_TEST = {
        # "test_easy": user_prompt_easy_en,
        # "test_mid":  user_prompt_mid_en,
        "hard_vague_style": user_prompt_hard_vague_style_en,
    }

    for test_tag, test_path in TESTS:
        user_prompt_template = PROMPT_BY_TEST.get(test_tag, user_prompt_easy_en)
        # print(user_prompt_template)

        for mode_tag, dialog_turns in RUN_MODES:
            output_jsonl_path = make_output_jsonl_path(
                test_path=test_path,
                dialog_turns=dialog_turns,
                output_dir=OUTPUT_DIR,
            )

            run_multiturn_file_distributed(
                test_path=test_path,
                dialog_turns=dialog_turns,
                keep_last_assistant=KEEP_LAST_ASSISTANT,
                output_jsonl_path=output_jsonl_path,
                user_prompt_template=user_prompt_template,  # ✅ 新增
            )
