# syntactic_detect_and_clarify_opt.py
import os, json, time, random, requests, threading
from typing import Dict, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

from utils.prompt import (
    syntactic_ambiguity_detection_prompt_template,
    syntactic_clarification_prompt_template,
)
from utils.config import CONFIG
from utils.utils import _strip_to_json, save_jsonl, load_jsonl

RAW_MUSIQUE_DEV = "download/raw_data/musique/musique_full_v1.0_train.jsonl"
DETECT_OUT_FILE = "dataset/train_musique_ambiguity_detection.jsonl"
CLARIFY_OUT_FILE = "dataset/train_musique_syntactic_clarified.jsonl"

DETECT_MODELS = [
    "openai/gpt-4.1",
    "meta-llama/llama-4-maverick",
    "qwen/qwen3-235b-a22b",
    "anthropic/claude-sonnet-4",
]
CLARIFY_MODEL = "openai/gpt-4.1"

OPENROUTER_URL = CONFIG["OPENROUTER_URL"]
OPENROUTER_KEY = CONFIG['OPENROUTER_KEY']

_thread_local = threading.local()

def _get_session() -> requests.Session:
    s = getattr(_thread_local, "session", None)
    if s is None:
        s = requests.Session()
        adapter = requests.adapters.HTTPAdapter(
            pool_connections=64, pool_maxsize=64, max_retries=0
        )
        s.mount("http://", adapter)
        s.mount("https://", adapter)
        _thread_local.session = s
    return s

def _post_openrouter(payload: dict, timeout: float) -> requests.Response:
    s = _get_session()
    headers = {"Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json"}

    def _once(p):
        return s.post(OPENROUTER_URL, headers=headers, json=p, timeout=timeout)

    r = None
    try:
        r = _once(payload)
        if r.status_code != 400:
            r.raise_for_status()
            return r
    except requests.RequestException:
        pass

    p2 = dict(payload)
    if "response_format" in p2:
        p2["response_format"] = {"type": "json_object"}

    backoff = 1.0
    for _ in range(4):
        try:
            r = _once(p2)
            if r.status_code == 429:
                ra = r.headers.get("Retry-After")
                sleep_t = float(ra) if ra else backoff
                time.sleep(sleep_t + random.random() * 0.25)
                backoff = min(backoff * 2, 8.0)
                continue
            if r.status_code >= 500:
                time.sleep(backoff + random.random() * 0.25)
                backoff = min(backoff * 2, 8.0)
                continue
            r.raise_for_status()
            return r
        except (requests.Timeout, requests.ConnectionError):
            time.sleep(backoff + random.random() * 0.25)
            backoff = min(backoff * 2, 8.0)
            continue
    r.raise_for_status()
    return r  

def call_with_json_multi(model: str, sentence: str) -> Dict[str, object]:
    schema = {
        "type": "object",
        "properties": {
            "is_ambiguous": {"type": "string", "enum": ["Y", "N"]},
            "categories": {
                "type": "array",
                "items": {"type": "integer", "minimum": 1, "maximum": 14},
            },
        },
        "required": ["is_ambiguous", "categories"],
        "additionalProperties": False,
    }
    sys_prompt = (
        "Return ONLY valid JSON that matches the given schema. "
        "No markdown, no code fences, no explanations."
    )
    payload = {
        "model": model,
        "temperature": 0.0,
        "messages": [
            {"role": "system", "content": sys_prompt},
            {"role": "user",
             "content": syntactic_ambiguity_detection_prompt_template(sentence=sentence)},
        ],
        "response_format": {"type": "json_schema", "json_schema": schema},
    }
    r = _post_openrouter(payload, timeout=90)
    raw = r.json()["choices"][0]["message"]["content"]
    try:
        return json.loads(raw)
    except json.JSONDecodeError:
        return json.loads(_strip_to_json(raw))

def _detect_one_record(record: dict) -> dict:
    qid = record["id"]
    query = record["question"]
    out_entry = {"qid": qid, "original_query": query, "models": {}}

    with ThreadPoolExecutor(max_workers=min(4, len(DETECT_MODELS))) as ex:
        futures = {ex.submit(call_with_json_multi, m, query): m for m in DETECT_MODELS}
        for fut in as_completed(futures):
            m = futures[fut]
            try:
                out_entry["models"][m] = fut.result()
            except Exception:
                out_entry["models"][m] = {"is_ambiguous": "N", "categories": []}
    return out_entry

def run_ambiguity_detection(n_sample: int = 300, workers: int = 8) -> None:
    with open(RAW_MUSIQUE_DEV, encoding="utf-8") as f:
        dataset = [json.loads(l) for l in f if l.strip()]
    if n_sample:
        dataset = dataset[10000:]

    results: List[Dict] = [None] * len(dataset)
    with ThreadPoolExecutor(max_workers=workers) as ex:
        fut2idx = {ex.submit(_detect_one_record, dataset[i]): i for i in range(len(dataset))}
        for fut in tqdm(as_completed(fut2idx), total=len(fut2idx), desc="MUSIQUE dev (syntactic detect)"):
            i = fut2idx[fut]
            try:
                results[i] = fut.result()
            except Exception:
                rec = dataset[i]
                results[i] = {"qid": rec["id"], "original_query": rec["question"], "models": {}}

    save_jsonl(results, DETECT_OUT_FILE)
    print(f"✅ Detection set saved to {DETECT_OUT_FILE}")

def call_generate_clarified(sentence: str, model: str = CLARIFY_MODEL, min_versions: int = 2) -> Dict[str, object]:
    schema = {
        "type": "object",
        "properties": {
            "clarified_queries": {
                "type": "array",
                "items": {"type": "string"},
                "minItems": min_versions,
            }
        },
        "required": ["clarified_queries"],
        "additionalProperties": False,
    }
    sys_prompt = (
        "Return ONLY valid JSON that matches the given schema. "
        "No markdown, no code fences, no explanations."
    )
    payload = {
        "model": model,
        "temperature": 0.0,
        "messages": [
            {"role": "system", "content": sys_prompt},
            {
                "role": "user",
                "content": syntactic_clarification_prompt_template(
                    sentence=sentence, min_versions=min_versions
                ),
            },
        ],
        "response_format": {"type": "json_schema", "json_schema": schema},
    }
    r = _post_openrouter(payload, timeout=120)
    raw = r.json()["choices"][0]["message"]["content"]
    try:
        return json.loads(raw)
    except json.JSONDecodeError:
        return json.loads(_strip_to_json(raw))

def filter_all_ambiguous(records: List[Dict]) -> List[Dict]:
    return [
        rec for rec in records
        if all(m.get("is_ambiguous") == "Y" for m in rec["models"].values())
    ]

def run_clarification(workers: int = 8) -> None:
    records = load_jsonl(DETECT_OUT_FILE)
    amb_records = filter_all_ambiguous(records)
    print(f"✓ Total   records loaded : {len(records)}")
    print(f"✓ All-ambiguous records : {len(amb_records)}")

    clarified_data: List[Dict] = [None] * len(amb_records)

    def _clarify_idx(i: int) -> Tuple[int, Dict]:
        rec = amb_records[i]
        out = call_generate_clarified(rec["original_query"], model=CLARIFY_MODEL, min_versions=2)
        return i, {
            "qid": rec["qid"],
            "original_query": rec["original_query"],
            "models": rec["models"],
            "categories_union": sorted(
                set().union(*(m.get("categories", []) for m in rec["models"].values()))
            ),
            "clarified_queries": out["clarified_queries"],
        }

    with ThreadPoolExecutor(max_workers=workers) as ex:
        futs = [ex.submit(_clarify_idx, i) for i in range(len(amb_records))]
        for fut in tqdm(as_completed(futs), total=len(futs), desc="Generating clarified (syntactic)"):
            try:
                i, item = fut.result()
                clarified_data[i] = item
            except Exception:
                pass

    clarified_data = [x for x in clarified_data if x]
    save_jsonl(clarified_data, CLARIFY_OUT_FILE)
    print(f"✅ Clarified set saved to {CLARIFY_OUT_FILE}")

def main() -> None:
    run_ambiguity_detection(n_sample=10000, workers=8)
    run_clarification(workers=8)

if __name__ == "__main__":
    main()
