#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ComplexFuncBench domain analyzer (invoked-preferred classification)

No CLI args. Configure paths in main().
"""

from __future__ import annotations
import csv
import re
import statistics as stats
from collections import Counter, defaultdict
from typing import Any, Dict, Iterable, Iterator, List, Set, Tuple

from typing import Optional, List, Dict, Any

# ----------------------------- IO helpers -----------------------------

def iter_jsonl(path: str) -> Iterator[Dict[str, Any]]:
    """Yield dict records from a .jsonl file safely."""
    with open(path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError as e:
                raise RuntimeError(f"JSON parse error at line {ln}: {e}")

def _content_to_str(x: Any) -> str:
    """Normalize message content to a string."""
    if isinstance(x, str):
        return x
    if isinstance(x, list):
        return "".join(str(t) for t in x)
    return str(x)

def normalize_conversations(rec: Dict[str, Any]) -> Dict[str, Any]:
    """Ensure conversation messages have string content (in-place, returns rec)."""
    convs = rec.get("conversations")
    if isinstance(convs, list):
        for m in convs:
            if isinstance(m, dict) and "content" in m:
                m["content"] = _content_to_str(m.get("content", ""))
    return rec

def join_conversation_text(convs: Any) -> str:
    """Concatenate all message contents into a single lower-cased string."""
    if not isinstance(convs, list):
        return ""
    texts = []
    for m in convs:
        if isinstance(m, dict):
            texts.append(_content_to_str(m.get("content", "")))
    return "\n".join(texts).lower()

# -------------------------- Domain dictionaries --------------------------

DOMAIN_LABELS = ("hotels", "flights", "car_rentals", "attractions", "taxi")

NAME_HINTS: Dict[str, Tuple[str, ...]] = {
    "hotels": (
        "Search_Hotels", "Get_Room_Availability", "Get_Room_List_With_Availability",
        "Get_Hotel_Policies", "searchHotels", "searchHotelsByCoordinates",
        "getFilter", "getSortBy", "searchDestination"  # hotel destination
    ),
    "flights": (
        "Search_Flights", "Get_Flight_Status", "Get_Flight_Prices",
        "searchFlights", "getAirports"
    ),
    "car_rentals": (
        "Search_Car_Location", "Search_Car_Rentals", "Get_Car_Availability",
        "searchCarRentals", "cars/searchDestination"
    ),
    "attractions": (
        "Search_Attractions", "Get_Attraction_Details", "Get_Attraction_Availability",
    ),
    "taxi": (
        "Search_Taxi", "Get_Taxi_Price", "Book_Taxi",
    ),
}

ENDPOINT_HINTS: Dict[str, Tuple[str, ...]] = {
    "hotels": ("/api/v1/hotels",),
    "flights": ("/api/v1/flights",),
    "car_rentals": ("/api/v1/cars",),
    "attractions": ("/api/v1/attractions",),
    "taxi": ("/api/v1/taxi",),
}

ID_PREFIX_MAP: Dict[str, str] = {
    "car-rental": "car_rentals",
    "car_rental": "car_rentals",
    "cars": "car_rentals",
    "hotel": "hotels",
    "hotels": "hotels",
    "flight": "flights",
    "flights": "flights",
    "attraction": "attractions",
    "attractions": "attractions",
    "taxi": "taxi",
}

# ----------------------------- Classification -----------------------------

def domain_from_id(sample_id: str | None) -> str | None:
    """Map an id prefix to a domain label (very weak fallback)."""
    if not sample_id:
        return None
    s = sample_id.strip().lower()
    s = re.split(r"-\d+$", s)[0]  # drop trailing numeric suffix like "-123"
    for part in re.split(r"[-_:+/ ]+", s):
        d = ID_PREFIX_MAP.get(part)
        if d:
            return d
    return None

def map_function_to_domain(name: str, description: str = "") -> str | None:
    """Map a function (by name/description) to a domain label."""
    n = (name or "").strip()
    d = (description or "").lower()

    for dom, keys in NAME_HINTS.items():
        if any(k.lower() in n.lower() for k in keys):
            return dom
    for dom, eps in ENDPOINT_HINTS.items():
        if any(ep in d for ep in eps):
            return dom

    if "hotel" in d:
        return "hotels"
    if "flight" in d:
        return "flights"
    if "car rental" in d or "cars" in d:
        return "car_rentals"
    if "attraction" in d:
        return "attractions"
    if "taxi" in d:
        return "taxi"
    return None

def record_allowed_domains(funcs: Any) -> Set[str]:
    """Domains implied by what the sample ALLOWS (from the functions list)."""
    out: Set[str] = set()
    if isinstance(funcs, list):
        for f in funcs:
            if not isinstance(f, dict):
                continue
            dom = map_function_to_domain(f.get("name", ""), f.get("description", ""))
            if dom:
                out.add(dom)
    return out or {"unknown"}

def compile_name_regex(names: List[str]) -> re.Pattern | None:
    """Compile a case-insensitive regex to detect function name mentions."""
    names = [n for n in names if n]
    if not names:
        return None
    pat = r"|".join(re.escape(n) for n in names)
    # Matches raw function names, or JSON-ish '"name":"<func>"'
    return re.compile(rf"(?:\"name\"\s*:\s*\")?(?:{pat})(?:\")?", flags=re.IGNORECASE)

def detect_invoked_domains(convs: Any, funcs: Any) -> Tuple[Set[str], Dict[str, List[str]]]:
    """
    Heuristically detect which domains are ACTUALLY INVOKED in the conversation text.
    Evidence sources:
      - Function name mentions
      - Endpoint substrings
    Returns:
      (invoked_domains, evidence_by_domain)
    """
    text = join_conversation_text(convs)  # lower-cased
    evidence: Dict[str, List[str]] = {d: [] for d in DOMAIN_LABELS}

    func_names: List[str] = []
    if isinstance(funcs, list):
        for f in funcs:
            if isinstance(f, dict):
                nm = f.get("name")
                if isinstance(nm, str) and nm.strip():
                    func_names.append(nm.strip())
    name_rx = compile_name_regex(func_names)

    hits: Set[str] = set()

    # 1) function-name hits
    if name_rx is not None:
        for match in name_rx.findall(text):
            for f in funcs or []:
                fn = f.get("name", "")
                if fn and fn.lower() in str(match).lower():
                    dom = map_function_to_domain(fn, f.get("description", ""))
                    if dom:
                        hits.add(dom)
                        evidence[dom].append(f"func:{fn}")

    # 2) endpoint hits
    for dom, eps in ENDPOINT_HINTS.items():
        for ep in eps:
            if ep in text:
                hits.add(dom)
                evidence[dom].append(f"ep:{ep}")

    evidence = {d: v for d, v in evidence.items() if v}
    return (hits, evidence)

def classify_record(rec: Dict[str, Any]) -> Dict[str, Any]:
    """
    Robust classification using both 'allowed' and 'invoked' signals.
    Priority:
      invoked_domains (if non-empty)  >  allowed_domains (if single)  >  id prefix
    """
    funcs = rec.get("functions", [])
    convs = rec.get("conversations", [])
    sample_id = rec.get("id", "")

    allowed = record_allowed_domains(funcs)
    invoked, evidence = detect_invoked_domains(convs, funcs)
    id_dom = domain_from_id(sample_id)

    if invoked:
        if len(invoked) == 1:
            primary = next(iter(invoked))
            bucket = primary
        else:
            primary = None
            bucket = "multi_domain"
    else:
        if len(allowed) == 1 and next(iter(allowed)) != "unknown":
            primary = next(iter(allowed))
            bucket = primary
        elif id_dom:
            primary = id_dom
            bucket = primary
        else:
            primary = None
            bucket = "multi_domain"

    return {
        "id": sample_id,
        "allowed_domains": sorted(allowed),
        "invoked_domains": sorted(invoked),
        "primary_domain": primary,
        "domain_bucket": bucket,
        "evidence": evidence,
    }

# ------------------------------- Stats -------------------------------

def conv_stats(convs: List[Dict[str, Any]]) -> Tuple[int, int, int]:
    """Return (#turns, #assistant_turns, total_chars)."""
    if not isinstance(convs, list):
        return 0, 0, 0
    total = len(convs)
    asst = sum(1 for m in convs if isinstance(m, dict) and m.get("role") == "assistant")
    chars = sum(len(_content_to_str(m.get("content", ""))) for m in convs if isinstance(m, dict))
    return total, asst, chars

def analyze_domains(jsonl_path: str) -> Tuple[Dict[str, List[Dict[str, Any]]], Counter, Counter, Dict[str, Dict[str, float]]]:
    """
    Iterate all records, classify each, and aggregate:
      - bucket: domain_bucket -> list of classified items (with per-item summary)
      - domain_counts: counts per bucket
      - multi_counts: composition counts for multi-domain rows
      - per_bucket_stats: basic averages (turns, assistant_turns, chars)
    """
    bucket: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    domain_counts: Counter = Counter()
    multi_counts: Counter = Counter()

    per_bucket_turns: Dict[str, List[int]] = defaultdict(list)
    per_bucket_asst: Dict[str, List[int]] = defaultdict(list)
    per_bucket_chars: Dict[str, List[int]] = defaultdict(list)

    for rec in iter_jsonl(jsonl_path):
        rec = normalize_conversations(rec)

        cls = classify_record(rec)
        convs = rec.get("conversations", [])
        turns, asst_turns, chars = conv_stats(convs)

        row = {**cls, "turns": turns, "assistant_turns": asst_turns, "chars": chars}
        bucket[cls["domain_bucket"]].append(row)

        if cls["domain_bucket"] == "multi_domain":
            domain_counts["multi_domain"] += 1
            combo = "+".join(cls["invoked_domains"] or cls["allowed_domains"])
            multi_counts[combo] += 1
        else:
            domain_counts[cls["domain_bucket"]] += 1

        per_bucket_turns[cls["domain_bucket"]].append(turns)
        per_bucket_asst[cls["domain_bucket"]].append(asst_turns)
        per_bucket_chars[cls["domain_bucket"]].append(chars)

    per_bucket_stats: Dict[str, Dict[str, float]] = {}
    for d, items in bucket.items():
        if not items:
            continue
        turns_list = per_bucket_turns[d]
        asst_list = per_bucket_asst[d]
        chars_list = per_bucket_chars[d]
        per_bucket_stats[d] = {
            "count": float(len(items)),
            "avg_turns": float(stats.mean(turns_list)) if turns_list else 0.0,
            "avg_assistant_turns": float(stats.mean(asst_list)) if asst_list else 0.0,
            "avg_chars": float(stats.mean(chars_list)) if chars_list else 0.0,
        }

    return bucket, domain_counts, multi_counts, per_bucket_stats

def print_summary(domain_counts: Counter, multi_counts: Counter, per_bucket_stats: Dict[str, Dict[str, float]]) -> None:
    """Pretty-print summary tables."""
    print("\n=== Domain summary (invoked-preferred) ===")
    total = sum(domain_counts.values()) or 1
    for d, c in domain_counts.most_common():
        print(f"{d:14s}: {c:4d} ({c/total:.1%})")

    if multi_counts:
        print("\n--- Multi-domain composition ---")
        for combo, c in multi_counts.most_common():
            print(f"{combo:40s}: {c}")

    print("\n=== Basic stats per domain ===")
    for d, st in per_bucket_stats.items():
        print(f"{d:14s} | #items={int(st['count']):4d} | "
              f"avg_turns={st['avg_turns']:.2f} | "
              f"avg_assistant_turns={st['avg_assistant_turns']:.2f} | "
              f"avg_chars={st['avg_chars']:.1f}")

# ------------------------------- Main -------------------------------

def main() -> None:
    # >>>>>>> EDIT THESE PATHS IF NEEDED <<<<<<<
    in_path = "../datasets/ComplexFuncBench.jsonl"
    write_jsonl = True
    out_jsonl = "../datasets/ComplexFuncBench.classified.jsonl"
    write_csv = True
    out_csv = "../datasets/ComplexFuncBench.domain_stats.csv"
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

    if not os.path.exists(in_path):
        raise FileNotFoundError(f"Input not found: {in_path}")

    bucket, domain_counts, multi_counts, per_bucket_stats = analyze_domains(in_path)
    print_summary(domain_counts, multi_counts, per_bucket_stats)

    if write_jsonl:
        os.makedirs(os.path.dirname(out_jsonl), exist_ok=True)
        with open(out_jsonl, "w", encoding="utf-8") as w:
            for d, items in bucket.items():
                for row in items:
                    w.write(json.dumps(row, ensure_ascii=False) + "\n")
        print(f"\nWrote classified JSONL: {out_jsonl}")

    if write_csv:
        os.makedirs(os.path.dirname(out_csv), exist_ok=True)
        rows: List[Dict[str, Any]] = []
        for d, items in bucket.items():
            for it in items:
                rows.append({
                    "id": it["id"],
                    "domain_bucket": it["domain_bucket"],
                    "primary_domain": it["primary_domain"] or "",
                    "invoked_domains": "+".join(it["invoked_domains"]),
                    "allowed_domains": "+".join(it["allowed_domains"]),
                    "turns": it["turns"],
                    "assistant_turns": it["assistant_turns"],
                    "chars": it["chars"],
                })
        if rows:
            with open(out_csv, "w", newline="", encoding="utf-8") as w:
                writer = csv.DictWriter(w, fieldnames=list(rows[0].keys()))
                writer.writeheader()
                writer.writerows(rows)
            print(f"Wrote CSV: {out_csv}")
        else:
            print("(CSV export skipped; no rows)")

def cal_statistical_statement_of_prompt():
    from collections import defaultdict
    from statistics import fmean, median  # fmean for float mean

    prompt_file_path = "../datasets/ComplexFuncBench.prompts.jsonl"

    intent_to_num = defaultdict(int)
    intent_to_len = defaultdict(list)

    # Read line by line and collect lengths per intent
    for record in iter_jsonl(prompt_file_path):
        intent = record.get("intent")
        if not intent:
            continue
        prompt_len = len(record.get("prompt", ""))

        intent_to_num[intent] += 1
        intent_to_len[intent].append(prompt_len)

    # Compute average and median prompt length per intent
    intent_avg_len = {k: fmean(v) for k, v in intent_to_len.items() if v}
    intent_median_len = {k: median(v) for k, v in intent_to_len.items() if v}

    # Pretty print
    print("=== Counts ===")
    print(dict(intent_to_num))
    print("\n=== Average lengths ===")
    print({k: round(v, 2) for k, v in intent_avg_len.items()})
    print("\n=== Median lengths ===")
    print(intent_median_len)

    # Optional: tabular summary
    print("\n=== Summary per intent ===")
    for intent in sorted(intent_to_num.keys()):
        cnt = intent_to_num[intent]
        avg = intent_avg_len.get(intent, 0.0)
        med = intent_median_len.get(intent, 0.0)
        print(f"{intent:20s} | #items={cnt:4d} | avg_len={avg:8.2f} | median_len={med}")

import json

def select_samples():
    # Per-intent selected records and a flat list for all selected
    selected_by_intent: dict[str, list] = {k: [] for k in INTENT_MEDIAN.keys()}
    selected_all: list = []

    current_total = 0

    # Iterate once over the JSONL; keep only items that pass the length filter
    for record in iter_jsonl(prompt_file_path):
        # Read fields safely
        intent = record.get("intent")
        prompt = record.get("prompt", "")

        # Skip unknown intents (not present in INTENT_MEDIAN)
        if intent not in INTENT_MEDIAN:
            continue

        # Enforce per-intent cap
        if len(selected_by_intent[intent]) >= SINGLE_THRESHOLD:
            continue

        # Enforce overall cap (break early if already enough)
        if current_total >= TOTAL_THRESHOLD:
            break

        # Length filter: prompt length must be <= median for that intent
        if len(prompt) <= INTENT_MEDIAN[intent]:
            selected_by_intent[intent].append(record)
            selected_all.append(record)
            current_total += 1

        # Optional: break early if we already reached overall cap
        if current_total >= TOTAL_THRESHOLD:
            break

    # ---- Reporting ----
    print("=== Per-intent counts ===")
    counts = {k: len(v) for k, v in selected_by_intent.items()}
    print(counts)
    print(f"\nTotal selected: {len(selected_all)} (target {TOTAL_THRESHOLD})")

    # Optional: peek a few IDs per intent
    for k, recs in selected_by_intent.items():
        ids_preview = [r.get("id") for r in recs[:]]
        print(f"{k:12s}: {len(recs):2d} | preview ids: {ids_preview}")

prompt_file_path = "../datasets/ComplexFuncBench.prompts.jsonl"

# Median prompt length per intent (given by you)
INTENT_MEDIAN = {
    "car_rental": 19580.0,
    "cross": 24746.5,
    "hotels": 18948.0,
    "attraction": 17486.0,
    "flights": 59062.5,
}

SINGLE_THRESHOLD = 5   # max samples per intent
TOTAL_THRESHOLD = 25   # max samples overall

import json
import os
from collections import defaultdict
from statistics import mean

# ---- paths ----
prompt_file_path     = "../datasets/ComplexFuncBench.prompts.jsonl"
selected_out_jsonl   = "../datasets/ComplexFuncBench.prompts.selected.jsonl"

# ---------- io helpers ----------
def iter_jsonl(path: str):
    """Yield JSON objects line by line from a JSONL file."""
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)

def load_id_map(path: str) -> dict:
    """Build a dict id -> record({id, intent, prompt})."""
    id_map = {}
    for obj in iter_jsonl(path):
        rid = obj.get("id")
        if rid:
            id_map[rid] = obj
    return id_map

def select_by_ids(id_map: dict, target_ids: dict[str, list[str]]):
    """
    Return:
      intent_to_prompts: dict intent -> list of {id, prompt}
      missing: list of ids that were not found in id_map
    """
    intent_to_prompts = defaultdict(list)
    missing = []
    for intent, id_list in target_ids.items():
        for rid in id_list:
            rec = id_map.get(rid)
            if not rec:
                missing.append(rid)
                continue
            # use intent from TARGET_IDS (grouping) and prompt from file
            prompt = rec.get("prompt", "")
            intent_to_prompts[intent].append({"id": rid, "prompt": prompt})
    return intent_to_prompts, missing

def write_selected_jsonl(intent_to_prompts: dict, out_path: str):
    """Write selected (id, intent, prompt) to one JSONL file."""
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    threshold = 2
    with open(out_path, "w", encoding="utf-8") as f:
        for intent, items in intent_to_prompts.items():
            if threshold < 0:
                break
            threshold -= 1
            for rec in items:
                obj = {"id": rec["id"], "intent": intent, "prompt": rec["prompt"]}
                f.write(json.dumps(obj, ensure_ascii=False) + "\n")
    print(f"Wrote selected prompts: {out_path}")

# ---------- CE/PPL ----------
def run_ce_ppl(intent_to_prompts: dict):
    """
    For each intent and each prompt, call user's cal_ce_ppl_sliding_streaming(text).
    Prints per-record metrics and per-intent averages.
    """
    # cal_ce_ppl_sliding_streaming(text) must be defined by you in this process.
    total = 0
    for intent, items in intent_to_prompts.items():
        print(f"\n--- {intent} ({len(items)} items) ---")
        ce_vals, ppl_vals = [], []
        for i, rec in enumerate(items, 1):
            rid  = rec["id"]
            text = rec["prompt"]
            ce_nats, ce_bits, ppl = ce_ppl_sliding_streaming(text)
            ce_vals.append(ce_nats)
            ppl_vals.append(ppl)
            total += 1
            print(f"[{i:02d}] id={rid} | len={len(text)} | "
                  f"CE(nats)={ce_nats:.6f} | CE(bits)={ce_bits:.6f} | PPL={ppl:.6f}")
        if ce_vals:
            print(f" -> avg CE(nats)={mean(ce_vals):.6f} | avg PPL={mean(ppl_vals):.6f}")
    print(f"\nTotal evaluated: {total}")

def interact_with_gpt_oss(prompt):
    tok = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True, local_files_only=True
    )
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token = tok.eos_token  #

    #
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=dtype,
        device_map="auto",  #
        low_cpu_mem_usage=True,
        attn_implementation="sdpa",  #
    )


    inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=4096)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=128,  # samll
            use_cache=False,  # save vRAM
            do_sample=True, temperature=0.7, top_p=0.9
        )

    #
    new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
    return tok.decode(new_tokens, skip_special_tokens=True)

import os, torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

MODEL_DIR = "<  your_model_dir_here  >"  #
import os, json
from typing import Optional, List, Dict, Any
from openai import OpenAI, OpenAIError

# ---  ---
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")  #
OPENAI_API_KEY = "<your_key_here>"  #

_client: Optional[OpenAI] = None
def _client_once() -> OpenAI:
    global _client
    if _client is None:
        if not OPENAI_API_KEY:
            raise RuntimeError("Missing OPENAI_API_KEY. Set it in env or .env.")
        _client = OpenAI(
            api_key=OPENAI_API_KEY,
            organization=os.getenv("OPENAI_ORG_ID"),
            project=os.getenv("OPENAI_PROJECT"),
        )
    return _client

def flatten_conversations_to_transcript(convs: List[Dict[str, Any]]) -> str:
    lines = []
    for m in convs:
        role = m.get("role") or m.get("from") or "user"
        content = m.get("content", m.get("value", ""))
        if "function_call" in m:
            fc = json.dumps(m["function_call"], ensure_ascii=False)
            content = f"{content}\n[function_call] {fc}" if content else f"[function_call] {fc}"
        if role not in ("system", "user", "assistant", "tool"):
            role = role
        lines.append(f"{role.capitalize()}: {content}")
    return "\n".join(lines)

def get_openai_available_model_list() -> list[dict[str, str | None | Any]]:
    client = _client_once()
    models = []
    for m in client.models.list():
        models.append({
            "id": m.id
        })

    models.sort(key=lambda x: x["id"])

    return models

def interact_with_gpt(
    prompt: Any,
    model: str = "gpt-5", # "gpt-5", "gpt-5-nano", "gpt-5-mini"
    system: Optional[str] = None,
    verbosity: Optional[str] = None,          # 'low' | 'medium' | 'high'
    reasoning_effort: Optional[str] = None,   # 'minimal' | 'low' | 'medium' | 'high'
    temperature: Optional[float] = None,
) -> str:
    client = _client_once()

    if isinstance(prompt, list):
        payload = flatten_conversations_to_transcript(prompt)
    else:
        payload = str(prompt)
    resp = client.responses.create(model=model, input=payload)
    return resp.output_text

from typing import Optional

def interact_with_oss_v2(
    prompt: str,
    *,
    model_id: str = "openai/gpt-oss-20b",
    max_new_tokens: int = 256,
    temperature: float = 0.7,
    top_p: float = 0.9,
    seed: Optional[int] = None,
) -> str:
    from transformers import AutoTokenizer, AutoModelForCausalLM

    tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
    model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b")
    messages = [
        {"role": "user", "content": prompt},
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, top_p=0.9)
    # print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))
    gen_ids = outputs[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

# ---------- main ----------
if __name__ == "__main__":
    path = "../datasets/ComplexFuncBench.jsonl"
    conversations = # conversations
    if conversations is None or conversations.strip() == "":
        print("start traversal")
        for rec in iter_jsonl(path):
            id = rec.get("id")
            if rec.get("id") == "Attraction-9":
                conversations = rec.get("conversations", [])
                print(json.dumps(rec, ensure_ascii=False, indent=2))
                break
    print("==============relection of original version is: ==============")
    print(interact_with_gpt(conversations))
