import json
import re
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from openai import OpenAI

# ====== 路径 ======
INPUT_PATH = Path("interact/test/test_easy_adjust_time.json")
OUTPUT_PATH = Path("interact/test/test_easy_g.json")

# ====== LLM 配置 ======
API_KEY = "1737787093780320300"
BASE_URL = "https://basicaiservice.sankuai.com/basicai/v1"

AGENT_MODEL_NAME = "LongCat-Flash-Chat"
TEMPERATURE = 0.0
MAX_TOKENS = 32768

# 并发度（按你服务的 QPS/限流调）
MAX_WORKERS = 50

client = OpenAI(
    api_key=API_KEY,
    base_url=BASE_URL
)

def _render_constraint_from_id(rubric_results: dict, rubric_name: str, target_id: str) -> str:
    """
    从 rubric_results[rubric_name] 里找到 _id == target_id 的 slot，
    用 description 模板渲染出约束文案；找不到则退化为 description（若有）。
    """
    rubric_block = rubric_results.get(rubric_name, {})
    for _, rule_content in rubric_block.items():
        description_template = rule_content.get("description", "") or ""
        all_ranges = (
            rule_content
            .get("result", {})
            .get("all_labels_and_ranges", {})
        ) or {}

        for slot_key, slot_value in all_ranges.items():
            if slot_value.get("_id") == target_id:
                if "{slot}" in description_template:
                    return description_template.replace("{slot}", slot_key)
                return description_template

        # 若这个 rule 没有 ranges 但有 description，也可作为兜底（但优先找 match）
        if description_template and not all_ranges:
            return description_template

    return ""


def extract_constraints(item: dict) -> list[str]:
    constraints = []

    applied_chains = item.get("applied_modification_chains", {})
    rubric_results = item.get("rubric_results", {})

    for rubric_name, chain in applied_chains.items():
        if not chain:
            continue

        last_id = chain[-1]
        constraint = _render_constraint_from_id(rubric_results, rubric_name, last_id)
        if constraint:
            constraints.append(constraint)

    return constraints


def build_basic_candidates(item: dict) -> list[dict]:
    """
    候选项 = 每个 RUBRIC 的 applied_modification_chains[rubric][0]
    并渲染成可读 instruction 文本。
    返回形如：
    [
      {"rubric": "RUBRIC_INCLUDE_CATEGORIES", "id": "attraction_1_1_1", "instruction": "..."},
      ...
    ]
    """
    applied = item.get("applied_modification_chains", {}) or {}
    rubric_results = item.get("rubric_results", {}) or {}

    candidates = []
    for rubric_name, chain in applied.items():
        if not chain:
            continue
        first_id = chain[0]
        instruction = _render_constraint_from_id(rubric_results, rubric_name, first_id) or ""
        candidates.append({
            "rubric": rubric_name,
            "id": first_id,
            "instruction": instruction
        })
    return candidates


def compute_rubric_progress(applied_chains: dict, selected_ids: list[str]) -> dict:
    """
    rubric_progress: {RUBRIC_xxx: 0/1/2/...}
    表示该 rubric 的 chain 里选到了第几个（从 1 开始）；没选则 0
    """
    selected_set = set(selected_ids or [])
    progress = {}

    for rubric_name, chain in (applied_chains or {}).items():
        p = 0
        if chain:
            for i, _id in enumerate(chain):
                if _id in selected_set:
                    p = i + 1
                    break
        progress[rubric_name] = p

    return progress


def build_prompt_basic(route: dict, candidates: list[dict]) -> str:
    # 给模型看的候选（每个 RUBRIC 的第一个）
    # 只允许从这些 id 里选，最多 4 个，也可以 0 个
    candidates_json = json.dumps(candidates, ensure_ascii=False, indent=2)

    return f"""
You are generating ONE realistic user query for a travel planning product.

Role:
- You are a REAL user planning a trip
- The query must sound natural, fluent, and human
- Do NOT sound like an assistant or a template

Hard requirements:
- Explicitly mention:
  - departure city
  - destination city
  - number of people
  - departure date
  - length of stay (days)
  - return date
- All information must be factually correct
- Do not omit or paraphrase away any key detail
- **Do NOT add, infer, assume, or supplement any information that is not explicitly provided below**

Default interpretation rule:
- **If transportation mode is NOT explicitly mentioned in selected instructions, assume ALL transportation modes are acceptable (e.g. flights or trains).**
- **Do NOT mention transportation mode in the query unless it is explicitly required by a selected instruction.**

Instruction selection rules:
- You MAY select **0 to 4** instruction IDs from the candidate list below.
- If you select an ID, you MUST reflect that instruction in the query.
- If you do not select an ID, you MUST NOT mention it.
- Do NOT introduce any constraints other than the selected instructions.
- Candidate list contains ONLY the first option for each RUBRIC.

Trip information:
- From: {route['from']}
- To: {route['to']}
- Number of people: {route['number_of_people']}
- Departure date: {route['depart_date']}
- Stay duration: {route['stay_days']} days
- Return date: {route['return_date']}

Instruction candidates (select 0~4 IDs):
{candidates_json}

Output format (STRICT JSON, no extra text):
{{
  "instruction_ids": ["id1", "id2", "... up to 4 ..."],
  "query_basic": "..."
}}
""".strip()


def build_prompt_with_constraints(route: dict, constraints: list[str]) -> str:
    return f"""
You are generating ONE realistic user query for a travel planning product.

Role:
- You are a REAL user planning a trip
- The query must sound natural, fluent, and human
- Do NOT sound like an assistant or a template

Hard requirements:
- Explicitly mention:
  - departure city
  - destination city
  - number of people
  - departure date
  - length of stay (days)
  - return date
- All information must be factually correct
- Do not omit or paraphrase away any key detail
- **Do NOT add, infer, assume, or supplement any information that is not explicitly provided below**
- **Do NOT include anything beyond the listed constraints**
- MUST clearly include ALL constraints listed below
- Constraints should be naturally integrated into the request

Default transportation rule:
- **If transportation mode is NOT listed in the constraints, assume all transportation modes are acceptable (e.g. flights or trains).**
- **Do NOT mention transportation mode in the query unless it appears explicitly in the constraints.**

Trip information:
- From: {route['from']}
- To: {route['to']}
- Number of people: {route['number_of_people']}
- Departure date: {route['depart_date']}
- Stay duration: {route['stay_days']} days
- Return date: {route['return_date']}

Constraints (must ALL be included):
{json.dumps(constraints, ensure_ascii=False, indent=2)}

Output format (STRICT JSON, no extra text):
{{
  "query_with_constraints": "..."
}}
""".strip()


def _safe_json_loads(content: str) -> dict:
    s = content.strip()

    # 去掉 ```json ... ``` 包裹
    if s.startswith("```"):
        s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.IGNORECASE)
        s = re.sub(r"\s*```$", "", s)

    # 截取第一个 { 到最后一个 }
    l = s.find("{")
    r = s.rfind("}")
    if l != -1 and r != -1 and r > l:
        s = s[l:r+1]

    return json.loads(s)


def call_llm(prompt: str, retry: int = 10) -> dict:
    last_error = None

    for attempt in range(retry + 1):
        if attempt > 0:
            print(f"⚠️ Retry {attempt + 1} after 40 second...")
            time.sleep(40)
        try:
            response = client.chat.completions.create(
                model=AGENT_MODEL_NAME,
                messages=[
                    {"role": "system", "content": "You generate realistic user queries for a travel planning scenario."},
                    {"role": "user", "content": prompt}
                ],
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
            )

            content = response.choices[0].message.content or ""
            return _safe_json_loads(content)

        except Exception as e:
            last_error = e
            print("\n" + "=" * 80)
            print(f"❌ LLM CALL FAILED (attempt {attempt + 1})")
            print(f"📌 ERROR: {e}")
            print("=" * 80 + "\n")

    # ⚠️ 关键：不要返回 None
    return {}

def main():
    with open(INPUT_PATH, "r", encoding="utf-8") as f:
        data = json.load(f)

    future_to_meta = {}

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        for idx, item in enumerate(data):
            route = item["route"][0]
            constraints = extract_constraints(item)

            # basic: 构造候选（每个 rubric 的第一个）
            candidates = build_basic_candidates(item)
            prompt_basic = build_prompt_basic(route, candidates)

            # constraints: 原样
            prompt_constraints = build_prompt_with_constraints(route, constraints)

            future_to_meta[executor.submit(call_llm, prompt_basic)] = (idx, "query_basic")
            future_to_meta[executor.submit(call_llm, prompt_constraints)] = (idx, "query_with_constraints")

        for fut in as_completed(future_to_meta):
            idx, key = future_to_meta[fut]
            try:
                llm_result = fut.result()

                if key == "query_basic":
                    # query_basic + instruction_ids + rubric_progress
                    instruction_ids = llm_result.get("instruction_ids", []) or []
                    data[idx]["query_basic"] = llm_result.get("query_basic", "") or ""
                    data[idx]["instruction_ids_basic"] = instruction_ids

                    applied = data[idx].get("applied_modification_chains", {}) or {}
                    data[idx]["rubric_progress"] = compute_rubric_progress(applied, instruction_ids)

                else:
                    data[idx][key] = llm_result.get(key, "") or ""

            except Exception as e:
                if key == "query_basic":
                    data[idx]["query_basic"] = ""
                    data[idx]["instruction_ids_basic"] = []
                    applied = data[idx].get("applied_modification_chains", {}) or {}
                    data[idx]["rubric_progress"] = compute_rubric_progress(applied, [])
                else:
                    data[idx][key] = ""
                print(f"❌ idx={idx} key={key} failed: {e}")

    with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f"✅ Original items preserved. Queries appended and saved to: {OUTPUT_PATH}")


if __name__ == "__main__":
    main()
