# -*- coding: utf-8 -*-
"""
RVB Evaluation Runner Script:
- Flattens RMB-style data (list of dicts) into a RewardBench-compatible JSONL format.
- Can run multiple RMs using the RewardBench CLI or score locally using a forward pass with a HuggingFace RM.
- Aggregates RSI / P90-P10 metrics for each prompt and exports to a CSV.

To run: python rvb_eval.py
"""

import os
import sys
import json
import glob
import math
import subprocess
import shutil
from pathlib import Path
from typing import List, Dict, Any, Optional
from collections import defaultdict

# =========================
# Configuration (all parameters are set here)
# =========================
CFG = {
    # Input can be a single file, e.g., "./output/chat.json"
    # Or a directory, e.g., "./output" (will recursively read all .json files).
    "INPUT_PATH": "./input_data",

    # Intermediate and result directories
    "WORKDIR": "./rb_work",                 # Will write to flat/ and results/ subdirectories.
    "FLAT_NAME": "combined.flat.jsonl",     # Filename for the flattened total JSONL file.

    # --- Option A: Evaluate with RewardBench CLI (Recommended) ---
    "USE_REWARDBENCH": True,
    # List of RMs to evaluate (add/remove as needed; first download may be slow/large).
    "RB_MODELS": [
        "Ray2333/GRM-llama3.2-3B-rewardmodel-ft",
        "Skywork/Skywork-Reward-Gemma-2-27B-v0.2",
    ],
    "RB_BATCH_SIZE": 8,
    "RB_CHAT_TEMPLATE": "raw", # 'raw' is the safest chat template for most RMs.
    # RewardBench executable name; typically "rewardbench" if it's in the system's PATH.
    "RB_ENTRY": "rewardbench",

    # --- Option B: Local HuggingFace RM Forward Pass (Fallback/Example) ---
    "USE_LOCAL_RM": True,
    "LOCAL_RM_MODELS": [
        "OpenAssistant/reward-model-deberta-v3-large-v2"
    ],
    "LOCAL_DEVICE": "auto",       # "auto" | "cuda" | "cpu"
    "LOCAL_MAX_LEN": 2048,
    "LOCAL_BATCH_SIZE": 4,

    # Metrics and Aggregation
    "AGG_OUT_CSV": "metrics_agg.csv", # An aggregated CSV will be generated.
}

# =========================
# Utility Functions
# =========================

def ensure_dir(p: Path):
    """Ensures that a directory exists."""
    p.mkdir(parents=True, exist_ok=True)

def list_json_files(path: Path) -> List[Path]:
    """Lists all .json files in a directory or returns the path if it's a single file."""
    if path.is_file() and path.suffix.lower() == ".json":
        return [path]
    return [Path(p) for p in glob.glob(str(path / "**" / "*.json"), recursive=True)]

def load_json(path: Path) -> Any:
    """Loads a JSON file."""
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)

def sanitize_name(s: str) -> str:
    """Sanitizes a string to be used as a valid filename component."""
    return "".join(ch if ch.isalnum() or ch in "-._" else "_" for ch in s)

def to_messages_from_conversation_input(conv: List[Dict[str, Any]]) -> List[Dict[str,str]]:
    """Directly maps RMB's 'conversation_input' to RewardBench's 'messages' list format."""
    out = []
    for m in conv or []:
        role = m.get("role", "user")
        content = m.get("content", "")
        out.append({"role": role, "content": content})
    # If empty, fallback to a placeholder.
    if not out:
        out = [{"role": "user", "content": "<empty>"}]
    return out

def flatten_rmb_to_jsonl(in_paths: List[Path], out_jsonl: Path):
    """Flattens RMB-style JSON files into a single JSONL file compatible with RewardBench."""
    ensure_dir(out_jsonl.parent)
    n_rec, n_lines = 0, 0
    with out_jsonl.open("w", encoding="utf-8") as w:
        for fp in in_paths:
            try:
                data = load_json(fp)
            except Exception as e:
                print(f"[WARN] skip {fp}: {e}")
                continue
            if not isinstance(data, list):
                print(f"[WARN] {fp} root is not a list; skipping.")
                continue
            for rec in data:
                conv = rec.get("conversation_input", [])
                msgs = to_messages_from_conversation_input(conv)
                pid = rec.get("bon_uid") or rec.get("id") or f"{fp.name}:{n_rec}"

                for c in rec.get("candidates", []) or []:
                    ans = (c.get("answer") or "").strip()
                    if not ans:
                        continue
                    sample = {
                        "messages": msgs + [{"role": "assistant", "content": ans}],
                        "meta": {
                            "prompt_id": pid,
                            "source": c.get("source"),
                            "llm_name": c.get("llm_name"),
                            "temperature": c.get("temperature"),
                            "quality_tier": c.get("quality_tier"),
                            "tag": c.get("tag"),
                            "from_file": str(fp),
                        }
                    }
                    w.write(json.dumps(sample, ensure_ascii=False) + "\n")
                    n_lines += 1
                n_rec += 1
    print(f"[FLAT] Wrote {n_lines} lines from {len(in_paths)} file(s) to {out_jsonl}")

def try_import(pkg: str) -> bool:
    """Checks if a package can be imported."""
    try:
        __import__(pkg)
        return True
    except ImportError:
        return False

def run_cmd(cmd: str, cwd: Optional[Path]=None) -> int:
    """Runs a command in the shell and returns its exit code."""
    print(f"[CMD] {cmd}")
    ret = subprocess.run(cmd, shell=True, cwd=str(cwd) if cwd else None)
    return ret.returncode

def percentile(arr: List[float], q: float) -> float:
    """Computes the q-th percentile of a list of numbers."""
    if not arr:
        return float("nan")
    arr2 = sorted(arr)
    k = (len(arr2) - 1) * (q / 100.0)
    f = math.floor(k)
    c = math.ceil(k)
    if f == c:
        return arr2[int(k)]
    d0 = arr2[f] * (c - k)
    d1 = arr2[c] * (k - f)
    return d0 + d1

# =========================
# A) RewardBench Evaluation
# =========================

def run_with_rewardbench(flat_jsonl: Path, workdir: Path, model_id: str, batch_size: int, chat_template: str, entry: str):
    """Runs a reward model evaluation using the RewardBench CLI."""
    out_dir = workdir / "results" / ("rb_" + sanitize_name(model_id))
    ensure_dir(out_dir)
    cmd = (
        f'"{entry}" '
        f'--model="{model_id}" '
        f'--dataset="{flat_jsonl}" '
        f"--load_json "
        f'--chat_template="{chat_template}" '
        f"--batch_size={batch_size} "
        f'--output_dir="{out_dir}"'
    )
    code = run_cmd(cmd)
    if code != 0:
        print(f"[RB] FAILED model={model_id} (exit={code})")
    else:
        print(f"[RB] DONE model={model_id} -> {out_dir}")
    return out_dir, code

def discover_rb_scores(out_dir: Path) -> Optional[Path]:
    """
    RewardBench output filenames can change with versions. 
    This function performs a lenient search for a jsonl/csv file containing 'score'.
    """
    patterns = ["**/*score*.jsonl", "**/*scores*.jsonl", "**/*score*.csv", "**/*scores*.csv"]
    for pattern in patterns:
        candidates = list(out_dir.glob(pattern))
        if candidates:
            return candidates[0]
    return None

# =========================
# B) Local HF RM Evaluation (Fallback)
# =========================

def local_rm_score(flat_jsonl: Path, workdir: Path, model_id: str, device: str = "auto", max_len: int = 2048, batch_size: int = 4):
    """Scores responses using a local HuggingFace reward model."""
    if not try_import("transformers") or not try_import("torch"):
        print("[LOCAL] `transformers` and `torch` are not installed, skipping local RM. Run: pip install transformers torch")
        return None
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    import torch

    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"

    tok = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSequenceClassification.from_pretrained(model_id).to(device).eval()

    out_dir = workdir / "results" / ("local_" + sanitize_name(model_id))
    ensure_dir(out_dir)
    out_path = out_dir / "scores.jsonl"

    buf_msgs, buf_meta = [], []
    def flush(buf_msgs, buf_meta, writer):
        if not buf_msgs: return
        enc = tok(buf_msgs, return_tensors="pt", truncation=True, max_length=max_len, padding=True).to(device)
        with torch.no_grad():
            out = model(**enc)
        scores = out.logits.squeeze(-1).detach().cpu().tolist()
        if not isinstance(scores, list): scores = [float(scores)]
        for m, meta, sc in zip(buf_msgs, buf_meta, scores):
            item = {"score": float(sc), "meta": meta}
            writer.write(json.dumps(item, ensure_ascii=False) + "\n")
        buf_msgs.clear(); buf_meta.clear()

    n = 0
    with open(flat_jsonl, "r", encoding="utf-8") as r, out_path.open("w", encoding="utf-8") as w:
        for line in r:
            try:
                obj = json.loads(line)
            except Exception:
                continue
            msgs = obj.get("messages", [])
            # Most conservative concatenation (raw mode): [USER]... [ASSISTANT]...
            text = ""
            for m in msgs:
                role = m.get("role", "user").upper()
                content = m.get("content", "")
                text += f"[{role}]\n{content}\n"
            buf_msgs.append(text)
            meta = obj.get("meta", {})
            # For easier aggregation later: keep prompt_id and other tags.
            buf_meta.append({
                "prompt_id": meta.get("prompt_id"),
                "tag": meta.get("tag"),
                "source": meta.get("source"),
                "llm_name": meta.get("llm_name"),
                "temperature": meta.get("temperature"),
                "quality_tier": meta.get("quality_tier"),
            })
            if len(buf_msgs) >= batch_size:
                flush(buf_msgs, buf_meta, w)
            n += 1
        flush(buf_msgs, buf_meta, w)
    print(f"[LOCAL] DONE model={model_id} scored {n} lines -> {out_path}")
    return out_path

# =========================
# Aggregate Metrics (RSI / P90-P10)
# =========================

def aggregate_scores_to_csv(score_files: List[Path], out_csv: Path, model_name_map: Dict[Path, str]):
    """
    Supports both RewardBench and local outputs: reads lines containing {"score": float, "meta": {prompt_id,...}}.
    """
    import csv

    # Read all scores
    per_model_prompt_scores = defaultdict(lambda: defaultdict(list))  # model -> prompt_id -> [scores]
    
    def try_load_jsonl(p: Path):
        ok = 0
        model_name = model_name_map.get(p)
        if not model_name: return
        with p.open("r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                try:
                    obj = json.loads(line)
                except Exception:
                    continue
                # Compatible with different field names.
                score = obj.get("score")
                if score is None: score = obj.get("rm_score")
                if score is None: score = obj.get("reward")
                meta = obj.get("meta") or {}
                pid = meta.get("prompt_id") or obj.get("prompt_id") or meta.get("id")
                if score is None or pid is None:
                    continue
                per_model_prompt_scores[model_name][pid].append(float(score))
                ok += 1
        print(f"[AGG] Loaded {ok} scored lines from {p}")

    for sf in score_files:
        try_load_jsonl(sf)

    # Calculate RSI / P90-P10
    ensure_dir(out_csv.parent)
    with out_csv.open("w", newline="", encoding="utf-8") as w:
        cw = csv.writer(w)
        cw.writerow(["model", "prompt_id", "n", "score_min", "score_max", "RSI", "P10", "P90", "P90-P10"])
        for m, d in per_model_prompt_scores.items():
            for pid, arr in d.items():
                if not arr: continue
                mn = min(arr); mx = max(arr); rsi = mx - mn
                p10 = percentile(arr, 10.0)
                p90 = percentile(arr, 90.0)
                cw.writerow([m, pid, len(arr), f"{mn:.6f}", f"{mx:.6f}", f"{rsi:.6f}", f"{p10:.6f}", f"{p90:.6f}", f"{(p90-p10):.6f}"])
    print(f"[AGG] Wrote per-prompt metrics to {out_csv}")

# =========================
# Main Workflow
# =========================

def main():
    in_path = Path(CFG["INPUT_PATH"])
    workdir = Path(CFG["WORKDIR"])
    ensure_dir(workdir)

    # 1) Flatten Data
    flat_dir = workdir / "flat"
    ensure_dir(flat_dir)
    flat_jsonl = flat_dir / CFG["FLAT_NAME"]
    inputs = list_json_files(in_path)
    if not inputs:
        print(f"[ERR] No input JSON files found in: {in_path}")
        sys.exit(1)
    flatten_rmb_to_jsonl(inputs, flat_jsonl)

    score_files = []
    model_name_map = {}

    # 2A) RewardBench Evaluation
    if CFG["USE_REWARDBENCH"]:
        # Check if rewardbench is available
        rb_ok = shutil.which(CFG["RB_ENTRY"]) is not None
        if not rb_ok:
            print(f"[WARN] Executable '{CFG['RB_ENTRY']}' not found, skipping RewardBench. Please run: pip install rewardbench")
        else:
            for mid in CFG["RB_MODELS"]:
                out_dir, code = run_with_rewardbench(
                    flat_jsonl, workdir, model_id=mid,
                    batch_size=CFG["RB_BATCH_SIZE"],
                    chat_template=CFG["RB_CHAT_TEMPLATE"],
                    entry=CFG["RB_ENTRY"]
                )
                if code == 0:
                    sf = discover_rb_scores(out_dir)
                    if sf:
                        score_files.append(sf)
                        model_name_map[sf] = f"RB::{mid}"
                    else:
                        print(f"[WARN] No scores file found in {out_dir}. The output path might differ in your RewardBench version.")

    # 2B) Local RM Evaluation (Fallback)
    if CFG["USE_LOCAL_RM"]:
        if not try_import("transformers"):
            print("[WARN] Local RM evaluation requires transformers/torch, skipping.")
        else:
            for mid in CFG["LOCAL_RM_MODELS"]:
                sf = local_rm_score(
                    flat_jsonl, workdir, model_id=mid,
                    device=CFG["LOCAL_DEVICE"],
                    max_len=CFG["LOCAL_MAX_LEN"],
                    batch_size=CFG["LOCAL_BATCH_SIZE"]
                )
                if sf:
                    score_files.append(Path(sf))
                    model_name_map[Path(sf)] = f"LOCAL::{mid}"

    if not score_files:
        print("[ERR] No score results available for aggregation. Please ensure either RewardBench or the local RM ran successfully.")
        sys.exit(2)

    # 3) Aggregate and Export
    out_csv = workdir / CFG["AGG_OUT_CSV"]
    aggregate_scores_to_csv(score_files, out_csv, model_name_map)

    print("\n[OK] Full pipeline finished. You can now:")
    print(f" - Open {flat_jsonl} to view the flattened data.")
    print(f" - Check the raw scores for each RM under {workdir / 'results' / '*'}")
    print(f" - View the aggregated RSI / P90-P10 metrics in {out_csv}")
    print("Tip: For cross-model comparison, it's recommended to first normalize scores (e.g., quantile or z-score) on a fixed calibration set before comparing RSI/bandwidth.")

if __name__ == "__main__":
    main()