from __future__ import annotations

import argparse
import json
import logging
import math
import os
import random
import re
import sys
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import lru_cache
from typing import Dict, List, Tuple

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from tqdm.auto import tqdm
from utils.config import CONFIG

WIKI_ENDPOINT = "https://en.wikipedia.org/w/api.php"
OR_URL = "https://openrouter.ai/api/v1/chat/completions"


MODELS = [
    "meta-llama/llama-4-maverick",
    "qwen/qwen3-235b-a22b",
    "anthropic/claude-sonnet-4",
    "openai/gpt-4.1",
]

TOP_K = 50
MAX_SRLIMIT = 20
MAX_RETRY = 4
REQUEST_SLEEP = 0.15   
MAX_RELAX_CAND = 4
MIN_VERS = 2  
CHK_EVERY = 200    
BATCH = 128   

OPENROUTER_KEY = CONFIG['OPENROUTER_KEY']
if not OPENROUTER_KEY:
    logging.warning("OPENROUTER_API_KEY is not set. Set env var before running LLM steps.")

UA_HEADERS = {"User-Agent": "gen-ambig/0.6 (contact: you@example.com)"}

logging.basicConfig(
    level=logging.INFO,
    format="[%(levelname)s %(asctime)s] %(message)s",
    datefmt="%H:%M:%S",
    stream=sys.stderr,
)

from utils.prompt import (
    generalize_ambiguity_detection_prompt_template,
    generalize_ambiguity_clarification_prompt_template,
)

_retry = Retry(total=3, backoff_factor=0.3, status_forcelist=(429, 500, 502, 503, 504))
_adapter = HTTPAdapter(pool_connections=64, pool_maxsize=64, max_retries=_retry)

WIKI = requests.Session()
WIKI.headers.update(UA_HEADERS)
WIKI.mount("https://", _adapter)
WIKI.mount("http://", _adapter)

OR = requests.Session()
OR.headers.update({"Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json"})
OR.mount("https://", _adapter)
OR.mount("http://", _adapter)

WIKI_WORKERS = 16
LLM_WORKERS = 4

TOKEN_PAT = re.compile(r"[A-Za-z0-9']+")
NUM_PAT = re.compile(r"\d{4}-\d{2}-\d{2}|\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d+")
QUOTE_PAT = re.compile(r'".+?"|\'.+?\'|“.+?”')

def toks(txt: str) -> List[str]:
    return TOKEN_PAT.findall(txt.lower())

def kl_divergence(p: Counter, q: Counter) -> float:
    vocab = set(p) | set(q)
    tp, tq = sum(p.values()) + len(vocab), sum(q.values()) + len(vocab)
    return sum(((p[t] + 1) / tp) * math.log(((p[t] + 1) / tp) / ((q[t] + 1) / tq)) for t in vocab)

BACKGROUND_Q = Counter({"the": 1})

def extract_constraints(q: str) -> List[str]:
    return sorted(set(QUOTE_PAT.findall(q) + NUM_PAT.findall(q)), key=len, reverse=True)

def _wiki_get(params: Dict, tries: int = MAX_RETRY) -> Dict:
    base = {"action": "query", "format": "json", "utf8": 1, "list": "search",
            "srprop": "snippet", "srinfo": "totalhits"}
    merged = {**base, **params}
    for i in range(tries):
        try:
            r = WIKI.get(WIKI_ENDPOINT, params=merged, timeout=20)
            if r.status_code == 429:
                retry_after = int(r.headers.get("Retry-After", 2 ** i))
                logging.warning(f"Wikipedia 429 – sleeping {retry_after}s (attempt {i+1}/{tries})")
                time.sleep(retry_after); continue
            r.raise_for_status()
            return r.json()
        except requests.RequestException as e:
            if i == tries - 1:
                raise
            logging.warning(f"Wikipedia error {e} – retry {i+1}/{tries}")
            time.sleep(2 ** i)
    return {}

@lru_cache(maxsize=100_000)
def wiki_search(q: str, srlimit: int = MAX_SRLIMIT) -> Tuple[int, List[str]]:
    data = _wiki_get({"srsearch": q, "srlimit": min(srlimit, MAX_SRLIMIT)})
    qobj = data.get("query", {})
    hits = qobj.get("searchinfo", {}).get("totalhits", 0)
    items = qobj.get("search", []) or []
    snippets = [re.sub(r"<[^>]+>", "", it.get("snippet", "")) for it in items]
    if REQUEST_SLEEP:
        time.sleep(REQUEST_SLEEP)
    return hits, snippets

def compute_metrics(q: str) -> Dict[str, float]:
    hits, snippets = wiki_search(q, TOP_K)
    kl_val = kl_divergence(Counter(t for s in snippets for t in toks(s)), BACKGROUND_Q)

    max_ratio = 1.0
    for c in extract_constraints(q)[:MAX_RELAX_CAND]:
        relaxed_q = q.replace(c, "").strip()
        if relaxed_q:
            relaxed_hits, _ = wiki_search(relaxed_q, 1)
            ratio = float("inf") if hits == 0 else relaxed_hits / max(hits, 1)
            max_ratio = max(max_ratio, ratio)
    return {"total_hits": hits, "kl_divergence": kl_val, "relax_delta_ratio": max_ratio}

def _safe_json(r: requests.Response) -> Dict:
    try: return r.json()
    except ValueError:
        logging.error("Non-JSON response from OpenRouter: %s", r.text[:200])
        return {}

LLM_CACHE: Dict[Tuple, str] = {}

def or_chat(messages, model: str, maxtok: int = 50, schema: Dict | None = None) -> str:
    """OpenRouter call with retries + cache."""
    if not OPENROUTER_KEY:
        raise RuntimeError("OPENROUTER_API_KEY missing in environment")

    payload = {"model": model, "temperature": 0.0, "messages": messages, "max_tokens": maxtok}
    if schema:
        payload["response_format"] = {"type": "json_object"}

    ck = (model, json.dumps(messages, ensure_ascii=False, sort_keys=True), maxtok, bool(schema))
    if ck in LLM_CACHE:
        return LLM_CACHE[ck]

    for attempt in range(MAX_RETRY):
        try:
            r = OR.post(OR_URL, data=json.dumps(payload), timeout=90)
            if r.status_code == 429:
                retry_after = int(r.headers.get("Retry-After", 2 ** attempt))
                logging.warning(f"OpenRouter 429 – sleeping {retry_after}s (attempt {attempt+1}/{MAX_RETRY})")
                time.sleep(retry_after); continue
            r.raise_for_status()
            data = _safe_json(r)
            if "choices" in data and data["choices"]:
                out = data["choices"][0]["message"]["content"]
                LLM_CACHE[ck] = out
                return out
            if "error" in data:
                raise RuntimeError(f"OpenRouter error: {data['error']}")
            raise KeyError("'choices' missing in response JSON")
        except (requests.RequestException, KeyError, RuntimeError) as e:
            if attempt == MAX_RETRY - 1:
                raise
            backoff = 2 ** attempt + random.random()
            logging.warning("or_chat failed (%s) – retrying in %.1fs", e, backoff)
            time.sleep(backoff)
    raise RuntimeError("or_chat failed after maximum retries")

def _strip_json(txt: str) -> str:
    txt = re.sub(r"```(?:json)?\s*|\s*```", "", txt, flags=re.I).strip()
    m = re.search(r"\{.*\}", txt, flags=re.S)
    return m.group(0) if m else "{}"

def llm_vote(model: str, q: str, metrics: Dict[str, float]) -> str:
    prompt = generalize_ambiguity_detection_prompt_template(query=q, **metrics)
    try:
        raw = or_chat(
            [
                {"role": "system", "content": "Return ONLY a JSON with key 'is_ambiguous'."},
                {"role": "user", "content": prompt},
            ],
            model,
        )
        m = re.search(r"\{.*\}", raw, flags=re.S)
        return json.loads(m.group(0))["is_ambiguous"] if m else "N"
    except Exception as e:
        logging.error("LLM vote failed for model %s – treating as 'N' (%s)", model, e)
        return "N"

def vote_one(model: str, q: str, metrics: Dict[str, float]) -> str:
    return llm_vote(model, q, metrics)

def vote_with_escalation(q: str, metrics: Dict[str, float]) -> Dict[str, str]:
    votes: Dict[str, str] = {}

    first = MODELS[0]
    v0 = vote_one(first, q, metrics)
    votes[first] = v0
    if v0 != "Y":
        return votes  

    rest = MODELS[1:]
    with ThreadPoolExecutor(max_workers=min(len(rest), LLM_WORKERS)) as ex:
        fut = {ex.submit(vote_one, m, q, metrics): m for m in rest}
        for f in as_completed(fut):
            m = fut[f]
            try:
                votes[m] = f.result()
            except Exception as e:
                logging.error("LLM vote failed for model %s – treating as 'N' (%s)", m, e)
                votes[m] = "N"
    return votes

def detect(records: List[Dict], sample_size: int | None = None):
    if sample_size:
        records = random.sample(records, min(sample_size, len(records)))
        logging.info("Subsampling %d / %d records", len(records), len(records))

    total = len(records)
    pbar = tqdm(total=total, desc="Ambiguity detection", unit="q")

    for bi in range(0, total, BATCH):
        batch = records[bi:bi+BATCH]

        with ThreadPoolExecutor(max_workers=WIKI_WORKERS) as ex:
            m_fut = {ex.submit(compute_metrics, rec["question"]): rec for rec in batch}
            metrics_map: Dict[str, Dict] = {}
            for f in as_completed(m_fut):
                rec = m_fut[f]
                try:
                    metrics_map[rec["id"]] = f.result()
                except Exception as e:
                    logging.error("metrics failed for %s: %s", rec["id"], e)
                    metrics_map[rec["id"]] = {"total_hits": 0, "kl_divergence": 0.0, "relax_delta_ratio": 1.0}

        out_batch = []
        with ThreadPoolExecutor(max_workers=LLM_WORKERS) as ex:
            v_fut = {ex.submit(vote_with_escalation, rec["question"], metrics_map[rec["id"]]): rec for rec in batch}
            for f in as_completed(v_fut):
                rec = v_fut[f]
                try:
                    model_votes = f.result()
                except Exception as e:
                    logging.error("vote error %s: %s", rec["id"], e)
                    model_votes = {m: "N" for m in MODELS}

                is_ambig = all(model_votes.get(m, "N") == "Y" for m in MODELS)
                out_batch.append(
                    {
                        "qid": rec["id"],
                        "original_query": rec["question"],
                        "metrics": metrics_map[rec["id"]],
                        "models": model_votes,
                        "is_ambiguous": is_ambig,
                    }
                )

        out_batch.sort(key=lambda r: r["qid"])
        for r in out_batch:
            yield r

        pbar.update(len(batch))

    pbar.close()

def generate_clarified(q: str, model: str = "openai/gpt-4.1") -> List[str]:
    schema = {
        "type": "object",
        "properties": {
            "clarified_queries": {
                "type": "array",
                "items": {"type": "string"},
                "minItems": MIN_VERS,
            }
        },
        "required": ["clarified_queries"],
        "additionalProperties": False,
    }
    raw = or_chat(
        [
            {"role": "system", "content": "Return ONLY valid JSON."},
            {"role": "user", "content": generalize_ambiguity_clarification_prompt_template(query=q, min_versions=MIN_VERS)},
        ],
        model,
        maxtok=256,
        schema=schema,
    )
    try:
        return json.loads(raw)["clarified_queries"]
    except json.JSONDecodeError:
        return json.loads(_strip_json(raw))["clarified_queries"]

def clarify(detect_recs: List[Dict]) -> List[Dict]:
    targets = [r for r in detect_recs if r["is_ambiguous"]]
    out: List[Dict] = []

    def _one(rec: Dict) -> Dict | None:
        try:
            cqs = generate_clarified(rec["original_query"])
            return {"qid": rec["qid"], "original_query": rec["original_query"], "clarified_queries": cqs}
        except Exception as e:
            logging.error("[%s] clarify error: %s", rec["qid"], e)
            return None

    with ThreadPoolExecutor(max_workers=LLM_WORKERS) as ex:
        fut = {ex.submit(_one, r): r for r in targets}
        for f in tqdm(as_completed(fut), total=len(fut), desc="Clarifying queries", unit="q"):
            item = f.result()
            if item:
                out.append(item)
    out.sort(key=lambda r: r["qid"])
    return out

def load_jsonl(fp: str) -> List[Dict]:
    with open(fp, encoding="utf-8") as f:
        return [json.loads(l) for l in f if l.strip()]

def save_jsonl(rows: List[Dict], fp: str):
    os.makedirs(os.path.dirname(fp), exist_ok=True)
    with open(fp, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--step", choices=["detect", "clarify", "all"], default="all")
    ap.add_argument("--in", dest="inp", default="download/raw_data/musique/musique_full_v1.0_train.jsonl")
    ap.add_argument("--det_out", default="dataset/train_general_ambiguity_detected.jsonl")
    ap.add_argument("--cl_out", default="dataset/train_general_ambiguity_clarified.jsonl")
    args = ap.parse_args()

    if args.step in ("detect", "all"):
        raw = load_jsonl(args.inp)
        detected: List[Dict] = list(detect(raw))
        save_jsonl(detected, args.det_out)
        logging.info("Detection saved → %s (%d records)", args.det_out, len(detected))
    else:
        detected = load_jsonl(args.det_out)

    if args.step in ("clarify", "all"):
        clarified = clarify(detected)
        save_jsonl(clarified, args.cl_out)
        logging.info("Clarified set saved → %s (%d records)", args.cl_out, len(clarified))

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        logging.warning("Interrupted by user – exiting")