import os, re, json, argparse, random
from typing import List, Dict, Any
import pandas as pd
from tenacity import retry, stop_after_attempt, wait_exponential

def parse_args():
    ap = argparse.ArgumentParser(
        description="Derive a FIVE-dimension survey rubric + Discussion/Future-Work from OpenReview reviews."
    )
    ap.add_argument("--submissions_csv", required=True)
    ap.add_argument("--reviews_csv",     required=True)
    ap.add_argument("--out_dir",         default="rubric_out")
    ap.add_argument("--sample_papers",   type=int, default=50)
    ap.add_argument("--max_reviews_per_paper", type=int, default=10)
    ap.add_argument("--per_dim_quotes",  type=int, default=6)
    ap.add_argument("--rounds",          type=int, default=3)
    ap.add_argument("--seed",            type=int, default=42)

    # Optional overrides for each role
    ap.add_argument("--key_method",  default=os.getenv("KEY_METHOD"))
    ap.add_argument("--key_domain",  default=os.getenv("KEY_DOMAIN"))
    ap.add_argument("--key_stats",   default=os.getenv("KEY_STATS"))
    ap.add_argument("--key_adj",     default=os.getenv("KEY_ADJ"))

    ap.add_argument("--url_method",  default=os.getenv("URL_METHOD",  "https://api.openai.com/v1"))
    ap.add_argument("--url_domain",  default=os.getenv("URL_DOMAIN",  "https://api.openai.com/v1"))
    ap.add_argument("--url_stats",   default=os.getenv("URL_STATS",   "https://api.openai.com/v1"))
    ap.add_argument("--url_adj",     default=os.getenv("URL_ADJ",     "https://api.openai.com/v1"))

    ap.add_argument("--model_method",default=os.getenv("MODEL_METHOD","gpt-4o-mini"))
    ap.add_argument("--model_domain",default=os.getenv("MODEL_DOMAIN","gpt-4o-mini"))
    ap.add_argument("--model_stats", default=os.getenv("MODEL_STATS", "gpt-4o-mini"))
    ap.add_argument("--model_adj",   default=os.getenv("MODEL_ADJ",   "gpt-4o"))
    return ap.parse_args()

# ===================== text utils =====================
def clean(s: str) -> str:
    s = (s or "").strip()
    return re.sub(r"\s+", " ", s)

def sents(x: str) -> List[str]:
    return [t.strip() for t in re.split(r"(?<=[\.!?。！？])\s+", x or "") if t.strip()]

CATS = [
    ("coverage",          ["coverage","scope","comprehensive","breadth","recent work","omit","systematic","search","PRISMA","time window","inclusion","exclusion"]),
    ("structure",         ["structure","organization","section","flow","outline","coherent","narrative","roadmap","signposting"]),
    ("relevance",         ["relevance","on-topic","fit","significance","timely","up-to-date","practical value","impact","community"]),
    ("synthesis",         ["synthesis","taxonomy","categorization","framework","matrix","compare","positioning","integration","unify"]),
    ("critical_analysis", ["critical","critique","limitations","bias","assumption","open questions","gap","future work","challenge","governance"])
]

def bucket(sentence: str) -> List[str]:
    s = (sentence or "").lower()
    hits=[]
    for name, keys in CATS:
        if any(k in s for k in keys):
            hits.append(name)
    return hits

# ===================== data I/O =====================
def load_and_sample(sub_csv, rev_csv, sample_papers, max_reviews, seed):
    subs = pd.read_csv(sub_csv)
    revs = pd.read_csv(rev_csv)
    revs["review_body"] = revs["review_body"].fillna("").astype(str)
    revs = revs[revs["review_body"].str.len()>20]

    if sample_papers and sample_papers > 0:
        subs_samp = subs.sample(n=min(sample_papers, len(subs)), random_state=seed)
    else:
        subs_samp = subs

    forums = set(subs_samp["forum_id"].tolist())
    revs_samp = (
        revs[revs["forum_id"].isin(forums)]
        .groupby("forum_id")
        .head(max_reviews)
        .reset_index(drop=True)
    )
    return subs_samp, revs_samp

def extract_evidence(subs_samp: pd.DataFrame, revs_samp: pd.DataFrame, per_dim_quotes=6) -> Dict[str, Any]:
    by_dim = {name: [] for name,_ in CATS}
    stats  = {name: 0   for name,_ in CATS}

    for _, r in revs_samp.iterrows():
        body = clean(r["review_body"])
        for sent in sents(body):
            cats = bucket(sent)
            for c in cats:
                if len(by_dim[c]) < per_dim_quotes:
                    by_dim[c].append({
                        "source":"review",
                        "pointer":f"note:{r.get('review_note_id','NA')}",
                        "quote":sent[:600]
                    })
                stats[c]+=1

    # If some dims are sparse, fill with title/abstract cues
    for _, s in subs_samp.iterrows():
        t = clean(str(s.get("title",""))); a = clean(str(s.get("abstract","")))
        for c in ["coverage","structure","relevance","synthesis","critical_analysis"]:
            if t and len(by_dim[c]) < per_dim_quotes:
                by_dim[c].append({"source":"submission","pointer":f"forum:{s['forum_id']}:title","quote":t[:400]})
            if a and len(by_dim[c]) < per_dim_quotes:
                by_dim[c].append({"source":"submission","pointer":f"forum:{s['forum_id']}:abstract","quote":a[:600]})

    return {
        "quotes_by_dim": by_dim,
        "stats": stats,
        "sample_size": {"submissions":int(len(subs_samp)),"reviews":int(len(revs_samp))}
    }

def render_evidence_block(ctx: Dict[str,Any], per_cat=6) -> str:
    qb = ctx["quotes_by_dim"]; lines=[]
    order = [n for n,_ in CATS]
    for name in order:
        arr = qb.get(name, [])[:per_cat]
        if not arr: continue
        lines.append(f"### {name}")
        for q in arr:
            lines.append(f"- [{q['source']}|{q['pointer']}] {q['quote']}")
    return "\n".join(lines)

# ===================== OpenAI caller =====================
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, max=30))
def call_openai(model: str, base_url: str, api_key: str, messages: List[Dict[str, str]], temperature=0.2) -> str:
    from openai import OpenAI
    if not api_key:
        raise RuntimeError("Missing API key for model call.")
    client = OpenAI(api_key=api_key, base_url=base_url)
    resp = client.chat.completions.create(model=model, messages=messages, temperature=temperature)
    try:
        return (resp.choices[0].message.content or "").strip()
    except Exception:
        return ""

# ===================== prompting =====================
SYS_ROLE = (
  "You are {ROLE}. Use ONLY the provided evidence snippets from OpenReview to judge a FIVE-DIMENSION rubric "
  "(Coverage, Structure, Relevance, Synthesis, Critical Analysis). "
  "Tasks: (1) assess adequacy, (2) suggest refinements, (3) comment on weighting (may propose adaptive ideas)."
)

SYS_ADJ  = (
  "You are the Adjudicator. Merge role proposals into a single consistent rubric. "
  "Keep EXACTLY FIVE core dimensions with semantics matching: Coverage, Structure, Relevance, Synthesis, Critical Analysis. "
  "ASSIGN EQUAL WEIGHTS (all five = 0.2), and output weights that sum to 1.0. "
  "Each dimension must include: definition, observable_indicators, and 0/3/5 anchors. "
  "Output ONLY valid JSON in this format: "
  "{\"dimensions\": [{\"name\":\"Coverage\",\"definition\":\"...\",\"observable_indicators\":[\"...\"],\"weight\":0.2,"
  "\"scale\":{\"0\":\"...\",\"3\":\"...\",\"5\":\"...\"}}], \"notes\":\"...\"}"
)

DISCUSSION_SYS = (
  "You are drafting a 'Discussion & Future Work' section. "
  "Use ONLY the evidence snippets and the finalized rubric (JSON) to: "
  "(i) argue whether the five dimensions are justified by reviewer signals, "
  "(ii) discuss equal weighting vs adaptive weighting, "
  "(iii) propose smarter constructs (topic/venue-adaptive weights, uncertainty, reviewer–LLM agreement, interactions), "
  "(iv) outline concrete research directions. "
  "Write clearly in academic style (500–800 words) with brief inline references to evidence pointers like [review|note:ID]."
)

USR_CTX = """Context
- We evaluate SURVEY papers (ICLR/NeurIPS 2022–2025).
- Evidence size: submissions={SUBS}, reviews={REVS}
- Evidence snippets grouped by the five dimensions:
{EVID}

Task (Round {ROUND}/{TOTAL}):
{TASK}

Constraints:
- Be specific and measurable; tie suggestions to verifiable snippets (source/pointer/quote).
- Do not invent evidence; reason only from the snippets above.
"""

SEED_5 = {
  "dimensions": [
    {
      "name": "Coverage",
      "definition": "Breadth and currency of literature; explicit search protocol & inclusion/exclusion.",
      "observable_indicators": [
        "databases/keywords/time window or PRISMA-like flow",
        "coverage of key subfields including recent landmark works",
        "discussion of omissions or scope boundaries"
      ],
      "weight": 0.20,
      "scale": {"0":"unsystematic/major gaps","3":"partially systematic, some gaps","5":"systematic, up-to-date, reproducible"}
    },
    {
      "name": "Structure",
      "definition": "Logical organization and narrative flow that help readers navigate the field.",
      "observable_indicators": [
        "clear outline/section flow and signposting",
        "consistent terminology and transitions",
        "figures/tables that support the narrative"
      ],
      "weight": 0.20,
      "scale": {"0":"disorganized/confusing","3":"acceptable but uneven","5":"coherent and easy to follow"}
    },
    {
      "name": "Relevance",
      "definition": "Topical fit and significance for the venue/community; scientific or practical usefulness.",
      "observable_indicators": [
        "alignment with current community interests",
        "timeliness and significance of covered topics",
        "actionable takeaways for researchers/practitioners"
      ],
      "weight": 0.20,
      "scale": {"0":"off-topic/low value","3":"partly aligned/modest value","5":"highly aligned, timely, and useful"}
    },
    {
      "name": "Synthesis",
      "definition": "Integration beyond listing; comparative taxonomy/framework that reveals structure.",
      "observable_indicators": [
        "clear taxonomy or organizing dimensions",
        "comparative analysis across methods",
        "identification of trends, gaps, tensions"
      ],
      "weight": 0.20,
      "scale": {"0":"mostly a list","3":"some comparative synthesis","5":"compelling integrative framework"}
    },
    {
      "name": "Critical Analysis",
      "definition": "Depth of critique: assumptions, limitations, risks, open problems.",
      "observable_indicators": [
        "balanced assessment including weaknesses",
        "clear articulation of open questions and future directions",
        "discussion of risks (e.g., bias, ethics, governance)"
      ],
      "weight": 0.20,
      "scale": {"0":"little critique","3":"some critique","5":"deep, balanced, insightful"}
    }
  ],
  "notes": "Weights are set equal by design (0.2 each) for a fair, light-feedback baseline."
}

def safe_json(text: str) -> Dict[str,Any]:
    if not text:
        return {"dimensions": [], "notes": ""}
    try:
        return json.loads(text[text.find("{"): text.rfind("}")+1])
    except Exception:
        pass
    try:
        return json.loads(text.strip())
    except Exception:
        return {"dimensions": [], "notes": text}

# ===================== main =====================
def main():
    args = parse_args()
    random.seed(args.seed)
    os.makedirs(args.out_dir, exist_ok=True)

    # 1) Load & sample
    print("[1/5] Load & sample data...")
    subs, revs = load_and_sample(
        args.submissions_csv, args.reviews_csv,
        args.sample_papers, args.max_reviews_per_paper, args.seed
    )
    print(f"  samples: subs={len(subs)}, reviews={len(revs)}")

    # 2) Extract evidence
    print("[2/5] Extract evidence...")
    ctx = extract_evidence(subs, revs, args.per_dim_quotes)
    with open(os.path.join(args.out_dir,"evidence.json"),"w",encoding="utf-8") as f:
        json.dump(ctx,f,ensure_ascii=False,indent=2)
    evid_block = render_evidence_block(ctx, args.per_dim_quotes)

    # 3) Multi-round proposals (three roles)
    print(f"[3/5] Multi-round proposals ({args.rounds} rounds)...")
    ROLES = [
        ("Methodologist","method", args.model_method, args.url_method, args.key_method),
        ("DomainExpert","domain", args.model_domain, args.url_domain, args.key_domain),
        ("Statistician","stats",  args.model_stats,  args.url_stats,  args.key_stats),
    ]
    all_round_proposals = []

    for round_num in range(1, args.rounds + 1):
        print(f"  Round {round_num}/{args.rounds}...")
        proposals = []
        if round_num == 1:
            task = (
                "Critique the FIVE-DIMENSION rubric (Coverage, Structure, Relevance, Synthesis, Critical Analysis). "
                "Are these sufficient according to the evidence? Which refinements are needed? "
                "Propose insights on adaptive weighting, but remember the final rubric will use EQUAL weights. "
                "Seed (equal-weight): " + json.dumps(SEED_5, ensure_ascii=False)
            )
        else:
            prev = all_round_proposals[-1]
            task = (
                f"Refine your rubric considering other experts' Round {round_num-1} proposals. "
                "Address disagreements, improve precision, keep five dimensions."
            )

        for role_label, short, model, url, key in ROLES:
            usr = USR_CTX.format(
                SUBS=ctx["sample_size"]["submissions"], REVS=ctx["sample_size"]["reviews"],
                EVID=evid_block, ROUND=round_num, TOTAL=args.rounds, TASK=task
            )
            out = call_openai(model, url, key,
                [{"role":"system","content":SYS_ROLE.format(ROLE=role_label)},
                 {"role":"user","content":usr}], temperature=0.3)
            proposals.append({"role":role_label, "rubric": safe_json(out)})
        all_round_proposals.append(proposals)

        with open(os.path.join(args.out_dir,f"proposals.r{round_num}.jsonl"),"w",encoding="utf-8") as f:
            for p in proposals: f.write(json.dumps(p,ensure_ascii=False)+"\n")

    final_proposals = all_round_proposals[-1]

    print("[4/5] Adjudication: merge to final five-dim rubric (equal weights)...")
    adj = call_openai(
        args.model_adj, args.url_adj, args.key_adj,
        [{"role":"system","content":SYS_ADJ},
         {"role":"user","content":json.dumps({"proposals":[p["rubric"] for p in final_proposals]},ensure_ascii=False)}],
        temperature=0.2
    )
    final = safe_json(adj)
    dims = final.get("dimensions", [])


    if dims:
        ew = 1.0 / len(dims)
        for d in dims:
            d["weight"] = round(ew, 6)

    with open(os.path.join(args.out_dir,"rubric.final.json"),"w",encoding="utf-8") as f:
        json.dump(final,f,ensure_ascii=False,indent=2)

    # 5) Reports
    print("[5/5] Render reports...")
    # rubric.report.md
    rep = ["# Final Five-Dimension Rubric (Equal-Weight, from OpenReview Evidence)",
           f"- samples: submissions={len(subs)}, reviews={len(revs)}", "## Dimensions"]
    for d in dims:
        rep.append(f"### {d.get('name','(unnamed)')}")
        if "weight" in d: rep.append(f"- weight: {round(float(d['weight']),4)}")
        if "definition" in d: rep.append(f"- definition: {d.get('definition')}")
        if "observable_indicators" in d: rep.append(f"- indicators: {d.get('observable_indicators')}")
        if "scale" in d: rep.append(f"- scale: {d.get('scale')}")
        rep.append("")
    with open(os.path.join(args.out_dir,"rubric.report.md"),"w",encoding="utf-8") as f:
        f.write("\n".join(rep))

    # discussion_future_work.md
    DISC_PROMPT = (
        f"Rubric JSON (final, equal-weight):\n{json.dumps(final, ensure_ascii=False, indent=2)}\n\n"
        + USR_CTX.format(
            SUBS=ctx["sample_size"]["submissions"],
            REVS=ctx["sample_size"]["reviews"],
            EVID=render_evidence_block(ctx, args.per_dim_quotes),
            ROUND=1, TOTAL=1,
            TASK="Draft the Discussion & Future Work as specified."
        )
    )
    disc = call_openai(
        args.model_domain, args.url_domain, args.key_domain,
        [{"role":"system","content":DISCUSSION_SYS},
         {"role":"user","content":DISC_PROMPT}],
        temperature=0.4
    )
    with open(os.path.join(args.out_dir,"discussion_future_work.md"),"w",encoding="utf-8") as f:
        f.write(disc or "")

    print("[DONE] rubric ->", os.path.join(args.out_dir,"rubric.final.json"))
    print("[DONE] report ->", os.path.join(args.out_dir,"rubric.report.md"))
    print("[DONE] discussion ->", os.path.join(args.out_dir,"discussion_future_work.md"))

if __name__ == "__main__":
    main()
