import re
import json
import argparse
from typing import List, Dict, Any, Optional
from datetime import datetime
from tqdm import tqdm
import pandas as pd
from tabulate import tabulate

import openreview
import sys

def get_val(x):

    if isinstance(x, dict) and "value" in x:
        return x["value"]
    return x

def cget(content, key, default=""):
    return get_val((content or {}).get(key, default))

def normalize(s: Optional[str]) -> str:
    return (s or "").strip()

# ---------------- Survey----------------
SURVEY_PATTERNS = [
    r"\bsurvey\b",
    r"\breview\b",
    r"\b(comprehensive|systematic)\s+(survey|review)\b",
    r"\ban?\s+overview\b",
    r"\ba\s+survey\s+of\b",
    r"\bliterature review\b",
]
SURVEY_REGEX = re.compile("|".join(SURVEY_PATTERNS), flags=re.IGNORECASE)

def is_survey_like(title: str, abstract: str = "") -> bool:
    text = f"{title}\n{abstract or ''}"
    return bool(SURVEY_REGEX.search(text))

DEFAULT_DECISION_KEYWORDS = [
    "accept", "accepted", "paper accepted",
    "oral", "spotlight", "poster",
    "conference",  
    "recommend accept",
]

def extract_decision_string(note) -> str:
    c = note.content or {}
    fields = [
        "decision", "acceptance decision", "final_decision",
        "venue", "venueid", "submission_decision",
    ]
    for f in fields:
        v = cget(c, f, None)
        if v:
            return str(v)
    return str(getattr(note, "venue", "") or getattr(note, "venueid", ""))

def is_accepted_submission(note, decision_keywords: List[str]) -> bool:
    dec = extract_decision_string(note).lower()
    if not dec:
        return False
    return any(k.lower() in dec for k in decision_keywords)

# ---------------- OpenReview ----------------
def get_client(baseurl: str):
    # if old version, use openreview.Client + https://api.openreview.net
    return openreview.api.OpenReviewClient(baseurl=baseurl)

REVIEW_INVITATION_HINTS = [
    "Official_Review", "Paper_Review", "Review",
    "Meta_Review", "Decision",
    "Public_Comment", "Official_Comment"
]

def looks_like_review_invitation(inv: str) -> bool:
    inv = inv or ""
    return any(h.lower() in inv.lower() for h in REVIEW_INVITATION_HINTS)

def is_public_readers(readers) -> bool:
    readers = readers or []
    s = " ".join(readers).lower()
    return ("everyone" in readers) or ("openreview.net" in s)

def get_note_invitation_string(note) -> str:
    """Return a string representing the note's invitation(s) for matching.
    Supports both 'invitation' (str) and 'invitations' (list[str]) fields.
    """
    inv = getattr(note, "invitation", None)
    if inv:
        return str(inv)
    invs = getattr(note, "invitations", None)
    if isinstance(invs, list) and invs:
        return ",".join([str(i) for i in invs])
    return ""

def list_public_children(client, forum_id: str) -> List[Any]:
    children = []
    for n in openreview.tools.iterget_notes(client, forum=forum_id):
        if not is_public_readers(n.readers):
            continue
        inv_str = get_note_invitation_string(n)
        if looks_like_review_invitation(inv_str) or n.replyto:
            children.append(n)
    return children

# def search_survey_submissions(
#     client,
#     venue_ids: Optional[List[str]] = None,
#     invitations: Optional[List[str]] = None,
#     since_ms: Optional[int] = None,
#     accepted_only: bool = False,
#     decision_keywords: Optional[List[str]] = None,
# ) -> List[Any]:
#     if not venue_ids and not invitations:
#         raise ValueError("Must provide --venues or --invitations")

#     params: Dict[str, Any] = {}
#     if venue_ids:
#         params["content"] = {"venueid": venue_ids}
#     if invitations:
#         params["invitations"] = invitations
#     if since_ms:
#         params["mintcdate"] = since_ms

#     decision_keywords = decision_keywords or DEFAULT_DECISION_KEYWORDS
#     hits = []

#     for note in openreview.tools.iterget_notes(client, **params):
#         title = normalize(cget(note.content, "title", ""))
#         if not title:
#             continue
#         abstract = normalize(cget(note.content, "abstract", ""))

#         if not is_survey_like(title, abstract):
#             continue
#         if accepted_only and not is_accepted_submission(note, decision_keywords):
#             continue

#         hits.append(note)
#     return hits
def search_survey_submissions(
    client,
    venue_ids: Optional[List[str]] = None,
    invitations: Optional[List[str]] = None,
    since_ms: Optional[int] = None,
    accepted_only: bool = False,
    decision_keywords: Optional[List[str]] = None,
) -> List[Any]:
    if not venue_ids and not invitations:
        raise ValueError("Must provide --venues or --invitations")

    decision_keywords = decision_keywords or DEFAULT_DECISION_KEYWORDS

    def iter_notes_with_params(params: Dict[str, Any]):
        for note in openreview.tools.iterget_notes(client, **params):
            yield note

    seen_forums = set()
    hits: List[Any] = []
    if venue_ids:
        for v in venue_ids:
            params: Dict[str, Any] = {"content": {"venueid": v}}
            if since_ms:
                params["mintcdate"] = since_ms
            for note in iter_notes_with_params(params):
                if note.forum in seen_forums:
                    continue
                title = normalize(cget(note.content, "title", ""))
                if not title:
                    continue
                abstract = normalize(cget(note.content, "abstract", ""))
                if not is_survey_like(title, abstract):
                    continue
                if accepted_only and not is_accepted_submission(note, decision_keywords):
                    continue
                seen_forums.add(note.forum)
                hits.append(note)

    if invitations:
        for inv in invitations:
            params: Dict[str, Any] = {"invitations": inv}
            if since_ms:
                params["mintcdate"] = since_ms
            for note in iter_notes_with_params(params):
                if note.forum in seen_forums:
                    continue
                title = normalize(cget(note.content, "title", ""))
                if not title:
                    continue
                abstract = normalize(cget(note.content, "abstract", ""))
                if not is_survey_like(title, abstract):
                    continue
                if accepted_only and not is_accepted_submission(note, decision_keywords):
                    continue
                seen_forums.add(note.forum)
                hits.append(note)

    return hits

def main():
    ap = argparse.ArgumentParser(
        description="Fetch OpenReview reviews for survey-like papers with optional 'accepted only' filtering."
    )
    ap.add_argument("--venues", nargs="*", default=None,
                    help="Filter by venue ids, e.g. ICLR.cc/2025/Conference  NeurIPS.cc/2024/Conference")
    ap.add_argument("--invitations", nargs="*", default=None,
                    help="Filter by submission invitations, e.g. ICLR.cc/2025/Conference/-/Blind_Submission")
    ap.add_argument("--since", type=str, default=None,
                    help="Only submissions created since this date (YYYY-MM-DD)")
    ap.add_argument("--limit", type=int, default=None, help="Limit number of submissions processed")
    ap.add_argument("--accepted_only", action="store_true",
                    help="Keep only accepted submissions")
    ap.add_argument("--decision_keywords", type=str, default=None,
                    help="Comma-separated keywords to judge acceptance (case-insensitive). "
                         "Default includes accept/accepted/paper accepted/oral/spotlight/poster/conference")
    ap.add_argument("--out_prefix", type=str, default="openreview_survey_reviews",
                    help="Output file prefix (CSV & JSONL)")
    ap.add_argument("--no_preview", action="store_true",
                    help="Do not print preview tables to stdout")
    ap.add_argument("--show_topk", type=int, default=10,
                    help="How many submissions to preview in stdout (default: 10)")
    ap.add_argument("--baseurl", type=str, default="https://api2.openreview.net",
                    help="OpenReview API baseurl (default: https://api2.openreview.net)")
    args = ap.parse_args()

    if not args.venues and not args.invitations:
        print("[ERROR] You must provide at least one of --venues or --invitations for OpenReview v2.", file=sys.stderr)
        print("        Examples:", file=sys.stderr)
        print("          --venues ICLR.cc/2025/Conference", file=sys.stderr)
        print("          --invitations ICLR.cc/2025/Conference/-/Blind_Submission", file=sys.stderr)
        sys.exit(2)

    since_ms = None
    if args.since:
        since_ms = int(datetime.strptime(args.since, "%Y-%m-%d").timestamp() * 1000)

    decision_keywords = None
    if args.decision_keywords:
        decision_keywords = [k.strip() for k in args.decision_keywords.split(",") if k.strip()]

    client = get_client(args.baseurl)

    print("[INFO] Searching survey-like submissions...")
    subs = search_survey_submissions(
        client,
        venue_ids=args.venues,
        invitations=args.invitations,
        since_ms=since_ms,
        accepted_only=args.accepted_only,
        decision_keywords=decision_keywords,
    )
    if args.limit:
        subs = subs[: args.limit]

    sub_rows = []
    for s in subs:
        c = s.content or {}
        sub_rows.append({
            "forum_id": s.forum,
            "submission_number": getattr(s, "number", None),
            "title": cget(c, "title", ""),
            "abstract": cget(c, "abstract", ""),
            "venue": cget(c, "venue", "") or getattr(s, "venue", "") or getattr(s, "venueid", ""),
            "decision_string": extract_decision_string(s),
            "tcdate": s.tcdate,
        })
    sub_df = pd.DataFrame(sub_rows)

    if not args.no_preview:
        if sub_df.empty:
            print("[PREVIEW] No submissions found.")
        else:
            show_cols = ["submission_number", "title", "venue", "decision_string"]
            print("\n[PREVIEW] Survey-like submissions{}:".format(" (accepted only)" if args.accepted_only else ""))
            print(tabulate(sub_df[show_cols].head(args.show_topk), headers="keys", tablefmt="github", showindex=False))

    print("[INFO] Fetching public reviews/comments for matched submissions...")
    review_rows = []
    for s in tqdm(subs, desc="forums"):
        children = list_public_children(client, s.forum)
        for r in children:
            content = r.content or {}
            review_rows.append({
                "forum_id": s.forum,
                "submission_title": cget(s.content, "title", ""),
                "submission_venue": cget(s.content, "venue", "") or getattr(s, "venue", "") or getattr(s, "venueid", ""),
                "review_note_id": r.id,
                "invitation": get_note_invitation_string(r),
                "readers": ";".join(r.readers or []),
                "replyto": r.replyto,
                "tcdate": r.tcdate,
                "review_title": cget(content, "title") or cget(content, "subject"),
                "review_rating": cget(content, "rating") or cget(content, "recommendation"),
                "review_confidence": cget(content, "confidence"),
                "review_body": cget(content, "review") or cget(content, "comment") or cget(content, "content") or "",
                "raw_json": json.dumps(content, ensure_ascii=False),
            })

    reviews_df = pd.DataFrame(review_rows)

    sub_csv = f"{args.out_prefix}.submissions.csv"
    sub_jsonl = f"{args.out_prefix}.submissions.jsonl"
    rev_csv = f"{args.out_prefix}.reviews.csv"
    rev_jsonl = f"{args.out_prefix}.reviews.jsonl"

    sub_df.to_csv(sub_csv, index=False)
    with open(sub_jsonl, "w", encoding="utf-8") as f:
        for _, row in sub_df.iterrows():
            f.write(json.dumps(row.to_dict(), ensure_ascii=False) + "\n")

    reviews_df.to_csv(rev_csv, index=False)
    with open(rev_jsonl, "w", encoding="utf-8") as f:
        for _, row in reviews_df.iterrows():
            f.write(json.dumps(row.to_dict(), ensure_ascii=False) + "\n")

    print(f"\n[DONE] Submissions: {len(sub_df)} rows -> {sub_csv} / {sub_jsonl}")
    print(f"[DONE] Reviews    : {len(reviews_df)} rows -> {rev_csv} / {rev_jsonl}")
    if not args.no_preview and not reviews_df.empty:
        show_cols = ["submission_title", "invitation", "review_rating", "review_confidence"]
        print("\n[PREVIEW] First reviews:")
        print(tabulate(reviews_df[show_cols].head(min(args.show_topk, 10)), headers="keys", tablefmt="github", showindex=False))

if __name__ == "__main__":
    main()
