import os
import re
import sys
import json
import argparse
from typing import Dict, Any, Optional, List, Tuple

import pandas as pd
from tqdm import tqdm


SURVEY_PATTERNS = [
    r"\bsurvey\b",
    r"\breview\b",
    r"\b(comprehensive|systematic)\s+(survey|review)\b",
    r"\ban?\s+overview\b",
    r"\ba\s+survey\s+of\b",
    r"\bliterature review\b",
    r"\btutorial\b",
]
SURVEY_REGEX = re.compile("|".join(SURVEY_PATTERNS), flags=re.IGNORECASE)


def get_out_prefix_from_submissions_path(submissions_csv: str) -> str:
    base = submissions_csv
    if base.endswith(".submissions.csv"):
        base = base[: -len(".submissions.csv")]
    return base + ".filtered"


# ---------------- OpenAI-compatible client ----------------
def call_openai(model: str, base_url: str, api_key: str, messages: List[Dict[str, str]], temperature: float = 0.0) -> str:
    from openai import OpenAI
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY is required (or pass --api_key)")
    client = OpenAI(api_key=api_key, base_url=base_url)
    resp = client.chat.completions.create(model=model, messages=messages, temperature=temperature)

    # Try to extract text content robustly across providers
    try:
        content = (resp.choices[0].message.content or "").strip()
    except Exception:
        content = ""

    if not content:
        try:
            rc = getattr(resp.choices[0].message, "reasoning_content", "")
            if isinstance(rc, str) and rc.strip():
                content = rc.strip()
        except Exception:
            pass

    if not content:
        try:
            t = getattr(resp.choices[0], "text", "")
            if isinstance(t, str) and t.strip():
                content = t.strip()
        except Exception:
            pass

    if not content:
        try:
            d = resp.model_dump()
        except Exception:
            try:
                d = json.loads(resp.model_dump_json())
            except Exception:
                d = None
        if d:
            try:
                part = d.get("choices", [{}])[0].get("message", {}).get("content")
                if isinstance(part, list):
                    for p in part:
                        txt = p.get("text") or p.get("content")
                        if isinstance(txt, str) and txt.strip():
                            content = txt.strip(); break
                elif isinstance(part, str) and part.strip():
                    content = part.strip()
            except Exception:
                pass

    return content


def safe_json(text: str) -> Dict[str, Any]:
    text = (text or "").strip()
    if not text:
        return {"is_survey": False, "reason": "empty LLM output", "confidence": 0.0}
    # Try multiple ways to parse JSON robustly
    candidates: List[str] = []
    try:
        candidates.append(text[text.find("{"): text.rfind("}") + 1])
    except Exception:
        pass
    if "```json" in text:
        try:
            candidates.append(text.split("```json", 1)[1].split("```", 1)[0].strip())
        except Exception:
            pass
    candidates.append(text)
    for c in candidates:
        try:
            obj = json.loads(c)
            if isinstance(obj, dict):
                return obj
        except Exception:
            continue
    return {"is_survey": False, "reason": text[:300], "confidence": 0.0}


def build_messages(title: str, abstract: str) -> List[Dict[str, str]]:
    sys_prompt = (
        "You are an expert librarian. Decide if the paper is a SURVEY (systematic review, literature review, overview, tutorial) "
        "based ONLY on title and abstract. If abstract is missing, rely on title. "
        "If uncertain or ambiguous, answer not a survey. Output ONLY compact JSON with keys: "
        "is_survey (boolean), reason (string, <=160 chars), confidence (0..1)."
    )
    user_prompt = (
        f"Title: {title}\n"
        f"Abstract: {abstract or '(missing)'}\n\n"
        "Definition notes: Survey/Review/Overview/Tutorial synthesize prior work; not proposing a new primary method/dataset as the main contribution."
    )
    return [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": user_prompt},
    ]


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser(description="Filter true surveys using an LLM over title/abstract.")
    ap.add_argument("--submissions_csv", required=True, help="Path to *.submissions.csv from openreview_re.py")
    ap.add_argument("--reviews_csv", required=True, help="Path to *.reviews.csv from openreview_re.py")
    ap.add_argument("--out_prefix", default=None, help="Output prefix; default: derive from submissions path + '.filtered'")
    ap.add_argument("--model", default=os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
    ap.add_argument("--base_url", default=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"))
    ap.add_argument("--api_key", default=os.getenv("OPENAI_API_KEY", ""))
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--min_confidence", type=float, default=0.5, help="Minimum LLM confidence to keep as survey")
    ap.add_argument("--skip_if_regex_hits", action="store_true", help="If heuristic regex says survey, skip LLM (keep)")
    return ap.parse_args()


def main():
    args = parse_args()

    out_prefix = args.out_prefix or get_out_prefix_from_submissions_path(args.submissions_csv)

    subs = pd.read_csv(args.submissions_csv)
    revs = pd.read_csv(args.reviews_csv)

    # Normalize fields
    for col in ["title", "abstract"]:
        if col not in subs.columns:
            subs[col] = ""
        subs[col] = subs[col].fillna("").astype(str)

    # Classify each submission via LLM
    results: List[Dict[str, Any]] = []
    keep_forums: List[str] = []

    print(f"[INFO] Classifying {len(subs)} submissions with model={args.model} base_url={args.base_url}")

    for _, row in tqdm(subs.iterrows(), total=len(subs), desc="classify"):
        forum_id = str(row.get("forum_id", ""))
        title = str(row.get("title", ""))
        abstract = str(row.get("abstract", ""))

        # Optional fast-path: if regex clearly says survey, accept without LLM
        if args.skip_if_regex_hits and SURVEY_REGEX.search((title + "\n" + abstract)):
            results.append({
                "forum_id": forum_id,
                "is_survey": True,
                "confidence": 0.9,
                "reason": "regex:survey",
                "raw": {"mode": "regex"}
            })
            keep_forums.append(forum_id)
            continue

        msgs = build_messages(title=title, abstract=abstract)
        out_text = call_openai(args.model, args.base_url, args.api_key, msgs, temperature=args.temperature)
        parsed = safe_json(out_text)
        is_survey = bool(parsed.get("is_survey", False))
        confidence = float(parsed.get("confidence", 0.0) or 0.0)
        reason = str(parsed.get("reason", "")).strip()

        if is_survey and confidence >= args.min_confidence:
            keep_forums.append(forum_id)

        results.append({
            "forum_id": forum_id,
            "is_survey": is_survey,
            "confidence": confidence,
            "reason": reason,
            "raw": parsed,
        })

    keep_set = set([f for f in keep_forums if f])
    subs_f = subs[subs["forum_id"].astype(str).isin(keep_set)].reset_index(drop=True)
    revs_f = revs[revs["forum_id"].astype(str).isin(keep_set)].reset_index(drop=True)

    # Write outputs (CSV + JSONL)
    out_sub_csv = f"{out_prefix}.submissions.csv"
    out_rev_csv = f"{out_prefix}.reviews.csv"
    out_sub_jsonl = f"{out_prefix}.submissions.jsonl"
    out_rev_jsonl = f"{out_prefix}.reviews.jsonl"
    out_audit = f"{out_prefix}.classification.jsonl"

    subs_f.to_csv(out_sub_csv, index=False)
    revs_f.to_csv(out_rev_csv, index=False)

    with open(out_sub_jsonl, "w", encoding="utf-8") as f:
        for _, r in subs_f.iterrows():
            f.write(json.dumps(r.to_dict(), ensure_ascii=False) + "\n")

    with open(out_rev_jsonl, "w", encoding="utf-8") as f:
        for _, r in revs_f.iterrows():
            f.write(json.dumps(r.to_dict(), ensure_ascii=False) + "\n")

    with open(out_audit, "w", encoding="utf-8") as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    print(f"[DONE] kept submissions: {len(subs_f)} / {len(subs)} -> {out_sub_csv}")
    print(f"[DONE] kept reviews    : {len(revs_f)} / {len(revs)} -> {out_rev_csv}")
    print(f"[DONE] audit log       : {out_audit}")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n[INTERRUPTED]")
        sys.exit(130)
    except Exception as e:
        print(f"[ERROR] {e}", file=sys.stderr)
        sys.exit(1)


