
import re
import pandas as pd
import numpy as np
from pathlib import Path

CONTEXT_CATEGORIES = {
    "scale": ["small", "medium", "large"],
    "seq_len": ["short", "medium", "long"],
    "structure": ["none", "explicit_static", "explicit_dynamic"],
    "time_irregularity": ["low", "medium", "high"],
    "modality": ["binary", "text/code", "text/dialogue"],
    "heterogeneity": ["low", "medium", "high"],
    "cold_start": ["low", "medium", "high"],
}

DATASET_CONTEXT = {
    "assistments 2009": dict(scale="medium", seq_len="short", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="medium", cold_start="medium"),
    "assistments 2012": dict(scale="medium", seq_len="short", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="medium", cold_start="medium"),
    "assistments 2015": dict(scale="medium", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="medium", cold_start="medium"),
    "assistments 2017": dict(scale="medium", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="medium", cold_start="medium"),
    "assistments": dict(scale="medium", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="medium", cold_start="medium"),
    "ednet": dict(scale="large", seq_len="long", structure="explicit_static", time_irregularity="high", modality="binary", heterogeneity="high", cold_start="high"),
    "statics": dict(scale="medium", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="medium", cold_start="medium"),
    "statics2011": dict(scale="medium", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="medium", cold_start="medium"),
    "kdd 2010": dict(scale="large", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="high", cold_start="high"),
    "kdd cup 2010": dict(scale="large", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="high", cold_start="high"),
    "duolingo": dict(scale="large", seq_len="long", structure="explicit_static", time_irregularity="high", modality="binary", heterogeneity="high", cold_start="high"),
    "junyi": dict(scale="large", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="high", cold_start="high"),
    "khan academy": dict(scale="large", seq_len="medium", structure="explicit_static", time_irregularity="medium", modality="binary", heterogeneity="high", cold_start="high"),
    "cs1": dict(scale="medium", seq_len="medium", structure="none", time_irregularity="high", modality="text/code", heterogeneity="medium", cold_start="medium"),
    "programming": dict(scale="medium", seq_len="medium", structure="none", time_irregularity="high", modality="text/code", heterogeneity="medium", cold_start="medium"),
    "dialogue": dict(scale="medium", seq_len="short", structure="none", time_irregularity="medium", modality="text/dialogue", heterogeneity="medium", cold_start="medium"),
}

def safe_lower(s):
    return "" if pd.isna(s) else str(s).lower()

def split_listish(x):
    if pd.isna(x):
        return []
    parts = re.split(r"[;,/]| and ", str(x))
    parts = [p.strip() for p in parts if p and p.strip()]
    out = []
    for p in parts:
        p = re.sub(r"\s+", " ", p)
        out.append(p)
    seen = set()
    uniq = []
    for p in out:
        key = p.lower()
        if key not in seen:
            uniq.append(p)
            seen.add(key)
    return uniq

def guess_context_for_dataset(name: str):
    if not name:
        return {}
    key = name.lower().strip()
    if "assistments" in key and key not in DATASET_CONTEXT:
        if "2009" in key: return DATASET_CONTEXT["assistments 2009"]
        if "2012" in key: return DATASET_CONTEXT["assistments 2012"]
        if "2017" in key: return DATASET_CONTEXT["assistments 2017"]
        return DATASET_CONTEXT["assistments"]
    if "kdd" in key and "2010" in key:
        return DATASET_CONTEXT["kdd cup 2010"]
    return DATASET_CONTEXT.get(key, {})

def parse_numeric_performance(text: str):
    if pd.isna(text): return []
    s = str(text)
    pairs = []
    for m in re.finditer(r"(AUC(?:-?ROC)?)\s*[:=]\s*([0]\.[0-9]+|1(?:\.0+)?)", s, flags=re.I):
        pairs.append(("AUC-ROC", float(m.group(2))))
    for m in re.finditer(r"(acc(?:uracy)?)\s*[:=]\s*([0]\.[0-9]+|1(?:\.0+)?)", s, flags=re.I):
        pairs.append(("accuracy", float(m.group(2))))
    for m in re.finditer(r"(log\s*loss|xe|cross[-\s]?entropy)\s*[:=]\s*([0]\.[0-9]+|[1-9]\d*(?:\.[0-9]+)?)", s, flags=re.I):
        pairs.append(("log_loss", float(m.group(2))))
    for m in re.finditer(r"(F1)\s*[:=]\s*([0]\.[0-9]+|1(?:\.0+)?)", s, flags=re.I):
        pairs.append(("f1", float(m.group(2))))
    return pairs

def protocol_quality_factor(text: str):
    s = safe_lower(text)
    factor = 1.0
    good = any(kw in s for kw in ["student-wise", "student wise", "student split"]) and any(kw in s for kw in ["chronological", "time", "temporal"])
    ood = any(kw in s for kw in ["unseen", "inductive", "cold-start", "cold start", "out-of-distribution", "ood"])
    leakage_prone = any(kw in s for kw in ["interaction-wise", "random split", "question id in train and test", "leakage", "mixing history"])
    if good or ood: factor *= 1.25
    if leakage_prone: factor *= 0.50
    return factor

def reporting_completeness_factor(metrics_text: str, perf_text: str):
    s1 = safe_lower(metrics_text); s2 = safe_lower(perf_text)
    text = s1 + " " + s2
    has_exact = bool(re.search(r"\b(auc|accuracy|f1|log\s*loss|cross[-\s]?entropy)\b", text)) and bool(re.search(r"\b\d+\.\d+\b", text))
    has_var = any(k in text for k in ["ci", "confidence interval", "std", "stdev", "variance", "±", "+/-"])
    factor = 1.0
    if has_exact: factor *= 1.10
    if "directional" in text or ("improves by" in text and not has_exact): factor *= 0.75
    if has_var: factor *= 1.05
    return factor

def compute_quality_weight(row):
    base = 1.0
    base *= protocol_quality_factor(row.get("split_protocol", ""))
    base *= reporting_completeness_factor(row.get("metrics", ""), (row.get("reported_performance", "") or "") + " " + (row.get("performance_summary", "") or ""))
    return base

def expand_with_context(df):
    rows = []
    for _, r in df.iterrows():
        datasets = split_listish(r.get("datasets", "")) or split_listish(r.get("data_summary", "")) or ["(unspecified)"]
        perf_pairs = parse_numeric_performance(r.get("reported_performance", "")) + parse_numeric_performance(r.get("performance_summary", ""))
        primary_metric, primary_value = None, None
        aucs = [v for m, v in perf_pairs if m.lower().startswith("auc")]
        if aucs:
            primary_metric, primary_value = "AUC-ROC", float(np.nanmean(aucs))
        elif perf_pairs:
            primary_metric, primary_value = perf_pairs[0]
        weight = compute_quality_weight(r)
        for ds in datasets:
            ctx = guess_context_for_dataset(ds)
            rows.append({
                "paper_id": f"{r.get('authors','')}_{r.get('year','')}".strip(),
                "year": r.get("year"),
                "title": r.get("title"),
                "venue": r.get("venue"),
                "model": r.get("model"),
                "acronym": r.get("acronym"),
                "model_family": r.get("model_family"),
                "dataset": ds,
                "primary_metric": primary_metric,
                "primary_value": primary_value,
                "quality_weight": weight,
                **{f"context_{k}": ctx.get(k, "unknown") for k in CONTEXT_CATEGORIES.keys()}
            })
    return pd.DataFrame(rows)

def weighted_counts(df_long, groupby_fields):
    d = df_long.copy()
    d["w"] = d["quality_weight"].fillna(1.0)
    unweighted = d.groupby(groupby_fields).size().rename("n").reset_index()
    weighted = d.groupby(groupby_fields)["w"].sum().rename("n_weighted").reset_index()
    return unweighted.merge(weighted, on=groupby_fields, how="outer")

def compute_weighted_win_rates(comparisons_df: pd.DataFrame, context_join: pd.DataFrame):
    c = comparisons_df.merge(context_join[["paper_id","dataset","model_family"] + [col for col in context_join.columns if col.startswith("context_")]].drop_duplicates(),
                              on=["paper_id","dataset","model_family"], how="left")
    results = []
    context_cols = [col for col in c.columns if col.startswith("context_")]
    for ctx_col in context_cols:
        for ctx_val, sub in c.groupby(ctx_col):
            w = sub["quality_weight"].replace([np.inf, -np.inf], np.nan).fillna(0.0)
            if w.sum() == 0: continue
            sub = sub.assign(w_norm=w / w.sum())
            wr = sub.groupby("model_family").apply(lambda g: float((g["is_winner"] * g["w_norm"]).sum())).rename("win_rate").reset_index()
            def w_median(x, w):
                order = np.argsort(x); x_sorted = np.array(x)[order]; w_sorted = np.array(w)[order]
                cw = np.cumsum(w_sorted); idx = np.searchsorted(cw, 0.5); idx = min(idx, len(x_sorted)-1); return float(x_sorted[idx])
            med = sub.groupby("model_family").apply(lambda g: w_median(g["normalized_rank"].values, g["w_norm"].values)).rename("median_norm_rank").reset_index()
            merged = wr.merge(med, on="model_family", how="outer")
            merged.insert(0, "context_dim", ctx_col.replace("context_",""))
            merged.insert(1, "context_value", ctx_val)
            results.append(merged)
    return pd.concat(results, ignore_index=True) if results else pd.DataFrame(columns=["context_dim","context_value","model_family","win_rate","median_norm_rank"])
