# -*- coding: utf-8 -*-
"""
This script aggregates scores from 3+ Reward Models (RMs), normalizes them,
computes variance (RSI) and consensus metrics, and then samples a final
"Eval-Core" subset for analysis.

It is designed to work with a specific data directory structure:
- Base Candidates: ./flat/combined.flat.jsonl
  Each line is a JSON object: {"messages":[...], "meta": {...}}
  The 'meta' field should contain prompt_id, cand_index, domain, task, etc.

- RM Scores: ./results/**/{*_outputs.jsonl|scores.jsonl}
  Each line is a JSON object that must contain a score ("results" or "score")
  and identifiers ("meta.prompt_id", "meta.cand_index"). If cand_index is
  missing, the script will attempt a fallback match based on response text.

To Run:
    Place this script in the root of your data directory (e.g., 'rb_work_grouped').
    Then execute:
    python rvb_select_core_from_grouped.py
"""

import os, json, glob, random
from collections import defaultdict, Counter
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd

# ---------------- Configuration ----------------
# Assume the script is placed in the root directory of the grouped data.
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
COMBINED_PATH = os.path.join(BASE_DIR, "flat", "combined.flat.jsonl")
RESULTS_GLOB = os.path.join(BASE_DIR, "results", "**", "*.jsonl")
OUT_DIR = os.path.join(BASE_DIR, "rvb_core_out")

# --- Selection and Thresholds ---
TOP_SPAN_FRAC = 0.35      # High-span: Prompts where RSI_mean is in the top 35%.
HI_CONS_THR = 0.70        # High-consensus threshold.
LO_CONS_THR = 0.40        # Disagreement threshold.
K_PER_PROMPT = 9          # Number of candidates to select per prompt.
MIN_CANDS_AFTER_FILTER = 6 # Minimum candidates per prompt to be considered.
SHINGLE_N = 5             # N-gram size for near-duplicate detection.
NEAR_DUP_THR = 0.90       # Jaccard similarity threshold for near-duplicates.

RANDOM_SEED = 0
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# ---------------- Utilities ----------------
def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def last_assistant_text(messages: List[Dict]) -> str:
    """Extracts the content of the last assistant message."""
    txt = ""
    for m in messages:
        if (m.get("role") or "").lower() == "assistant":
            txt = m.get("content", "")
    return txt

def build_cand_key(meta: Dict) -> Tuple[str, int]:
    pid = str(meta.get("prompt_id"))
    cidx = meta.get("cand_index")
    # Fallback: Most files have cand_index, but some may not.
    if cidx is None:
        raise KeyError("missing meta.cand_index")
    return (pid, int(cidx))

def shingles(s, n=5):
    """Generates a set of n-character shingles from a string."""
    s = "".join(ch for ch in s if ch.isprintable()).strip()
    if len(s) < n:
        return {s} if s else set()
    return {s[i:i+n] for i in range(len(s) - n + 1)}

def jaccard(a: set, b: set) -> float:
    if not a or not b: return 0.0
    inter = len(a & b)
    union = len(a | b)
    return inter / union if union else 0.0

def is_garbage(t: str) -> bool:
    """Basic filter for low-quality or malformed text."""
    if not t or len(t) < 30: return True
    if t.count("\n") > 50: return True
    c = Counter(t)
    # Check if a single character makes up >40% of the text
    if max(c.values()) / max(1, len(t)) > 0.4:
        return True
    return False

def quantile_normalize_by_group(df, group_col, score_col, out_col):
    df = df.copy()
    for g, sub in df.groupby(group_col):
        x = sub[score_col].values.astype(float)
        ranks = pd.Series(x).rank(method="average").values
        q = (ranks - 0.5) / len(x)
        df.loc[sub.index, out_col] = q
    return df

def rsi_p90_p10(x):
    """Calculates the P90-P10 range (a robust spread metric)."""
    x = np.asarray(x, dtype=float)
    if len(x) < 2: return 0.0
    return float(np.percentile(x, 90) - np.percentile(x, 10))

def spearman(x, y):
    """Calculates Spearman's rank correlation coefficient."""
    if len(x) < 2 or len(y) < 2: return 0.0
    rx = pd.Series(x).rank().values
    ry = pd.Series(y).rank().values
    vx = rx - rx.mean()
    vy = ry - ry.mean()
    denom = (np.sqrt((vx**2).sum()) * np.sqrt((vy**2).sum()))
    return float((vx * vy).sum() / denom) if denom > 0 else 0.0

# ---------------- 1) Load Base Candidates ----------------
print(f"[INFO] Loading base candidates from: {COMBINED_PATH}")
base_rows = []
# For matching scores back to candidates if cand_index is missing.
idx_by_pid_text = defaultdict(dict)
for obj in load_jsonl(COMBINED_PATH):
    meta = obj.get("meta", {})
    pid = str(meta.get("prompt_id"))
    cidx = meta.get("cand_index")
    msgs = obj.get("messages", [])
    atext = last_assistant_text(msgs)

    if cidx is None:
        # Rare case; most combined.flat files include cand_index.
        # Use assistant text to create a stable index as a fallback.
        cidx = idx_by_pid_text[pid].get(atext)
        if cidx is None:
            cidx = len(idx_by_pid_text[pid])
            idx_by_pid_text[pid][atext] = cidx
    else:
        idx_by_pid_text[pid][atext] = int(cidx)

    row = {
        "prompt_id": pid,
        "cand_index": int(cidx),
        "cand_id": f"{pid}__{int(cidx)}",
        "assistant_text": atext,
        "messages": msgs,
        # Metadata
        "domain": meta.get("domain"),
        "task": meta.get("task"),
        "source": meta.get("source"),
        "temperature": meta.get("temperature"),
        "meta": meta
    }
    base_rows.append(row)

df_base = pd.DataFrame(base_rows)
print(f"[INFO] Base candidates loaded: {len(df_base)} rows, {df_base['prompt_id'].nunique()} prompts")

# ---------------- 2) Collect RM Scores ----------------
def infer_rm_label(path: str) -> str:
    """Infers a readable RM label from its output file path."""
    fn = os.path.basename(path)
    d1 = os.path.basename(os.path.dirname(path))
    # For *_outputs.jsonl, use "parent_dir/basename_without_suffix"
    if fn.endswith("_outputs.jsonl"):
        base = fn.replace("_outputs.jsonl", "")
        return f"{d1}/{base}"
    # For scores.jsonl, use the parent directory name
    if fn == "scores.jsonl":
        return d1
    # Fallback: filename without extension
    return os.path.splitext(fn)[0]

score_rows = []
rm_files = sorted(glob.glob(RESULTS_GLOB, recursive=True))
rm_files = [p for p in rm_files if (p.endswith("_outputs.jsonl") or os.path.basename(p) == "scores.jsonl")]
if not rm_files:
    raise RuntimeError(f"No RM score files found. Check the path: {RESULTS_GLOB}")

print("[INFO] Found RM score files:")
for p in rm_files:
    print(f"  - {os.path.relpath(p, BASE_DIR)}")

miss_cand_index_cases = 0

for path in rm_files:
    rm_label = infer_rm_label(path)
    n_line, n_ok = 0, 0
    for obj in load_jsonl(path):
        n_line += 1
        # Handle different key names for the score
        score = obj.get("results", obj.get("score", None))
        meta = obj.get("meta", {})
        pid = str(meta.get("prompt_id"))
        cidx = meta.get("cand_index", None)

        if pid is None or pid == "None":
            continue # Skip records without a valid prompt_id

        if cidx is None:
            # Attempt to match back using assistant text
            msgs = obj.get("messages", [])
            atext = last_assistant_text(msgs)
            cidx = idx_by_pid_text.get(pid, {}).get(atext)
            if cidx is None:
                miss_cand_index_cases += 1
                continue
        try:
            cidx = int(cidx)
        except (ValueError, TypeError):
            miss_cand_index_cases += 1
            continue
        
        score_rows.append({
            "prompt_id": pid,
            "cand_index": cidx,
            "cand_id": f"{pid}__{cidx}",
            "rm": rm_label,
            "score": float(score) if score is not None else None
        })
        n_ok += 1
    print(f"[INFO] Loaded {n_ok}/{n_line} score lines from {rm_label}")

df_scores = pd.DataFrame(score_rows)
if not len(df_scores):
    raise RuntimeError("Failed to load any RM scores. Check if your score files contain meta.prompt_id and meta.cand_index.")
print(f"[INFO] Total score rows: {len(df_scores)} (from {df_scores['rm'].nunique()} distinct RMs)")

# --- Use only a specified set of RMs ---
# Names should match the output of infer_rm_label
TARGET_RMS = {
    "CIR-AMS/BTRM_Qwen2_7b_0613",
    "nicolinho/QRM-Llama3.1-8B-v2",
    "local_OpenAssistant_reward-model-deberta-v3-large-v2" # Example for a locally run model
}
print("[INFO] Using target RM set:", TARGET_RMS)

# Check for missing RMs to avoid silent errors
available_rms = set(df_scores["rm"].unique())
missing_rms = TARGET_RMS - available_rms
if missing_rms:
    print("[WARN] The following target RMs were not found and will be ignored:", missing_rms)

df_scores = df_scores[df_scores["rm"].isin(TARGET_RMS)].copy()
print(f"[INFO] Kept {len(df_scores)} rows after RM filtering. Kept RMs:", sorted(df_scores["rm"].unique()))

# ---------------- 3) Merge into a Long DataFrame & Quantile Normalize ----------------
df_long = df_scores.merge(df_base[["cand_id", "assistant_text", "domain", "task", "source", "temperature"]], on="cand_id", how="left")
df_long = df_long.dropna(subset=["score"]) # Clean up missing scores
if not len(df_long):
    raise RuntimeError("No score records left after merging (all scores were NaN).")

df_long = quantile_normalize_by_group(df_long, "rm", "score", "score_qA")

# ---------------- 4) Deduplicate and Filter Low-Quality Candidates ----------------
keep_mask = np.ones(len(df_long), dtype=bool)
for pid, sub in df_long.groupby("prompt_id"):
    reps = sub.drop_duplicates("cand_id")[["cand_id", "assistant_text"]].copy()
    reps["is_bad"] = reps["assistant_text"].apply(is_garbage)
    bad_ids = set(reps[reps["is_bad"]]["cand_id"])
    
    reps = reps[~reps["cand_id"].isin(bad_ids)].copy()
    reps["shingles"] = reps["assistant_text"].apply(lambda s: shingles(s, SHINGLE_N))
    
    kept_rows = []
    for _, r in reps.iterrows():
        is_dup = False
        for k in kept_rows:
            if jaccard(r["shingles"], k["shingles"]) >= NEAR_DUP_THR:
                is_dup = True
                break
        if not is_dup:
            kept_rows.append(r)
            
    keep_ids = {r["cand_id"] for r in kept_rows}
    idx = sub.index
    bad_mask = df_long.loc[idx, "cand_id"].isin(bad_ids)
    dup_mask = ~df_long.loc[idx, "cand_id"].isin(keep_ids)
    keep_mask[idx] = ~(bad_mask | dup_mask)

df_f = df_long[keep_mask].copy()

# ---------------- 5) Compute Metrics: RSI / Consensus ----------------
metric_rows = []
for pid, sub in df_f.groupby("prompt_id"):
    n_cand = sub["cand_id"].nunique()
    if n_cand < MIN_CANDS_AFTER_FILTER:
        continue

    rsi_per_rm = {rm: rsi_p90_p10(srm["score_qA"].values) for rm, srm in sub.groupby("rm")}
    RSI_mean = float(np.mean(list(rsi_per_rm.values()))) if rsi_per_rm else 0.0
    RSI_min = float(np.min(list(rsi_per_rm.values()))) if rsi_per_rm else 0.0

    pv = sub.pivot_table(index="cand_id", columns="rm", values="score_qA", aggfunc="mean").dropna()
    correlations = []
    rms = list(pv.columns)
    for i in range(len(rms)):
        for j in range(i + 1, len(rms)):
            correlations.append(spearman(pv[rms[i]].values, pv[rms[j]].values))
    consensus = float(np.mean(correlations)) if correlations else 0.0

    # Dominant metadata for the prompt
    domain = sub["domain"].dropna().mode().iloc[0] if not sub["domain"].dropna().empty else None
    task = sub["task"].dropna().mode().iloc[0] if not sub["task"].dropna().empty else None

    metric_rows.append({
        "prompt_id": pid, "domain": domain, "task": task, "n_cand": int(n_cand),
        "RSI_mean": RSI_mean, "RSI_min": RSI_min, "consensus": consensus
    })

met = pd.DataFrame(metric_rows)
if not len(met):
    raise RuntimeError("No usable prompts left after filtering (too few candidates).")

q80 = met["RSI_mean"].quantile(1 - TOP_SPAN_FRAC)
met["span_bucket"] = np.where(met["RSI_mean"] >= q80, "High-span", "Non-high")
met["cons_bucket"] = met["consensus"].apply(
    lambda v: "High-consensus" if v >= HI_CONS_THR else ("Disagreement" if v < LO_CONS_THR else "Normal")
)

hs = met[met["span_bucket"] == "High-span"].copy()
hs_hc = hs[hs["cons_bucket"] == "High-consensus"]
hs_dis = hs[hs["cons_bucket"] == "Disagreement"]

m = min(len(hs_hc), len(hs_dis))
if m == 0:
    # If one bucket is empty, fall back to the 'Normal' consensus group.
    print("[WARN] High-consensus or Disagreement bucket is empty. Falling back to Normal consensus.")
    hs_norm = hs[hs["cons_bucket"] == "Normal"]
    m = min(len(hs_hc), len(hs_norm)) if len(hs_hc) > 0 and len(hs_norm) > 0 else 0
    if m > 0:
        sel_prompts = pd.concat([
            hs_hc.sample(n=m, random_state=RANDOM_SEED),
            hs_norm.sample(n=m, random_state=RANDOM_SEED)
        ], ignore_index=True)
    else:
        sel_prompts = pd.DataFrame() # No prompts could be selected
else:
    sel_prompts = pd.concat([
        hs_hc.sample(n=m, random_state=RANDOM_SEED),
        hs_dis.sample(n=m, random_state=RANDOM_SEED)
    ], ignore_index=True)

if sel_prompts.empty:
    raise RuntimeError("Could not select any prompts based on the criteria. Check your data and thresholds.")
print(f"[INFO] Selected prompts: {len(sel_prompts)} (High-span balanced by consensus) | q80(RSI_mean)={q80:.4f}")

# ---------------- 6) Select K Candidates Per Prompt (3 each from low/mid/high tiers) ----------------
# Prepare candidate summary scores
cand_scores = df_f.groupby(["prompt_id", "cand_id"])["score_qA"].mean().reset_index().rename(columns={"score_qA": "score_mean_q"})
cand_meta = df_base.drop_duplicates("cand_id")

def pick_k_for_prompt(pid) -> List[Dict]:
    sub = cand_scores[cand_scores["prompt_id"] == pid].copy()
    if len(sub) < K_PER_PROMPT: return [] # Not enough candidates

    p20 = sub["score_mean_q"].quantile(0.2)
    p80 = sub["score_mean_q"].quantile(0.8)
    sub["tier"] = sub["score_mean_q"].apply(lambda v: "low" if v <= p20 else ("high" if v >= p80 else "mid"))
    
    join = sub.merge(cand_meta, on=["prompt_id", "cand_id"], how="left")
    
    chosen = []
    for tier_name in ["low", "mid", "high"]:
        pool = join[join["tier"] == tier_name].copy()
        if pool.empty: continue
        
        # Prioritize diversity in source/temperature
        groups = list(pool.groupby(["source", "temperature"]))
        random.shuffle(groups)
        
        slot = [g.sample(n=1, random_state=RANDOM_SEED).iloc[0].to_dict() for _, g in groups]
        
        # Backfill if not enough diverse candidates were found
        if len(slot) < 3:
            remaining = pool[~pool["cand_id"].isin([r["cand_id"] for r in slot])]
            if not remaining.empty:
                n_needed = min(3 - len(slot), len(remaining))
                slot.extend(remaining.sample(n=n_needed, random_state=RANDOM_SEED).to_dict("records"))
        
        chosen.extend(slot[:3])

    # Backfill if the total count is less than K
    if len(chosen) < K_PER_PROMPT:
        rest = join[~join["cand_id"].isin([r["cand_id"] for r in chosen])]
        n_needed = min(K_PER_PROMPT - len(chosen), len(rest))
        if n_needed > 0:
            chosen.extend(rest.sample(n=n_needed, random_state=RANDOM_SEED).to_dict("records"))
            
    return chosen[:K_PER_PROMPT]

# ---------------- 7) Write Output Files ----------------
ensure_dir(OUT_DIR)
met.to_csv(os.path.join(OUT_DIR, "prompt_metrics_all.csv"), index=False)
sel_prompts.to_csv(os.path.join(OUT_DIR, "prompt_metrics_selected.csv"), index=False)

out_path = os.path.join(OUT_DIR, "eval_core.jsonl")
w = 0
with open(out_path, "w", encoding="utf-8") as f:
    for _, prow in sel_prompts.iterrows():
        pid = prow["prompt_id"]
        picked_cands = pick_k_for_prompt(pid)
        for cand_data in picked_cands:
            rec = {
                "prompt_id": pid,
                "cand_index": cand_data.get("cand_index"),
                "domain": cand_data.get("domain"),
                "task": cand_data.get("task"),
                "messages": cand_data.get("messages"),
                "meta": cand_data.get("meta", {}),
            }
            # Add selection metrics to meta
            rec["meta"]["selection_span_bucket"] = prow.get("span_bucket")
            rec["meta"]["selection_cons_bucket"] = prow.get("cons_bucket")
            rec["meta"]["selection_RSI_mean"] = prow.get("RSI_mean")
            rec["meta"]["selection_consensus"] = prow.get("consensus")
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
            w += 1

print(f"[OK] Wrote {len(sel_prompts)} prompts ({w} total candidate rows) to the eval set.")
print(f"[OK] Prompt metrics (all): {os.path.join(OUT_DIR, 'prompt_metrics_all.csv')}")
print(f"[OK] Prompt metrics (selected): {os.path.join(OUT_DIR, 'prompt_metrics_selected.csv')}")
print(f"[OK] Final eval core set: {out_path}")