import copy
import sys
import os
import json
import re
import time
import queue
import threading
import traceback
import multiprocessing as mp
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime

import yaml
from openai import OpenAI

# ====== PATH 修复 ======
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)
print("Added ROOT:", ROOT, flush=True)

# ====== 引入工具 ======
from Tools.General.General_tools import GeneralTool
from Tools.Attraction.Attraction_tools import AttractionTool
from Tools.Flight.Flight_tools import FlightTool
from Tools.Train.Train_tools import TrainTool
from Tools.Restaurant.Restaurant_tools import RestaurantTool
from Tools.Hotel.Hotel_tools import HotelTool

from Tools.tool_description import (
    Flight_tool_description,
    Train_tool_description,
    Hotel_tool_description,
    Attraction_tool_description,
    Restaurant_tool_description,
    General_tool_description,
)

from interact.prompt.agent_system_prompt import system_prompt_en, system_prompt_en_single_turn
from interact.prompt.user_prompt import user_prompt_easy_en, user_prompt_mid_en, user_prompt_easy_en_no_issue
from interact.user_agent import UserSimulator
from interact.evaluation_agent import TripPlanEvaluator

from Evaluation.Attraction.val_attractions import AttractionEvaluator
from Evaluation.Restaurant.val_restaurants import RestaurantEvaluator
from Evaluation.Hotel.val_hotels import HotelEvaluator
from Evaluation.Transportation.val_transports import TransportationEvaluator
from Evaluation.General.val_general import GeneralEvaluator, load_or_build_cache


# ============================================================
#                    CONFIG / RUNNER CONFIG
# ============================================================

CONFIG_PATH = " =========================
# ✅ Agent/User 独立配置
# =========================

@dataclass
class AgentLLMConfig:
    model_name: str = "qwen2.5-32b-instruct"
    temperature: float = 0.7
    max_tokens: int = 64 * 1024
    tool_choice: str = "auto"
    thinking: bool = False
    extra_body: Optional[Dict[str, Any]] = None  # 额外透传给服务端
    timeout_sec: float = 1200.0
    base_url: Optional[str] = None
    api_key: str = "1737787093780320300"


@dataclass
class UserLLMConfig:
    model_name: str = "deepseek-v32"
    temperature: float = 0.7
    max_tokens: int = 16 * 1024
    retry_temperature: float = 0.7
    extra_body: Optional[Dict[str, Any]] = None
    timeout_sec: float = 120.0
    base_url: Optional[str] = None
    api_key: str = "1737787093780320300"


USER_IP_POOL = [
    "33.253.98.213",     # primary
    "33.253.210.217",    
    "33.253.69.150",
    "33.253.178.54"
]

# USER_IP_POOL = [
#     "33.253.210.217",    
#     "33.253.69.150",
#     "33.253.178.54",
#      "33.253.98.213" 
# ]

# USER_IP_POOL = [ 
#     "33.253.69.150",
#     "33.253.178.54"
# ]

# USER_IP_POOL = [ 
#     "33.253.178.54",
#     "33.253.69.150"
# ]

DEFAULT_USER_URLS = [f"http://{ip}:8000/v1" for ip in USER_IP_POOL]

# USER_CFG.base_url 仍然填第一个，真正容错在 TravelAgent 内切换
USER_CFG = UserLLMConfig(base_url=DEFAULT_USER_URLS[0])




REQUESTS_PER_IP = 20

# AGENT_MODEL_NAME = "qwen2.5-32b-instruct"
# AGENT_IP = "33.253.202.152"
# DEFAULT_AGENT_URL = f"http://{AGENT_IP}:8000/v1"

# AGENT_MODEL_NAME = "LongCat-Flash-Thinking"
AGENT_MODEL_NAME = "LongCat-Flash-Chat"
DEFAULT_AGENT_URL = "https://basicaiservice.sankuai.com/basicai/v1"
AGENT_CFG = AgentLLMConfig(model_name=AGENT_MODEL_NAME, base_url=DEFAULT_AGENT_URL)


# "LongCat-Flash-Chat"
# "LongCat-Flash-Thinking"
# "qwen2.5-32b-instruct"
# "Qwen2.5-14B-Instruct"
# "Toucan-Qwen2.5-14B-Instruct-v0.1"
# "Toucan-Qwen2.5-32B-Instruct-v0.1"


# IP 冷却
COOLDOWN_SECONDS = 1 * 1

# 多轮
KEEP_LAST_ASSISTANT = 3

OUTPUT_DIR = " ============================================================
#           进程内资源初始化（避免 spawn 重复模块级大加载）
# ============================================================
GLOBAL_RESOURCES = None  # fork 后子进程继承（CoW）

def _load_resources_for_process(config_path: str) -> Dict[str, Any]:
    """
    在每个子进程里调用一次：加载 config、数据、cache、evaluator。
    避免 spawn 模式下 import 阶段就卡住/重复重建大对象。
    """
    with open(config_path, "r", encoding="utf-8") as f:
        config = yaml.safe_load(f)

    data_path = config["data_path"]
    attraction_path = data_path["attraction"]
    restaurant_path = data_path["restaurant"]
    hotel_path = data_path["hotel"]
    flight_path = data_path["flight"]
    train_path = data_path["train"]

    # evaluators
    with open(attraction_path, "r", encoding="utf-8") as f:
        attractions_data = json.load(f)

    attraction_evaluator = AttractionEvaluator(attractions_data)
    restaurant_evaluator = RestaurantEvaluator(restaurants_path=restaurant_path)
    hotel_evaluator = HotelEvaluator(hotel_path=hotel_path)
    transportation_evaluator = TransportationEvaluator(flights_path=flight_path)

    # 这个通常很重：放到进程内初始化，但只做一次
    data_cache = load_or_build_cache(config_path)
    general_evaluator = GeneralEvaluator(data_cache)

    return {
        "config": config,
        "attraction_evaluator": attraction_evaluator,
        "restaurant_evaluator": restaurant_evaluator,
        "hotel_evaluator": hotel_evaluator,
        "transportation_evaluator": transportation_evaluator,
        "general_evaluator": general_evaluator,
    }


# ============================================================
# 工具调用封装类（保持你原来的结构）
# ============================================================

class FunctionCall:
    def __init__(self, data: Dict[str, Any]):
        self.name = data.get("name")
        self.arguments = data.get("arguments")


class ToolCall:
    def __init__(self, data: Dict[str, Any]):
        self.id = data.get("id")
        self.type = data.get("type")
        self.function = FunctionCall(data.get("function") or {})


class AssistantMessage:
    def __init__(self, data: Dict[str, Any]):
        self.role = data.get("role")
        self.content = data.get("content")
        self.tool_calls = [ToolCall(tc) for tc in (data.get("tool_calls") or [])]


def serialize_tool_calls(tool_calls):
    if not tool_calls:
        return None
    out = []
    for call in tool_calls:
        out.append({
            "id": call.id,
            "type": call.type,
            "function": {
                "name": call.function.name,
                "arguments": call.function.arguments,
            }
        })
    return out


# ============================================================
# 每轮用户状态记录
# ============================================================

@dataclass
class RoundUserState:
    round_idx: int
    id_list_before: List[str]
    id_list_after: List[str]
    blocks: Dict[str, str]                # HISTORY/NEW/MODIFY/ISSUE
    sim_prompt: Optional[str]
    sim_output: Optional[Dict[str, Any]]
    newly_added_ids: List[str]
    all_added_ids: List[str]
    final_answer: str
    eval_result: Dict[str, Any]
    use_next_turn_reward: Optional[bool]
    start_time: Optional[str]
    end_time: Optional[str]
    tool_calls: List[Any]


# ============================================================
#                  TravelAgent 主类（OpenAI SDK 版本）
# ============================================================

class TravelAgent:
    """
    每个进程只初始化一个 TravelAgent。
    为了线程安全与并发，OpenAI client 使用 thread-local（每线程一个 client）。
    """
    def __init__(
        self,
        *,
        config_path: str,
        agent_cfg: AgentLLMConfig,
        user_cfg: UserLLMConfig,
        user_prompt_template: str = user_prompt_easy_en,  # ✅ 新增
        timeout_sec: float = 1200.0,  # 兼容旧参数（不用也不影响）
        # ✅ 新增：user 端 base_url 后备列表
        user_base_urls: Optional[List[str]] = None,
        # evaluators 注入（进程内初始化一次）
        attraction_evaluator: AttractionEvaluator,
        restaurant_evaluator: RestaurantEvaluator,
        hotel_evaluator: HotelEvaluator,
        transportation_evaluator: TransportationEvaluator,
        general_evaluator: GeneralEvaluator,
    ):
        self.agent_cfg = agent_cfg
        self.user_cfg = user_cfg
        self.user_prompt_template = user_prompt_template  # ✅ 新增

        # ✅ agent/user base_url & api_key & timeout 独立
        self._agent_base_url = agent_cfg.base_url
        self._agent_api_key = agent_cfg.api_key
        self._agent_timeout_sec = agent_cfg.timeout_sec

        self._user_base_url = user_cfg.base_url
        self._user_api_key = user_cfg.api_key
        self._user_timeout_sec = user_cfg.timeout_sec

        self._tlocal = threading.local()

        # ---------------------------
        # ✅ user base_url pool（主+后备）
        # ---------------------------
        urls = user_base_urls or ([self._user_base_url] if self._user_base_url else [])
        seen = set()
        self._user_base_urls = []
        for u in urls:
            if u and u not in seen:
                self._user_base_urls.append(u)
                seen.add(u)

        if not self._user_base_urls:
            self._user_base_urls = [self._user_base_url]

        try:
            self._user_url_idx = (
                self._user_base_urls.index(self._user_base_url)
                if self._user_base_url in self._user_base_urls
                else 0
            )
        except Exception:
            self._user_url_idx = 0
            if self._user_base_urls:
                self._user_base_url = self._user_base_urls[0]

        self._user_url_lock = threading.Lock()

        # 工具类
        self.general_tool = GeneralTool()
        self.attraction_tool = AttractionTool(config_path)
        self.flight_tool = FlightTool(config_path)
        self.train_tool = TrainTool(config_path)
        self.restaurant_tool = RestaurantTool(config_path)
        self.hotel_tool = HotelTool(config_path)

        # 工具映射
        self.tools = {
            "search_flights": self.flight_tool.search_flights,
            "get_flight_detail_with_products": self.flight_tool.get_flight_detail_with_products,
            "get_airport_coordinates": self.flight_tool.get_airport_coordinates,

            "search_trains": self.train_tool.search_trains,
            "get_train_detail_with_products": self.train_tool.get_train_detail_with_products,
            "get_station_coordinates": self.train_tool.get_station_coordinates,

            "search_hotels": self.hotel_tool.search_hotels,
            "get_hotel_detail_with_products": self.hotel_tool.get_hotel_detail_with_products,
            "get_hotel_coordinates": self.hotel_tool.get_hotel_coordinates,

            "search_attractions": self.attraction_tool.search_attractions,
            "get_attraction_detail_with_products": self.attraction_tool.get_attraction_detail_with_products,
            "get_attraction_coordinates": self.attraction_tool.get_attraction_coordinates,

            "search_restaurants": self.restaurant_tool.search_restaurants,
            "get_restaurant_detail_with_products": self.restaurant_tool.get_restaurant_detail_with_products,
            "get_restaurant_coordinates": self.restaurant_tool.get_restaurant_coordinates,

            "get_route_estimate": self.general_tool.get_route_estimate,
            "get_city_center_coords": self.general_tool.get_city_center_coords,
            "get_date_after": self.general_tool.get_date_after,
        }

        # 工具描述
        self.tool_descriptions = (
            Flight_tool_description
            + Train_tool_description
            + Hotel_tool_description
            + Attraction_tool_description
            + Restaurant_tool_description
            + General_tool_description
        )

        # Evaluator
        self.trip_plan_evaluator = TripPlanEvaluator(
            attraction_evaluator=attraction_evaluator,
            restaurant_evaluator=restaurant_evaluator,
            hotel_evaluator=hotel_evaluator,
            transportation_evaluator=transportation_evaluator,
            general_evaluator=general_evaluator,
        )


    def _get_agent_client(self) -> OpenAI:
        if not hasattr(self._tlocal, "agent_client"):
            self._tlocal.agent_client = OpenAI(
                api_key=self._agent_api_key,
                base_url=self._agent_base_url,
                timeout=self._agent_timeout_sec,
            )
        return self._tlocal.agent_client

    def _get_user_client(self) -> OpenAI:
        if not hasattr(self._tlocal, "user_client"):
            self._tlocal.user_client = OpenAI(
                api_key=self._user_api_key,
                base_url=self._user_base_url,
                timeout=self._user_timeout_sec,
            )
        return self._tlocal.user_client

    def _get_user_simulator(self) -> UserSimulator:
        if not hasattr(self._tlocal, "user_sim"):
            self._tlocal.user_sim = UserSimulator(
                client=self._get_user_client(),
                model_name=self.user_cfg.model_name,
                temperature=self.user_cfg.temperature,
                max_tokens=self.user_cfg.max_tokens,
            )
        return self._tlocal.user_sim

    def _invalidate_user_thread_local(self):
        """只清掉当前线程的 user client/sim，让它们按新 base_url 重建。"""
        if hasattr(self._tlocal, "user_client"):
            delattr(self._tlocal, "user_client")
        if hasattr(self._tlocal, "user_sim"):
            delattr(self._tlocal, "user_sim")
    
    def rotate_user_base_url(self, reason: str = "") -> str:
        """
        切到后备 base_url（轮询），并让当前线程下次请求重建 client/sim。
        返回新的 base_url。
        """
        with self._user_url_lock:
            if not self._user_base_urls:
                return self._user_base_url

            old = self._user_base_url
            self._user_url_idx = (self._user_url_idx + 1) % len(self._user_base_urls)
            self._user_base_url = self._user_base_urls[self._user_url_idx]

        self._invalidate_user_thread_local()
        print(f"[user-url] rotate: {old} -> {self._user_base_url}. reason={reason}", flush=True)
        return self._user_base_url


    # ======================== 核心 Chat Loop ========================
    def chat(self, history: Optional[List[Dict[str, str]]] = None, max_rounds=1000):
        if history is None:
            history = []

        trace_log = {
            "start_time": datetime.now().isoformat(),
            "messages": [],
            "trace_log_messages": [],
            "tool_calls": [],
            "final_answer": None,
        }

        messages = copy.deepcopy(history)
        trace_log["trace_log_messages"].extend(messages)

        for _ in range(max_rounds):
            client = self._get_agent_client()
            if self.agent_cfg.model_name == "deepseek-v32":
                extra_body = {"chat_template_kwargs": {"thinking": bool(self.agent_cfg.thinking)}}
                if isinstance(self.agent_cfg.extra_body, dict):
                    extra_body.update(self.agent_cfg.extra_body)

                resp = client.chat.completions.create(
                    model=self.agent_cfg.model_name,
                    messages=messages,
                    tools=self.tool_descriptions,
                    tool_choice=self.agent_cfg.tool_choice,
                    temperature=self.agent_cfg.temperature,
                    max_tokens=self.agent_cfg.max_tokens,
                    extra_body=extra_body,
                )
            else:
                resp = client.chat.completions.create(
                    model=self.agent_cfg.model_name,
                    messages=messages,
                    tools=self.tool_descriptions,
                    tool_choice=self.agent_cfg.tool_choice,
                    temperature=self.agent_cfg.temperature,
                    max_tokens=self.agent_cfg.max_tokens,
                )
            # print(resp)
            assistant_msg_dict = resp.choices[0].message.model_dump()
            messages.append(assistant_msg_dict)

            assistant_msg = AssistantMessage(assistant_msg_dict)
            trace_log["trace_log_messages"].append({
                "role": "assistant",
                "content": assistant_msg.content,
                "tool_calls": serialize_tool_calls(assistant_msg.tool_calls),
            })

            tool_calls = assistant_msg.tool_calls

            # ---------- 没有工具调用 = 最终答案 ----------
            if not tool_calls:
                final_answer = assistant_msg.content or ""
                trace_log["messages"] = messages
                trace_log["final_answer"] = final_answer
                trace_log["end_time"] = datetime.now().isoformat()
                return trace_log

            # ---------- 工具调用 ----------
            for call in tool_calls:
                tool_name = call.function.name
                raw_args = call.function.arguments or "{}"

                try:
                    args = json.loads(raw_args)
                except Exception:
                    args = {}

                tool_func = self.tools.get(tool_name)

                if tool_func is None:
                    result = {"error": f"Unknown tool {tool_name}"}
                    trace_log["tool_calls"].append({
                        "tool": tool_name,
                        "args": args,
                        "result": {"success": False, "error": result["error"]},
                    })
                else:
                    try:
                        output = tool_func(**args)
                        result = output
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": True, "data": output},
                        })
                    except Exception as e:
                        err = str(e)
                        result = {"error": err}
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": False, "error": err},
                        })

                if isinstance(result, str):
                    content = result
                else:
                    content = json.dumps(result, ensure_ascii=False)

                tool_msg = {
                    "role": "tool",
                    "tool_call_id": call.id,
                    "name": tool_name,
                    "content": content,
                }

                messages.append(tool_msg)
                trace_log["trace_log_messages"].append(tool_msg)

        trace_log["messages"] = messages
        trace_log["final_answer"] = "Too many rounds."
        trace_log["end_time"] = datetime.now().isoformat()
        return trace_log

    # ============================================================
    #                  Multi-turn helpers（静态方法）
    # ============================================================

    @staticmethod
    def _safe_format(desc: str, slot: str) -> str:
        try:
            return desc.format(slot=slot)
        except Exception:
            return (desc or "").replace("{slot}", slot)

    @staticmethod
    def _build_id2instruction(meta_dict: Dict[str, Any]) -> Dict[str, str]:
        id2instr = {}
        rubric_results = meta_dict.get("rubric_results", {}) or {}

        for _, rubric_group in rubric_results.items():
            for _, cfg in (rubric_group or {}).items():
                desc_tmpl = cfg.get("description", "") or ""
                result = cfg.get("result", {}) or {}
                all_labels = result.get("all_labels_and_ranges", {}) or {}

                for label_text, payload in all_labels.items():
                    _id = (payload or {}).get("_id")
                    if not _id:
                        continue
                    instr = TravelAgent._safe_format(desc_tmpl, label_text).strip()
                    id2instr[_id] = instr
        return id2instr

    @staticmethod
    def _flatten_applied_chains(meta_dict: Dict[str, Any]) -> List[str]:
        chains = meta_dict.get("applied_modification_chains") or {}
        out = []
        for _, chain in chains.items():
            if chain:
                out.append(chain[-1])
        return out

    @staticmethod
    def _build_blocks(meta_dict: Dict[str, Any], id_list: List[str], rubric_progress: Dict[str, int]) -> Dict[str, str]:
        id2instr = TravelAgent._build_id2instruction(meta_dict)

        # ===== HISTORY =====
        history_lines = []
        for cid in id_list:
            txt = id2instr.get(cid)
            if txt:
                history_lines.append(f"[ID: HIS_{cid}] {txt}")

        chains = meta_dict.get("applied_modification_chains", {}) or {}
        new_candidates = []
        modify_items = []  # (rubric_key, expected_id)

        for rubric_key in sorted(chains.keys()):
            chain = chains[rubric_key] or []
            if not chain:
                continue

            prog = rubric_progress.get(rubric_key, 0)

            if prog >= len(chain):
                continue

            if prog == 0:
                new_candidates.append(chain[0])
            else:
                modify_items.append((rubric_key, chain[prog]))

        # ===== NEW =====
        new_lines = []
        for cid in new_candidates:
            txt = id2instr.get(cid)
            if txt:
                new_lines.append(f"[ID: NEW_{cid}] {txt}")

        # ===== MODIFY =====
        modify_lines = []
        for rubric_key, expected_id in modify_items:
            chain = chains.get(rubric_key, [])
            prog = rubric_progress.get(rubric_key, 0)
            expected_txt = id2instr.get(expected_id)
            if not expected_txt:
                continue

            if expected_id in chain[:prog]:
                removed_id = chain[prog - 1]
                removed_txt = id2instr.get(removed_id, "")
                line = (
                    f"[ID: MOD_{expected_id}] {expected_txt} "
                    f"(Note: the previous requirement {removed_txt} is no longer needed)"
                )
            else:
                line = f"[ID: MOD_{expected_id}] {expected_txt}"

            modify_lines.append(line)

        return {
            "HISTORY": "\n".join(history_lines).strip(),
            "NEW": "\n".join(new_lines).strip(),
            "MODIFY": "\n".join(modify_lines).strip(),
        }

    @staticmethod
    def _history_text_for_sim(conv_simple: List[Dict[str, str]], keep_last_assistant: int = 3) -> str:
        OMIT_TEXT = "Earlier assistant responses are omitted due to context length limits."
        assistant_indices = [i for i, m in enumerate(conv_simple) if m["role"] == "assistant"]
        keep_indices = set(assistant_indices[-keep_last_assistant:])

        lines = []
        for i, m in enumerate(conv_simple):
            if m["role"] == "user":
                lines.append(f"User: {m['content']}")
            elif m["role"] == "assistant":
                if i in keep_indices:
                    lines.append(f"Assistant: {m['content']}")
                else:
                    lines.append(f"Assistant: {OMIT_TEXT}")
        return "\n".join(lines).strip()

    def _apply_proposed_ids_with_prefix(
        self,
        meta_dict: Dict[str, Any],
        current_id_list: List[str],
        rubric_progress: Dict[str, int],
        proposed_ids: List[Any],
    ) -> Tuple[bool, List[str], List[str], List[str]]:
        chains = meta_dict.get("applied_modification_chains", {}) or {}
        id2pos: Dict[str, Tuple[str, int]] = {}
        for rubric_key, chain in chains.items():
            for idx, cid in enumerate(chain):
                id2pos[cid] = (rubric_key, idx)

        errors: List[str] = []
        newly_added: List[str] = []
        removed: List[str] = []

        if not isinstance(proposed_ids, list):
            return False, [], [], [f"instruction_ids is not a list: {type(proposed_ids)}"]

        for raw in proposed_ids:
            if not isinstance(raw, str) or not raw:
                errors.append(f"invalid instruction_id (not str/empty): {raw!r}")
                continue

            if raw.startswith("NEW_"):
                op = "NEW"
                base_id = raw[len("NEW_"):]
            elif raw.startswith("MOD_"):
                op = "MOD"
                base_id = raw[len("MOD_"):]
            else:
                continue

            if base_id not in id2pos:
                errors.append(f"{raw} -> {base_id} not found in applied_modification_chains")
                continue

            rubric_key, idx = id2pos[base_id]
            chain = chains.get(rubric_key, []) or []

            if op == "NEW":
                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)
                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

            else:  # MOD
                if idx <= 0:
                    errors.append(f"{raw}: target is first element of chain, no previous to remove")
                    continue

                prev_id = chain[idx - 1]
                if prev_id not in current_id_list:
                    errors.append(f"{raw}: expected to remove prev_id={prev_id}, but it's not in current_id_list")
                    continue

                try:
                    current_id_list.remove(prev_id)
                    removed.append(prev_id)
                except ValueError:
                    errors.append(f"{raw}: failed to remove prev_id={prev_id} (ValueError)")
                    continue

                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)

                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

        ok = (len(errors) == 0)
        return ok, newly_added, removed, errors

    def rubric_all_satisfied(self, rubric_progress: dict, applied_chains: dict) -> bool:
        for rubric, chains in applied_chains.items():
            required = len(chains)
            current = rubric_progress.get(rubric, 0)
            if current < required:
                return False
        return True

    def select_raw_answer_for_eval(self, current_raw_answer: str, conv_simple: List[Dict]) -> str:
        def parse_trip_plan_json(raw_text: str) -> Optional[dict]:
            # 1) 纯 JSON
            try:
                data = json.loads(raw_text)
                return data if isinstance(data, dict) else None
            except Exception:
                pass

            # 2) ```json ... ```
            m = re.search(r"```json\s*(\{.*?\})\s*```", raw_text, re.S)
            if m:
                try:
                    data = json.loads(m.group(1))
                    return data if isinstance(data, dict) else None
                except Exception:
                    return None

            return None

        def should_fallback_only_if_empty_schedule(raw_text: str) -> bool:
            data = parse_trip_plan_json(raw_text)
            if not isinstance(data, dict):
                return False

            tp = data.get("trip_plan")
            if not isinstance(tp, dict):
                return False

            # 关键：必须有 daily_schedule 字段，且必须是“空 list”
            if "daily_schedule" not in tp:
                return False

            ds = tp.get("daily_schedule")
            return isinstance(ds, list) and len(ds) == 0

        if not should_fallback_only_if_empty_schedule(current_raw_answer):
            return current_raw_answer

        def extract_nonempty_schedule(raw_text: str) -> bool:
            data = parse_trip_plan_json(raw_text)
            if not isinstance(data, dict):
                return False
            tp = data.get("trip_plan")
            if not isinstance(tp, dict):
                return False
            ds = tp.get("daily_schedule")
            return isinstance(ds, list) and len(ds) > 0

        for msg in reversed(conv_simple):
            if msg.get("role") != "assistant":
                continue
            raw_text = msg.get("content", "")
            if extract_nonempty_schedule(raw_text):
                return raw_text

        return current_raw_answer

    def run_multiturn_trip(
        self,
        meta_dict: Dict[str, Any],
        dialog_turns: int = 1,
        keep_last_assistant: int = 3,
        extra_save_dir: str = "__unused__",
    ) -> Dict[str, Any]:

        trip_id = meta_dict.get("trip_id", "unknown_trip")
        rubric_progress = copy.deepcopy(meta_dict.get("rubric_progress", {}))

        conv_full: List[Dict[str, str]] = []
        conv_simple: List[Dict[str, str]] = []
        rounds: List[RoundUserState] = []

        if dialog_turns == 1:
            current_user_query = meta_dict.get("query_with_constraints", meta_dict.get("query_basic", ""))
            current_id_list = self._flatten_applied_chains(meta_dict)
        else:
            current_user_query = meta_dict.get("query_basic", "")
            current_id_list = meta_dict.get("instruction_ids_basic", [])

        issue_text = ""

        for r in range(1, dialog_turns + 1):
            # ====== 快照：保证失败轮不落盘 & 回滚到上一轮 ======
            snap_conv_full = copy.deepcopy(conv_full)
            snap_conv_simple = copy.deepcopy(conv_simple)
            snap_id_list = list(current_id_list)
            snap_progress = copy.deepcopy(rubric_progress)
            snap_issue_text = issue_text

            try:
                id_before = list(current_id_list)

                blocks = self._build_blocks(meta_dict, current_id_list, rubric_progress)
                blocks["ISSUE"] = issue_text or ""

                sim_prompt = None
                sim_output = None
                newly_added = []
                all_added = []
                use_next_turn_reward = False

                if r >= 2:
                    history_text = self._history_text_for_sim(conv_simple, keep_last_assistant)
                    max_retry = 3
                    retry = 0

                    while True:
                        sim_prompt = self.user_prompt_template \
                            .replace("{{HISTORY}}", blocks.get("HISTORY", "")) \
                            .replace("{{NEW}}", blocks.get("NEW", "")) \
                            .replace("{{MODIFY}}", blocks.get("MODIFY", "")) \
                            .replace("{{ISSUE}}", blocks.get("ISSUE", "")) \
                            .replace("{{HISTORY_MESSAGES}}", history_text)

                        # ----------------------------------------------------------
                        # ✅ 核心改动：只要 user_sim.generate 报错，就换后备 URL 重试
                        # ----------------------------------------------------------
                        max_url_switch = max(1, len(getattr(self, "_user_base_urls", []) or []))
                        url_switch_cnt = 0
                        last_err = None

                        while True:
                            user_sim = self._get_user_simulator()
                            try:
                                # ✅ user 参数完全独立：默认温度、重试温度都从 user_cfg 取
                                if retry > 0:
                                    sim_output = user_sim.generate(sim_prompt, temperature=self.user_cfg.retry_temperature)
                                else:
                                    sim_output = user_sim.generate(sim_prompt, temperature=self.user_cfg.temperature)
                                last_err = None
                                break  # ✅ 成功
                            except Exception as e:
                                last_err = e
                                # ✅ 只要报错就换 ip/url（轮询）
                                if url_switch_cnt < max_url_switch:
                                    url_switch_cnt += 1
                                    self.rotate_user_base_url(reason=f"user_sim.generate error: {e}")
                                    continue
                                # ✅ 所有 url 都试过了仍失败：抛最后一个错误
                                raise last_err

                        proposed_ids = sim_output.get("instruction_ids", []) or []
                        use_next_turn_reward = (
                            not proposed_ids or
                            all(isinstance(pid, str) and pid.startswith("HIS_") for pid in proposed_ids)
                        )

                        tmp_id_list = list(current_id_list)
                        tmp_progress = copy.deepcopy(rubric_progress)

                        ok, newly_added_tmp, _, errors = self._apply_proposed_ids_with_prefix(
                            meta_dict=meta_dict,
                            current_id_list=tmp_id_list,
                            rubric_progress=tmp_progress,
                            proposed_ids=proposed_ids,
                        )

                        if ok:
                            current_id_list = tmp_id_list
                            rubric_progress = tmp_progress
                            newly_added = newly_added_tmp
                            current_user_query = sim_output.get("user_query", "")
                            all_added = proposed_ids
                            break

                        retry += 1
                        if retry >= max_retry:
                            newly_added = []
                            current_user_query = sim_output.get("user_query", "")
                            break

                if r == 1:
                    if dialog_turns == 1:
                        conv_full.append({"role": "system", "content": system_prompt_en_single_turn})
                    else:
                        conv_full.append({"role": "system", "content": system_prompt_en})

                conv_full.append({"role": "user", "content": current_user_query})
                conv_simple.append({"role": "user", "content": current_user_query})

                trace_log = self.chat(history=conv_full)
                conv_full = copy.deepcopy(trace_log["messages"])
                conv_simple.append({"role": "assistant", "content": trace_log["final_answer"]})

                raw_answer_for_eval = self.select_raw_answer_for_eval(
                    current_raw_answer=trace_log["final_answer"],
                    conv_simple=conv_simple
                )

                eval_result = self.trip_plan_evaluator.evaluate(
                    meta_dict=meta_dict,
                    raw_text=raw_answer_for_eval,
                    id_list=current_id_list
                )
                issue_text = eval_result.get("error_text", "") or ""

                round_state = RoundUserState(
                    round_idx=r,
                    id_list_before=id_before,
                    id_list_after=list(current_id_list),
                    blocks=blocks,
                    sim_prompt=sim_prompt,
                    sim_output=sim_output,
                    newly_added_ids=newly_added,
                    all_added_ids=all_added,
                    final_answer=trace_log["final_answer"],
                    tool_calls=trace_log["tool_calls"],
                    eval_result=eval_result,
                    use_next_turn_reward=use_next_turn_reward,
                    start_time=trace_log["start_time"],
                    end_time=trace_log["end_time"],
                )
                rounds.append(round_state)

                applied_chains = meta_dict.get("applied_modification_chains", {})
                if self.rubric_all_satisfied(rubric_progress, applied_chains):
                    print(f"All rubrics are proposed at round {r}", flush=True)
                    break

            except Exception as e:
                # ====== 本轮失败：回滚 ======
                conv_full = snap_conv_full
                conv_simple = snap_conv_simple
                current_id_list = snap_id_list
                rubric_progress = snap_progress
                issue_text = snap_issue_text

                # ====== 关键规则：只要已有轮次>=1就返回 partial；否则继续抛给上层按原逻辑处理 ======
                if len(rounds) >= 1:
                    summary = {
                        "trip_id": trip_id,
                        "dialog_turns": dialog_turns,
                        "keep_last_assistant": keep_last_assistant,
                        "final_id_list": list(current_id_list),
                        "messages": conv_full,
                        "rounds": [asdict(x) for x in rounds],
                        "terminate_reason": {
                            "round_failed": r,
                            "error": str(e),
                        },
                    }
                    return summary
                else:
                    raise

        summary = {
            "trip_id": trip_id,
            "dialog_turns": dialog_turns,
            "keep_last_assistant": keep_last_assistant,
            "final_id_list": list(current_id_list),
            "messages": conv_full,
            "rounds": [asdict(x) for x in rounds],
            "terminate_reason": None,
        }
        return summary


# ============================================================
#                    DISTRIBUTED RUNNER
# ============================================================
def canonical_final_id_list(final_id_list):
    """去重+排序，保证顺序不影响相等判定"""
    if not isinstance(final_id_list, list):
        return tuple()
    return tuple(sorted({str(x) for x in final_id_list if x is not None}))

def make_sample_key(trip_id, final_id_list):
    return (str(trip_id), canonical_final_id_list(final_id_list))

def final_id_list_from_meta(meta: dict):
    """
    与你参考代码一致：
    - 优先 applied_modification_chains：每个 rubric_key 的 chain 取最后一个 id
    - 兜底：meta["final_id_list"]
    """
    chains = meta.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        out = []
        for _, chain in chains.items():
            if isinstance(chain, list) and chain:
                out.append(chain[-1])
        return out

    fil = meta.get("final_id_list")
    if isinstance(fil, list):
        return fil
    return []

def final_id_list_from_output_record(rec: dict):
    """
    从输出 jsonl 的一条 record 里取用于判重的 id_list：

    - dialog_turns > 1（multi-turn）：
        优先用 rounds[0].id_list_before（它在你的实现里就是第1轮开始前的 current_id_list，
        通常等于 meta["instruction_ids_basic"]），避免 summary.final_id_list 漂移（没跑完/中断）。

    - dialog_turns <= 1（single-turn）：
        保持原逻辑：优先 summary.final_id_list，兜底 rec.final_id_list
    """
    if not isinstance(rec, dict):
        return []

    summary = rec.get("summary") or {}

    dt = rec.get("dialog_turns")
    if dt is None:
        dt = summary.get("dialog_turns")

    try:
        dt = int(dt or 0)
    except Exception:
        dt = 0

    if dt > 1:
        rounds = summary.get("rounds")
        if isinstance(rounds, list) and len(rounds) > 0:
            r0 = rounds[0] or {}
            fil0 = r0.get("id_list_before")
            if isinstance(fil0, list):
                return fil0

        fil_basic = summary.get("instruction_ids_basic")
        if isinstance(fil_basic, list):
            return fil_basic

        fil = summary.get("final_id_list")
        if isinstance(fil, list):
            return fil

        fil2 = rec.get("final_id_list")
        if isinstance(fil2, list):
            return fil2

        return []

    fil = summary.get("final_id_list")
    if isinstance(fil, list):
        return fil
    fil2 = rec.get("final_id_list")
    if isinstance(fil2, list):
        return fil2
    return []

def load_done_keys_from_output_jsonl(output_jsonl_path: str):
    """
    扫 output jsonl，构建已经跑过的 sample_key 集合。
    不要求每行都有 eval；只要 summary.final_id_list 写了就算跑过。
    """
    done = set()
    if not output_jsonl_path or (not os.path.exists(output_jsonl_path)):
        return done

    with open(output_jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except Exception:
                continue

            trip_id = rec.get("trip_id") or (rec.get("summary") or {}).get("trip_id")
            fil = final_id_list_from_output_record(rec)
            done.add(make_sample_key(trip_id, fil))

    return done


def get_resources_for_process(config_path: str):
    """
    子进程入口获取资源：
    - fork：GLOBAL_RESOURCES 已继承，直接返回（不再 load）
    - spawn：GLOBAL_RESOURCES 为空，只能各进程自己 load 一份
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is not None:
        return GLOBAL_RESOURCES
    return _load_resources_for_process(config_path)

def preload_resources_in_parent(config_path: str):
    """
    只在父进程里调用（fork 前）。
    fork 后子进程会继承 GLOBAL_RESOURCES，并通过 Copy-on-Write 共享内存页。
    """
    global GLOBAL_RESOURCES
    if GLOBAL_RESOURCES is None:
        print("[runner] preload resources in parent ...", flush=True)
        GLOBAL_RESOURCES = _load_resources_for_process(config_path)
        print("[runner] preload resources done.", flush=True)

        try:
            import gc
            gc.freeze()
            print("[runner] gc.freeze() done.", flush=True)
        except Exception:
            pass


def _append_jsonl(file_lock: mp.Lock, output_jsonl_path: str, obj: dict):
    line = json.dumps(obj, ensure_ascii=False)
    with file_lock:
        with open(output_jsonl_path, "a", encoding="utf-8") as f:
            f.write(line + "\n")


def _ip_process_main(
    task_q,                 # mp.JoinableQueue
    done_counter,           # mp.Value('i', 0)
    total_tasks: int,
    stop_event,             # mp.Event
    file_lock,              # mp.Lock
    output_jsonl_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
    user_prompt_template: str,   # ✅ 新增
):
    cooldown_until = 0.0

    print("[worker] before get resources", flush=True)
    resources = get_resources_for_process(CONFIG_PATH)
    print("[worker] after get resources (shared if fork)", flush=True)

    # ✅ agent/user 参数独立（来自 AGENT_CFG/USER_CFG）
    agent = TravelAgent(
        config_path=CONFIG_PATH,
        agent_cfg=AGENT_CFG,
        user_cfg=USER_CFG,
        user_prompt_template=user_prompt_template,  # ✅ 新增
        # ✅ 新增：把 user 端主+后备 URL 列表传进去
        user_base_urls=DEFAULT_USER_URLS,
        attraction_evaluator=resources["attraction_evaluator"],
        restaurant_evaluator=resources["restaurant_evaluator"],
        hotel_evaluator=resources["hotel_evaluator"],
        transportation_evaluator=resources["transportation_evaluator"],
        general_evaluator=resources["general_evaluator"],
    )

    def set_cooldown():
        nonlocal cooldown_until
        cooldown_until = max(cooldown_until, time.time() + COOLDOWN_SECONDS)
        print(
            "[worker] ERROR -> cooldown "
            f"{COOLDOWN_SECONDS}s until "
            f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(cooldown_until))}",
            flush=True
        )

    def wait_cooldown():
        while not stop_event.is_set() and time.time() < cooldown_until:
            time.sleep(1.0)

    def _is_max_tokens_bad_request(e: Exception) -> bool:
        """
        只针对你贴的这一类错误：
        Error code: 400 - {'error': {'message': 'max_tokens must be at least 1, got -793.' ...}}
        命中则：不 cooldown + 不重试 + 丢弃该 item
        """
        try:
            s = str(e)
        except Exception:
            s = repr(e)

        return "max_tokens must be at least 1" in s

    def worker_loop(worker_id: int):
        while True:
            wait_cooldown()

            if stop_event.is_set():
                try:
                    item = task_q.get(timeout=1.0)
                except queue.Empty:
                    return
                task_q.task_done()
                continue

            try:
                item = task_q.get(timeout=1.0)
            except queue.Empty:
                continue

            try:
                summary = agent.run_multiturn_trip(
                    meta_dict=item,
                    dialog_turns=dialog_turns,
                    keep_last_assistant=keep_last_assistant,
                    extra_save_dir="__unused__",
                )

                record = {
                    "trip_id": summary.get("trip_id"),
                    "dialog_turns": summary.get("dialog_turns"),
                    "summary": summary,
                }
                _append_jsonl(file_lock, output_jsonl_path, record)

                with done_counter.get_lock():
                    done_counter.value += 1
                    done = done_counter.value

                # ✅ 每完成一个就打印一下
                print(
                    f"[worker] progress: {done}/{total_tasks} "
                    f"(trip_id={summary.get('trip_id')})",
                    flush=True
                )

                if done >= total_tasks:
                    stop_event.set()

            except Exception as e:
                # === 新规则：max_tokens 这类 400 -> 不 cooldown，不 requeue，直接丢弃该任务 ===
                if _is_max_tokens_bad_request(e):
                    print(f"[worker] NON-RETRYABLE(max_tokens) -> drop: {e}", flush=True)

                    # 丢弃也算“完成一个任务”，避免 total_tasks 永远达不到
                    with done_counter.get_lock():
                        done_counter.value += 1
                        done = done_counter.value

                    # ✅ 丢弃也打印进度
                    print(f"[worker] progress: {done}/{total_tasks} (dropped)", flush=True)

                    if done >= total_tasks:
                        stop_event.set()

                    # 不 requeue，不 cooldown
                else:
                    # === 其他异常：保持原来的处理（requeue + cooldown）===
                    if not stop_event.is_set():
                        item["_retry"] = int(item.get("_retry", 0) or 0) + 1
                        task_q.put(item)

                    print(f"[worker] trajectory failed -> requeue: {e}", flush=True)
                    set_cooldown()

            finally:
                task_q.task_done()

    threads = []
    for wid in range(REQUESTS_PER_IP):
        t = threading.Thread(target=worker_loop, args=(wid,), daemon=True)
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

    print("[worker] process exit", flush=True)



def _pick_mp_context():
    """
    Linux 优先 fork（避免 spawn 导致每个子进程重复 import / 大加载）
    不支持 fork 的环境自动 fallback 到 spawn
    """
    try:
        methods = mp.get_all_start_methods()
    except Exception:
        methods = ["spawn"]

    if "fork" in methods:
        return mp.get_context("fork")
    return mp.get_context("spawn")


def run_multiturn_file_distributed(
    test_path: str,
    dialog_turns: int,
    keep_last_assistant: int,
    output_jsonl_path: str,
    user_prompt_template: str,   # ✅ 新增
):
    ctx = _pick_mp_context()
    print(f"[runner] mp start method = {ctx.get_start_method()}", flush=True)

    if ctx.get_start_method() == "fork":
        preload_resources_in_parent(CONFIG_PATH)

    with open(test_path, "r", encoding="utf-8") as f:
        items = json.load(f)

    os.makedirs(os.path.dirname(output_jsonl_path), exist_ok=True)
    if not os.path.exists(output_jsonl_path):
        with open(output_jsonl_path, "w", encoding="utf-8") as f:
            pass

    done_keys = load_done_keys_from_output_jsonl(output_jsonl_path)
    print(f"[runner] already_done={len(done_keys)} (from {output_jsonl_path})", flush=True)

    filtered = []
    seen_in_this_run = set()
    skipped_done = 0
    skipped_dup = 0

    for it in items:
        trip_id = it.get("trip_id")

        # multi-turn 用 instruction_ids_basic 做 key；single-turn 保持原逻辑
        if dialog_turns > 1:
            fil = it.get("instruction_ids_basic", [])
        else:
            fil = final_id_list_from_meta(it)

        sk = make_sample_key(trip_id, fil)

        if sk in done_keys:
            skipped_done += 1
            continue
        if sk in seen_in_this_run:
            skipped_dup += 1
            continue

        seen_in_this_run.add(sk)
        filtered.append(it)

    total = len(filtered)
    print(
        f"[runner] loaded={len(items)} queued={total} "
        f"skipped_done={skipped_done} skipped_dup_in_test={skipped_dup}",
        flush=True
    )

    if total == 0:
        print(f"[Done] nothing to run. output={output_jsonl_path}", flush=True)
        return

    task_q = ctx.JoinableQueue()
    done_counter = ctx.Value("i", 0)
    stop_event = ctx.Event()
    file_lock = ctx.Lock()

    # ✅ 只起一个 IP 进程
    procs = []
    p = ctx.Process(
        target=_ip_process_main,
        args=(
            task_q,
            done_counter,
            total,
            stop_event,
            file_lock,
            output_jsonl_path,
            dialog_turns,
            keep_last_assistant,
            user_prompt_template,  # ✅ 新增
        ),
        daemon=False,
    )
    p.start()
    procs.append(p)

    for it in filtered:
        task_q.put(it)

    task_q.join()

    stop_event.set()
    for p in procs:
        p.join()

    print(f"[Done] total={total}, finished={done_counter.value}, output={output_jsonl_path}", flush=True)


def make_output_jsonl_path(test_path: str, dialog_turns: int, output_dir: str) -> str:
    test_name = os.path.splitext(os.path.basename(test_path))[0]
    output_mode = "single" if dialog_turns == 1 else "multi"
    agent_name = AGENT_CFG.model_name

    return os.path.join(
        output_dir,
        f"{test_name}_{agent_name}_{output_mode}_t{dialog_turns}.jsonl"
    )


if __name__ == "__main__":

    TESTS = [
        ("test_easy", "        ("test_mid",  "    ]

    RUN_MODES = [
        ("single", 1),
        ("multi",  15),
    ]

    PROMPT_BY_TEST = {
        "test_easy": user_prompt_easy_en,
        "test_mid":  user_prompt_mid_en,
    }

    for test_tag, test_path in TESTS:
        user_prompt_template = PROMPT_BY_TEST.get(test_tag, user_prompt_easy_en)
        # print(user_prompt_template)

        for mode_tag, dialog_turns in RUN_MODES:
            output_jsonl_path = make_output_jsonl_path(
                test_path=test_path,
                dialog_turns=dialog_turns,
                output_dir=OUTPUT_DIR,
            )

            run_multiturn_file_distributed(
                test_path=test_path,
                dialog_turns=dialog_turns,
                keep_last_assistant=KEEP_LAST_ASSISTANT,
                output_jsonl_path=output_jsonl_path,
                user_prompt_template=user_prompt_template,  # ✅ 新增
            )
