import copy
import sys
import os
import json
import re
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

import yaml
from openai import OpenAI
import requests

# ====== PATH 修复 ======
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)

print("Added ROOT:", ROOT)

# ====== 引入工具 ======
from Tools.General.General_tools import GeneralTool
from Tools.Attraction.Attraction_tools import AttractionTool
from Tools.Flight.Flight_tools import FlightTool
from Tools.Train.Train_tools import TrainTool
from Tools.Restaurant.Restaurant_tools import RestaurantTool
from Tools.Hotel.Hotel_tools import HotelTool

from Tools.tool_description import (
    Flight_tool_description,
    Train_tool_description,
    Hotel_tool_description,
    Attraction_tool_description,
    Restaurant_tool_description,
    General_tool_description,
)

from interact.prompt.agent_system_prompt import system_prompt_en
from interact.prompt.user_prompt import user_prompt_easy_en
from interact.user_agent import UserSimulator
from interact.evaluation_agent import TripPlanEvaluator

from Evaluation.Attraction.val_attractions import AttractionEvaluator
from Evaluation.Restaurant.val_restaurants import RestaurantEvaluator
from Evaluation.Hotel.val_hotels import HotelEvaluator
from Evaluation.Transportation.val_transports import TransportationEvaluator
from Evaluation.General.val_general import GeneralEvaluator, load_or_build_cache


# ====== 读取 Config ======
CONFIG_PATH = " open(CONFIG_PATH, "r", encoding="utf-8") as f:
    config = yaml.safe_load(f)

data_path = config["data_path"]
attraction_path = data_path["attraction"]
restaurant_path = data_path["restaurant"]
hotel_path = data_path["hotel"]
flight_path = data_path["flight"]
train_path = data_path["train"]

# ====== 初始化评估器 ======
with open(attraction_path, "r", encoding="utf-8") as f:
    attractions_data = json.load(f)

attraction_evaluator = AttractionEvaluator(attractions_data)
restaurant_evaluator = RestaurantEvaluator(restaurants_path=restaurant_path)
hotel_evaluator = HotelEvaluator(hotel_path=hotel_path)
transportation_evaluator = TransportationEvaluator(flights_path=flight_path)
data = load_or_build_cache(CONFIG_PATH)
general_evaluator = GeneralEvaluator(data)


# ====== LLM 配置 ======
API_KEY = "1737787093780320300"
BASE_URL = "https://basicaiservice.sankuai.com/basicai/v1"
limit_url="https://basicaiservice.sankuai.com/basicai/openapi"

USER_MODEL_NAME = "LongCat-Flash-Chat"
# AGENT_MODEL_NAME = "LongCat-Flash-Thinking"
TEMPERATURE = 1.0
AGENT_MODEL_NAME = "LongCat-Flash-Chat"
# TEMPERATURE = 0.0
MAX_TOKENS = 65536

DIALOGUE_TURNS = 15  # 改成你要的轮数；1 就是 single turn
TEST_PATH = "interact/test/test_easy.json"
# MODE = "easy"
EXTRA_SAVE_DIR = "interact/traces/test_easy_multi_turn"

# print_ass = 1
# ============================================================
# 工具调用封装类（保持你原来的结构）
# ============================================================

class FunctionCall:
    def __init__(self, data: Dict[str, Any]):
        self.name = data.get("name")
        self.arguments = data.get("arguments")


class ToolCall:
    def __init__(self, data: Dict[str, Any]):
        self.id = data.get("id")
        self.type = data.get("type")
        self.function = FunctionCall(data.get("function") or {})


class AssistantMessage:
    def __init__(self, data: Dict[str, Any]):
        self.role = data.get("role")
        self.content = data.get("content")
        self.tool_calls = [ToolCall(tc) for tc in (data.get("tool_calls") or [])]


def serialize_tool_calls(tool_calls):
    if not tool_calls:
        return None
    out = []
    for call in tool_calls:
        out.append({
            "id": call.id,
            "type": call.type,
            "function": {
                "name": call.function.name,
                "arguments": call.function.arguments,
            }
        })
    return out


# ============================================================
# User Simulator（单独类）
# ============================================================



# ============================================================
# 每轮用户状态记录
# ============================================================

@dataclass
class RoundUserState:
    round_idx: int
    id_list_before: List[str]
    id_list_after: List[str]
    blocks: Dict[str, str]                # HISTORY/NEW/MODIFY/ISSUE
    sim_prompt: Optional[str]
    sim_output: Optional[Dict[str, Any]] 
    newly_added_ids: List[str]
    all_added_ids: List[str]
    final_answer: str
    eval_result: Dict[str, Any]
    use_next_turn_reward: Optional[bool]
    start_time: Optional[str]
    end_time: Optional[str] 
    tool_calls: List[Any]


# ============================================================
#                  TravelAgent 主类（OpenAI SDK 版本）
# ============================================================

class TravelAgent:
    def __init__(self, config_path: str = CONFIG_PATH):
        # ⭐ 初始化 OpenAI Client
        self.client = OpenAI(api_key=API_KEY, base_url=BASE_URL,timeout=600.0)

        # 工具类
        self.general_tool = GeneralTool()
        self.attraction_tool = AttractionTool(config_path)
        self.flight_tool = FlightTool(config_path)
        self.train_tool = TrainTool(config_path)
        self.restaurant_tool = RestaurantTool(config_path)
        self.hotel_tool = HotelTool(config_path)

        # 工具映射
        self.tools = {
            "search_flights": self.flight_tool.search_flights,
            "get_flight_detail_with_products": self.flight_tool.get_flight_detail_with_products,
            "get_airport_coordinates": self.flight_tool.get_airport_coordinates,

            "search_trains": self.train_tool.search_trains,
            "get_train_detail_with_products": self.train_tool.get_train_detail_with_products,
            "get_station_coordinates": self.train_tool.get_station_coordinates,

            "search_hotels": self.hotel_tool.search_hotels,
            "get_hotel_detail_with_products": self.hotel_tool.get_hotel_detail_with_products,
            "get_hotel_coordinates": self.hotel_tool.get_hotel_coordinates,

            "search_attractions": self.attraction_tool.search_attractions,
            "get_attraction_detail_with_products": self.attraction_tool.get_attraction_detail_with_products,
            "get_attraction_coordinates": self.attraction_tool.get_attraction_coordinates,

            "search_restaurants": self.restaurant_tool.search_restaurants,
            "get_restaurant_detail_with_products": self.restaurant_tool.get_restaurant_detail_with_products,
            "get_restaurant_coordinates": self.restaurant_tool.get_restaurant_coordinates,

            "get_route_estimate": self.general_tool.get_route_estimate,
            "get_city_center_coords": self.general_tool.get_city_center_coords,
            "get_date_after": self.general_tool.get_date_after,
        }

        # 工具描述
        self.tool_descriptions = (
            Flight_tool_description
            + Train_tool_description
            + Hotel_tool_description
            + Attraction_tool_description
            + Restaurant_tool_description
            + General_tool_description
        )

        # ✅ user simulation（单独类实例）
        self.user_simulator = UserSimulator(
            client=self.client,
            model_name=USER_MODEL_NAME,
            temperature=0.0,
            max_tokens=16384,
        )

        self.trip_plan_evaluator = TripPlanEvaluator(
            attraction_evaluator=attraction_evaluator,
            restaurant_evaluator=restaurant_evaluator,
            hotel_evaluator=hotel_evaluator,
            transportation_evaluator=transportation_evaluator,
            general_evaluator=general_evaluator,
        )

    def _post_chat_completion(
        self,
        messages: List[Dict[str, Any]],
        model: str,
        tools: Optional[List[Dict[str, Any]]] = None,
        tool_choice: str = "auto",
        temperature: float = TEMPERATURE,
        max_tokens: int = MAX_TOKENS,
    ) -> Dict[str, Any]:
        payload = {
            "api_key": API_KEY,
            "base_url": "https://aigc.sankuai.com/v1/openai/native",
            "openapi_params": {
                "model": model,
                "messages": messages,
                "tools": tools or [],
                "tool_choice": tool_choice,
                "temperature": temperature,
                "max_tokens": max_tokens,
                "timeout": 120.0,
            },
        }
        # if model == "LongCat-Flash-Thinking":
            # payload["openapi_params"]["enable_thinking"] = True
            # payload["openapi_params"]["thinking_budget"] = 4096
        headers = {"Content-Type": "application/json"}
        resp = requests.post(limit_url, json=payload, headers=headers, timeout=120)
        resp.raise_for_status()
        print(resp.json())
        return resp.json()

    # ======================== 核心 Chat Loop ========================
    def chat(self, history: Optional[List[Dict[str, str]]] = None, max_rounds=1000):
        if history is None:
            history = []

        trace_log = {
            "start_time": datetime.now().isoformat(),
            "messages": [],
            "trace_log_messages": [],
            "tool_calls": [],
            "final_answer": None,
        }

        # 初始化消息
        messages = copy.deepcopy(history)
        trace_log["trace_log_messages"].extend(messages)

        # ---------------------- Loop ----------------------
        for _ in range(max_rounds):
            # resp = self.client.chat.completions.create(
            #     model=AGENT_MODEL_NAME,
            #     messages=messages,
            #     tools=self.tool_descriptions,
            #     tool_choice="auto",
            #     temperature=TEMPERATURE,
            #     max_tokens=MAX_TOKENS,
            # )

            # assistant_msg_dict = resp.choices[0].message.to_dict()

            data = self._post_chat_completion(
                messages=messages,
                model=AGENT_MODEL_NAME,
                tools=self.tool_descriptions,
                tool_choice="auto",
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
            )
            print(data)
            assistant_msg_dict = data["choices"][0]["message"]
            # print(assistant_msg_dict)


            messages.append(assistant_msg_dict)

            assistant_msg = AssistantMessage(assistant_msg_dict)
            trace_log["trace_log_messages"].append({
                "role": "assistant",
                "content": assistant_msg.content,
                "tool_calls": serialize_tool_calls(assistant_msg.tool_calls),
            })

            tool_calls = assistant_msg.tool_calls

            # ---------- 没有工具调用 = 最终答案 ----------
            if not tool_calls:
                final_answer = assistant_msg.content or ""
                trace_log["messages"] = messages
                trace_log["final_answer"] = final_answer
                trace_log["end_time"] = datetime.now().isoformat()
                return trace_log

            # ---------- 工具调用 ----------
            for call in tool_calls:
                tool_name = call.function.name
                raw_args = call.function.arguments or "{}"

                try:
                    args = json.loads(raw_args)
                except Exception:
                    args = {}

                # tool_func = self.tools.get(tool_name)

                # if tool_func is None:
                #     result = {"error": f"Unknown tool {tool_name}"}
                #     trace_log["tool_calls"].append({"tool": tool_name, "args": args, "result": result})
                # else:
                #     try:
                #         output = tool_func(**args)
                #         result = output
                #         trace_log["tool_calls"].append({"tool": tool_name, "args": args, "result": output})
                #     except Exception as e:
                #         result = {"error": str(e)}
                #         trace_log["tool_calls"].append({"tool": tool_name, "args": args, "result": result})

                tool_func = self.tools.get(tool_name)

                if tool_func is None:
                    # 仍然保持原来的 result 形态（写回 LLM 不变）
                    result = {"error": f"Unknown tool {tool_name}"}

                    # 仅改 trace 记录结构：success/data/error
                    trace_log["tool_calls"].append({
                        "tool": tool_name,
                        "args": args,
                        "result": {"success": False, "error": result["error"]},
                    })
                else:
                    try:
                        output = tool_func(**args)

                        # 写回 LLM 不变
                        result = output

                        # trace 记录结构：success/data/error
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": True, "data": output},
                        })
                    except Exception as e:
                        err = str(e)

                        # 写回 LLM 不变
                        result = {"error": err}

                        # trace 记录结构：success/data/error
                        trace_log["tool_calls"].append({
                            "tool": tool_name,
                            "args": args,
                            "result": {"success": False, "error": err},
                        })


                if isinstance(result, str):
                    content = result
                else:
                    content = json.dumps(result, ensure_ascii=False)

                # 回复工具执行结果
                tool_msg = {
                    "role": "tool",
                    "tool_call_id": call.id,
                    "name": tool_name,
                    "content": content,
                }

                messages.append(tool_msg)
                trace_log["trace_log_messages"].append(tool_msg)

        # max_rounds 超限
        trace_log["messages"] = messages
        trace_log["final_answer"] = "Too many rounds."
        trace_log["end_time"] = datetime.now().isoformat()

        return trace_log


    # ============================================================
    #                  Multi-turn helpers（静态方法）
    # ============================================================

    @staticmethod
    def _safe_format(desc: str, slot: str) -> str:
        try:
            return desc.format(slot=slot)
        except Exception:
            return (desc or "").replace("{slot}", slot)

    @staticmethod
    def _build_id2instruction(meta_dict: Dict[str, Any]) -> Dict[str, str]:
        """
        meta_dict["rubric_results"] -> 每个 label 的 _id 映射到一句 instruction（用 description）
        """
        id2instr = {}
        rubric_results = meta_dict.get("rubric_results", {}) or {}

        for _, rubric_group in rubric_results.items():
            for _, cfg in (rubric_group or {}).items():
                desc_tmpl = cfg.get("description", "") or ""
                result = cfg.get("result", {}) or {}
                all_labels = result.get("all_labels_and_ranges", {}) or {}

                for label_text, payload in all_labels.items():
                    _id = (payload or {}).get("_id")
                    if not _id:
                        continue
                    instr = TravelAgent._safe_format(desc_tmpl, label_text).strip()
                    id2instr[_id] = instr

        return id2instr

    @staticmethod
    def _flatten_applied_chains(meta_dict: Dict[str, Any]) -> List[str]:
        chains = meta_dict.get("applied_modification_chains") or {}
        out = []

        for _, chain in chains.items():
            if chain:               # 确保不是 None 或空列表
                out.append(chain[-1])  # 只取最后一个

        return out



    @staticmethod
    def _build_blocks(meta_dict: Dict[str, Any], id_list: List[str], rubric_progress: Dict[str, int]) -> Dict[str, str]:
        
        id2instr = TravelAgent._build_id2instruction(meta_dict)

        # ===== HISTORY =====
        history_lines = []        
        for cid in id_list:
            txt = id2instr.get(cid)
            if txt:
                history_lines.append(f"[ID: HIS_{cid}] {txt}")

        chains = meta_dict.get("applied_modification_chains", {}) or {}

        new_candidates = []
        modify_items = []  # (rubric_key, expected_id)

        for rubric_key in sorted(chains.keys()):
            chain = chains[rubric_key] or []
            if not chain:
                continue

            prog = rubric_progress.get(rubric_key, 0)

            # 已完成
            if prog >= len(chain):
                continue

            # 未开始 → NEW
            if prog == 0:
                new_candidates.append(chain[0])
            # 已开始未完成 → MODIFY
            else:
                modify_items.append((rubric_key, chain[prog]))

        # ===== NEW（可以多条）=====
        new_lines = []
        for cid in new_candidates:
            txt = id2instr.get(cid)
            if txt:
                new_lines.append(f"[ID: NEW_{cid}] {txt}")

        # ===== MODIFY（支持回退说明）=====
        modify_lines = []

        for rubric_key, expected_id in modify_items:
            chain = chains.get(rubric_key, [])
            prog = rubric_progress.get(rubric_key, 0)

            expected_txt = id2instr.get(expected_id)
            if not expected_txt:
                continue

            # ===== 回退检测：expected_id 之前出现过 =====
            if expected_id in chain[:prog]:
                # 被“取消”的是上一步
                removed_id = chain[prog - 1]
                removed_txt = id2instr.get(removed_id, "")

                line = (
                    f"[ID: MOD_{expected_id}] {expected_txt} "
                    f"(Note: the previous requirement {removed_txt} is no longer needed)"
                )
            else:
                line = f"[ID: MOD_{expected_id}] {expected_txt}"

            modify_lines.append(line)

        return {
            "HISTORY": "\n".join(history_lines).strip(),
            "NEW": "\n".join(new_lines).strip(),
            "MODIFY": "\n".join(modify_lines).strip(),
        }



    @staticmethod
    def _history_text_for_sim(conv_simple: List[Dict[str, str]], keep_last_assistant: int = 3) -> str:
        """
        用于填 {{HISTORY_MESSAGES}}
        - user 全保留
        - assistant 仅保留最近 keep_last_assistant 条
        - 更早的 assistant 用占位说明替换
        """

        OMIT_TEXT = "Earlier assistant responses are omitted due to context length limits."

        assistant_indices = [
            i for i, m in enumerate(conv_simple) if m["role"] == "assistant"
        ]

        keep_indices = set(assistant_indices[-keep_last_assistant:])

        lines = []
        for i, m in enumerate(conv_simple):
            if m["role"] == "user":
                lines.append(f"User: {m['content']}")
            elif m["role"] == "assistant":
                if i in keep_indices:
                    lines.append(f"Assistant: {m['content']}")
                else:
                    lines.append(f"Assistant: {OMIT_TEXT}")

        return "\n".join(lines).strip()

    

    def _all_rubrics_done(self, meta_dict: Dict[str, Any], rubric_progress: Dict[str, int]) -> bool:
        chains = meta_dict.get("applied_modification_chains", {}) or {}
        if not chains:
            return False
        for rubric_key, chain in chains.items():
            if int(rubric_progress.get(rubric_key, 0) or 0) < len(chain):
                return False
        return True


    def _apply_proposed_ids_with_prefix(
        self,
        meta_dict: Dict[str, Any],
        current_id_list: List[str],
        rubric_progress: Dict[str, int],
        proposed_ids: List[Any],
    ) -> Tuple[bool, List[str], List[str], List[str]]:
        """
        处理 NEW_/MOD_ 指令：
        - NEW_xxx: 提取 xxx 加入 current_id_list（若不在则 append），并更新 rubric_progress
        - MOD_xxx: 找到 xxx 在该 rubric chain 中的前一个 prev，要求 prev 必须在 current_id_list，
                然后删除 prev，再加入 xxx，并更新 rubric_progress

        返回:
        ok, newly_added_ids, removed_ids, errors
        """
        chains = meta_dict.get("applied_modification_chains", {}) or {}
        id2pos: Dict[str, Tuple[str, int]] = {}
        for rubric_key, chain in chains.items():
            for idx, cid in enumerate(chain):
                id2pos[cid] = (rubric_key, idx)

        errors: List[str] = []
        newly_added: List[str] = []
        removed: List[str] = []

        if not isinstance(proposed_ids, list):
            return False, [], [], [f"instruction_ids is not a list: {type(proposed_ids)}"]

        for raw in proposed_ids:
            if not isinstance(raw, str) or not raw:
                errors.append(f"invalid instruction_id (not str/empty): {raw!r}")
                continue

            if raw.startswith("NEW_"):
                op = "NEW"
                base_id = raw[len("NEW_"):]
            elif raw.startswith("MOD_"):
                op = "MOD"
                base_id = raw[len("MOD_"):]
            else:
                continue

            if base_id not in id2pos:
                errors.append(f"{raw} -> {base_id} not found in applied_modification_chains")
                continue

            rubric_key, idx = id2pos[base_id]
            chain = chains.get(rubric_key, []) or []

            if op == "NEW":
                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)
                # progress：至少推进到 idx+1（1-based count）
                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

            else:  # MOD
                if idx <= 0:
                    errors.append(f"{raw}: target is first element of chain, no previous to remove")
                    continue

                prev_id = chain[idx - 1]
                if prev_id not in current_id_list:
                    errors.append(f"{raw}: expected to remove prev_id={prev_id}, but it's not in current_id_list")
                    continue

                # 删除 prev（只删一次）
                try:
                    current_id_list.remove(prev_id)
                    removed.append(prev_id)
                except ValueError:
                    errors.append(f"{raw}: failed to remove prev_id={prev_id} (ValueError)")
                    continue

                if base_id not in current_id_list:
                    current_id_list.append(base_id)
                    newly_added.append(base_id)

                # rubric_progress +1（等价于推进到 idx+1）
                rubric_progress[rubric_key] = max(int(rubric_progress.get(rubric_key, 0) or 0), idx + 1)

        ok = (len(errors) == 0)
        return ok, newly_added, removed, errors


    # ============================================================
    # 保存每轮用户状态（额外保存）
    # ============================================================

    def rubric_all_satisfied(self, rubric_progress: dict, applied_chains: dict) -> bool:
        for rubric, chains in applied_chains.items():
            required = len(chains)
            current = rubric_progress.get(rubric, 0)
            if current < required:
                return False
        return True
    

    def should_fallback_to_prev_answer(self, final_answer: str) -> bool:
        """
        判断 final_answer 中是否包含 trip_plan.daily_schedule 为空
        是则返回 True
        """
        try:
            data = json.loads(final_answer)
        except Exception:
            # 不是合法 JSON，不回退
            return False

        trip_plan = data.get("trip_plan")
        if not isinstance(trip_plan, dict):
            return False

        daily_schedule = trip_plan.get("daily_schedule")
        return isinstance(daily_schedule, list) and len(daily_schedule) == 0


    def select_raw_answer_for_eval(self,current_raw_answer: str, conv_simple: List[Dict]) -> str:
        """
        从当前轮开始，向前回溯 assistant 输出，
        找到第一个 trip_plan.daily_schedule 非空的 raw_answer
        """

        def extract_daily_schedule(raw_text: str) -> Optional[list]:
            # 1) 直接 json
            try:
                data = json.loads(raw_text)
                if isinstance(data, dict):
                    tp = data.get("trip_plan", {})
                    if isinstance(tp, dict):
                        return tp.get("daily_schedule")
            except Exception:
                pass

            # 2) ```json``` 块
            match = re.search(r"```json\s*(\{.*?\})\s*```", raw_text, re.S)
            if match:
                try:
                    data = json.loads(match.group(1))
                    tp = data.get("trip_plan", {})
                    if isinstance(tp, dict):
                        return tp.get("daily_schedule")
                except Exception:
                    pass

            return None

        # ① 当前轮优先
        schedule = extract_daily_schedule(current_raw_answer)
        if isinstance(schedule, list) and len(schedule) > 0:
            return current_raw_answer

        # ② 向前回溯历史
        for msg in reversed(conv_simple):
            if msg.get("role") != "assistant":
                continue
            raw_text = msg.get("content", "")
            schedule = extract_daily_schedule(raw_text)
            if isinstance(schedule, list) and len(schedule) > 0:
                return raw_text

        # ③ 实在没有，兜底返回当前轮
        return current_raw_answer

    # ============================================================
    # 多轮主入口：一个 meta_dict 跑 dialog_turns 轮
    # ============================================================


    def _save_multiturn_summary(
        self,
        trip_id: str,
        summary: Dict[str, Any],
        save_root: str,
    ) -> str:
        # save_root / 模型名称
        model_dir = os.path.join(save_root, AGENT_MODEL_NAME)
        os.makedirs(model_dir, exist_ok=True)

        filename = (
            f"trajectory_{trip_id}_"
            f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        )
        path = os.path.join(model_dir, filename)

        with open(path, "w", encoding="utf-8") as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)

        print(f"📁 Summary saved: {path}")
        return path



    def run_multiturn_trip(
        self,
        meta_dict: Dict[str, Any],
        dialog_turns: int = DIALOGUE_TURNS,
        keep_last_assistant: int = 3,
        extra_save_dir: str = "interact/traces/test_easy",
    ) -> Dict[str, Any]:
        
        os.makedirs(extra_save_dir, exist_ok=True)
        trip_id = meta_dict.get("trip_id", "unknown_trip")
        rubric_progress = copy.deepcopy(meta_dict.get("rubric_progress", {}))

        conv_full: List[Dict[str, str]] = []
        conv_simple: List[Dict[str, str]] = []

        rounds: List[RoundUserState] = []

        # ====== 初始化：dialog_turns==1 用 query_with_constraints，否则用 query_basic ======
        if dialog_turns == 1:
            current_user_query = meta_dict.get("query_with_constraints", meta_dict.get("query_basic", ""))
            # single turn: 认为所有约束已提出（用于 evaluator）
            current_id_list = self._flatten_applied_chains(meta_dict)
        else:
            current_user_query = meta_dict.get("query_basic", "")
            current_id_list = meta_dict.get("instruction_ids_basic", [])

        issue_text = ""  # 第一轮无 issue（第二轮起用上一轮 eval 的 error_text）

        for r in range(1, dialog_turns + 1):

            # print(f"round {r}:")
            # print(f"rubric_progress: {rubric_progress}")
            id_before = list(current_id_list)

            # 1) blocks（ISSUE 用上一轮 eval 结果）
            blocks = self._build_blocks(meta_dict, current_id_list, rubric_progress)

            # print(f"blocks: {blocks}")

            blocks["ISSUE"] = issue_text or ""
            sim_prompt = None
            sim_output = None
            newly_added = []
            all_added = []
            use_next_turn_reward = False

            # 2) 第2轮起，走用户模拟生成 user_query + instruction_ids
            if r >= 2:
                history_text = self._history_text_for_sim(conv_simple, keep_last_assistant)

                max_retry = 3
                retry = 0
                sim_output = None
                sim_prompt = None
                newly_added = []

                while True:
                    sim_prompt = user_prompt_easy_en \
                        .replace("{{HISTORY}}", blocks.get("HISTORY", "")) \
                        .replace("{{NEW}}", blocks.get("NEW", "")) \
                        .replace("{{MODIFY}}", blocks.get("MODIFY", "")) \
                        .replace("{{ISSUE}}", blocks.get("ISSUE", "")) \
                        .replace("{{HISTORY_MESSAGES}}", history_text)
                    if retry > 0:
                        sim_output = self.user_simulator.generate(sim_prompt,temperature=0.7)
                    else:
                        sim_output = self.user_simulator.generate(sim_prompt)
                    # print(f"round {r}:")
                    # print(f"sim_prompt: {sim_prompt}")
                    # print(f"sim_output: {sim_output}")
                    proposed_ids = sim_output.get("instruction_ids", []) or []

                    use_next_turn_reward = (
                        not proposed_ids or
                        all(pid.startswith("HIS_") for pid in proposed_ids)
                    )

                    # 先在副本上 apply，失败就回滚并要求 simulator 重发
                    tmp_id_list = list(current_id_list)
                    tmp_progress = copy.deepcopy(rubric_progress)

                    ok, newly_added_tmp, removed_tmp, errors = self._apply_proposed_ids_with_prefix(
                        meta_dict=meta_dict,
                        current_id_list=tmp_id_list,
                        rubric_progress=tmp_progress,
                        proposed_ids=proposed_ids,
                    )

                    if ok:
                        current_id_list = tmp_id_list
                        rubric_progress = tmp_progress
                        newly_added = newly_added_tmp
                        current_user_query = sim_output.get("user_query", "")
                        all_added = proposed_ids
                        break

                    retry += 1

                    if retry >= max_retry:
                        # 这里不再继续卡死，带着 issue_text 进入 agent/eval，让下一轮继续修正
                        newly_added = []
                        current_user_query = sim_output.get("user_query", "")
                        break


            # 3) agent 回答（history：user 全保留 + assistant 仅近3轮）
            if r == 1:
                conv_full.append({"role": "system", "content": system_prompt_en})
            conv_full.append({"role": "user", "content": current_user_query})
            conv_simple.append({"role": "user", "content": current_user_query})

            trace_log = self.chat(history=conv_full)

            conv_full = copy.deepcopy(trace_log["messages"])
            conv_simple.append({"role": "assistant", "content": trace_log["final_answer"]})

            raw_answer_for_eval = self.select_raw_answer_for_eval(
                    current_raw_answer=trace_log["final_answer"],
                    conv_simple=conv_simple
            )

            # 4) eval
            eval_result = self.trip_plan_evaluator.evaluate(
                meta_dict=meta_dict,
                raw_text=raw_answer_for_eval,
                id_list=current_id_list
            )

            issue_text = eval_result.get("error_text", "") or ""

            # 5) 保存每轮用户状态
            round_state = RoundUserState(
                round_idx=r,
                id_list_before=id_before,
                id_list_after=list(current_id_list),
                blocks=blocks,
                sim_prompt=sim_prompt,
                sim_output=sim_output,
                newly_added_ids=newly_added,
                all_added_ids=all_added,
                final_answer=trace_log["final_answer"],
                tool_calls=trace_log["tool_calls"],
                eval_result=eval_result,
                use_next_turn_reward=use_next_turn_reward,
                start_time = trace_log["start_time"],
                end_time = trace_log["end_time"],

            )
            rounds.append(round_state)

            # ====== ✅ EARLY STOP: rubric 满足 ======
            applied_chains = meta_dict.get("applied_modification_chains", {})
            # print(f"applied_chains: {applied_chains}")
            # print(f"rubric_progress: {rubric_progress}")

            if self.rubric_all_satisfied(rubric_progress, applied_chains):
                print(f"All rubrics are proposed at round {r}")
                break


        summary = {
            "trip_id": trip_id,
            "dialog_turns": dialog_turns,
            "keep_last_assistant": keep_last_assistant,
            "final_id_list": list(current_id_list),
            "messages": conv_full,
            "rounds": [asdict(x) for x in rounds],
        }
        self._save_multiturn_summary(trip_id, summary, extra_save_dir)
        return summary

    # ============================================================
    # 读取 test_easy.json，逐条处理
    # ============================================================

    def run_multiturn_file(
        self,
        test_path: str,
        dialog_turns: int,
        keep_last_assistant: int = 3,
        extra_save_dir: str = "interact/traces/multiturn_easy",
        max_workers: int = 8,
    ):
        """
        稳定版多线程 runner（无重试）：
        - 动态补任务（不会漏掉新提交的 future）
        - 每个 future 都会 result()，异常不再静默
        """
        import traceback
        from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED

        with open(test_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        data_iter = iter(data)
        total = len(data)

        def worker(item):
            # 你原来的逻辑
            self.run_multiturn_trip(
                meta_dict=item,
                dialog_turns=dialog_turns,
                keep_last_assistant=keep_last_assistant,
                extra_save_dir=extra_save_dir,
            )

        submitted = 0
        finished = 0
        failed = 0

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            running = set()

            # 先塞满
            for _ in range(max_workers):
                try:
                    item = next(data_iter)
                except StopIteration:
                    break
                submitted += 1
                running.add(executor.submit(worker, item))

            # 动态消费：完成一个就补一个
            while running:
                done, running = wait(running, return_when=FIRST_COMPLETED)

                for fut in done:
                    try:
                        fut.result()
                        finished += 1
                    except Exception:
                        failed += 1
                        print("========== Worker crashed ==========")
                        print(traceback.format_exc())
                        print("===================================")

                    # 补一个
                    try:
                        item = next(data_iter)
                        submitted += 1
                        running.add(executor.submit(worker, item))
                    except StopIteration:
                        pass

        print(f"[Done] total={total} finished={finished} failed={failed}")


# ======================== 运行入口 ========================
if __name__ == "__main__":


    agent = TravelAgent()

    # agent.run_multiturn_file(
    #     test_path=TEST_PATH,
    #     dialog_turns=1,
    #     keep_last_assistant=3,
    #     extra_save_dir="interact/traces/test_easy_single_turn",
    #     max_workers=40,
    # )


    # agent.run_multiturn_file(
    #     test_path=TEST_PATH,
    #     dialog_turns=DIALOGUE_TURNS,
    #     keep_last_assistant=3,
    #     extra_save_dir="interact/traces/test_easy_multi_turn",
    #     max_workers=20,
    # )

    # AGENT_MODEL_NAME = "LongCat-Flash-Thinking"


    agent.run_multiturn_file(
        test_path=TEST_PATH,
        dialog_turns=1,
        keep_last_assistant=3,
        extra_save_dir="interact/traces/test_easy_single_turn",
        max_workers=10,
    )

    # agent.run_multiturn_file(
    #     test_path=TEST_PATH,
    #     dialog_turns=DIALOGUE_TURNS,
    #     keep_last_assistant=3,
    #     extra_save_dir="interact/traces/test_easy_multi_turn",
    #     max_workers=10,
    # )
