#!/usr/bin/env python3
"""Run the gaussian_mixture_main_raw_scores notebook logic as a CLI script."""

from __future__ import annotations

import argparse
import json
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

REPO_NAME = "llm-judge-bias"
CWD = Path.cwd().resolve()
if (CWD / "src").exists() and (CWD / "notebooks").exists():
    REPO_ROOT = CWD
elif (CWD / REPO_NAME).exists():
    REPO_ROOT = CWD / REPO_NAME
else:
    for parent in CWD.parents:
        candidate = parent / REPO_NAME
        if candidate.exists():
            REPO_ROOT = candidate
            break
    else:
        raise RuntimeError(f"Could not locate {REPO_NAME} repository from {CWD}")

NOTEBOOK_DIR = REPO_ROOT / "notebooks"
SRC_ROOT = REPO_ROOT / "src"
JUDGE_OUTPUT_ROOT = REPO_ROOT / "judge_outputs" / "gaussian_mixture"
DATA_ROOT = REPO_ROOT / "data"
RESULTS_DIR = NOTEBOOK_DIR / "results"

if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

import pgm_tools  # noqa: E402
from pgm_tools import caresl_aggregate, majority_vote, uws_aggregate, ws_aggregate  # noqa: E402
from data_tools import extract_score_from_parsed_output  # noqa: E402

pd.options.display.float_format = "{:0.3f}".format

FAST_TENSOR_OPTS = {
    "max_iters": 6,
    "early_stop_patience": 50,
    "improvement_tol": 1e-3,
}

_original_caret = pgm_tools.caret_aggregate

def caret_aggregate_supervised(J, **kwargs):
    opts = dict(FAST_TENSOR_OPTS)
    extra = kwargs.pop("tensor_opts", {})
    if extra:
        opts.update(extra)
    return _original_caret(J, tensor_opts=opts, **kwargs)

pgm_tools.caret_aggregate = caret_aggregate_supervised
caret_aggregate = caret_aggregate_supervised

SCORE_COLUMN_CANDIDATES = [
    "parsed_output",
    "score_original_order",
    "score_ab",
    "pred_label_num",
    "pred_label_binary",
]

def maybe_remap_score_range(series: pd.Series, dataset: str, judge: str) -> pd.Series:
    numeric = pd.to_numeric(series, errors="coerce")
    valid = numeric.dropna()
    if valid.empty:
        return numeric.astype(float)

    min_val = float(valid.min())
    max_val = float(valid.max())
    unique = int(valid.nunique())

    if unique > 1 and min_val >= 0.0 and max_val <= 6.0 and (max_val - min_val) >= 4.5:
        scaled = (numeric - min_val) / (max_val - min_val)
        remapped = scaled * 6.0 - 3.0
        print(f"{dataset}/{judge}: remapped score range [{min_val:.2f}, {max_val:.2f}] -> [-3, 3]")
        return remapped.astype(float)

    return numeric.astype(float)


PREF_LABEL_MAP = {
    "a": 0.0,
    "model_a": 0.0,
    "left": 0.0,
    "b": 1.0,
    "model_b": 1.0,
    "right": 1.0,
}

PREF_NULL_VALUES = {"tie", "none", "nan", ""}


def extract_pref_labels(df: pd.DataFrame, column: str) -> pd.Series | None:
    if column not in df.columns:
        return None
    normalized = df[column].astype(str).str.strip().str.lower()
    mapped = normalized.map(PREF_LABEL_MAP)
    mapped = mapped.where(~normalized.isin(PREF_NULL_VALUES), np.nan)
    mapped = pd.to_numeric(mapped, errors="coerce")
    series = pd.Series(mapped.to_numpy(dtype=float), index=df.index, dtype=float)
    finite = series.dropna()
    if finite.empty:
        return None
    unique = np.unique(finite)
    if unique.size <= 1:
        return None
    return series

DATASET_CONFIGS = [
    {"name": "civilcomments", "label_source": "csv", "label_path": DATA_ROOT / "binary/civilcomments.csv", "label_column": "label", "judge_subdir": "civilcomments", "min_rating": 0.0, "max_rating": 9.0, "binary_threshold": 4.5, "score_columns": ["parsed_output"], "ranks": (4, 5, 6, 7)},
    {"name": "yelp", "label_source": "csv", "label_path": DATA_ROOT / "binary/yelp.csv", "label_column": "label", "judge_subdir": "yelp", "min_rating": 0.0, "max_rating": 9.0, "binary_threshold": 4.5, "score_columns": ["parsed_output"], "ranks": (4, 5, 6, 7)},
    {"name": "liar2", "label_source": "csv", "label_path": DATA_ROOT / "binary/liar2.csv", "label_column": "label", "judge_subdir": "liar2", "min_rating": 0.0, "max_rating": 9.0, "binary_threshold": 4.5, "score_columns": ["parsed_output"], "ranks": (4, 5, 6, 7)},
    {"name": "judgebench", "label_source": "judge", "label_column": "gold_label_binary", "judge_subdir": "judgebench", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "chatbot_arena", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "chatbot_arena_conversations", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "anthropic_harmless", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "anthropic_harmless", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "anthropic_helpful", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "anthropic_helpful", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "summarize", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "summarize", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "pku_better", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "pku_better", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "pku_safer", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "pku_safer", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "shp", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "shp", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "mtbench_gpt4", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "mtbench_gpt4", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "mtbench_human", "label_source": "judge", "label_column": "gold_label_binary", "pref_label_column": "pref_A_or_B", "judge_subdir": "mtbench_human", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
    {"name": "helpsteer3", "label_source": "csv", "label_path": DATA_ROOT / "score/helpsteer3_sampled.csv", "label_column": "overall_preference", "label_transform": lambda series: (series > 0).astype(int), "judge_subdir": "helpsteer3", "min_rating": -3.0, "max_rating": 3.0, "binary_threshold": 0.0, "score_columns": ["score_original_order", "score_ab"], "ranks": (4, 5, 6, 7)},
]

LAM_L_GRID = [1e-3, 5e-3, 1e-2, 5e-2, 1e-1]
LAM_S_GRID = [1e-3, 5e-3, 1e-2, 5e-2, 1e-1]

CACHE_DIR = REPO_ROOT / "outputs" / "gaussian_mixture"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

DEFAULT_CACHE_PATH = CACHE_DIR / "caret_results.json"


def load_cache(path: Path) -> dict:
    if path.exists():
        try:
            return json.loads(path.read_text())
        except json.JSONDecodeError:
            return {}
    return {}


def save_cache(cache: dict, path: Path) -> None:
    path.write_text(json.dumps(cache, indent=2, sort_keys=True))



def _json_default(obj):
    import numpy as __np
    if isinstance(obj, (__np.integer, __np.floating)):
        return obj.item()
    if isinstance(obj, __np.ndarray):
        return obj.tolist()
    raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")


def write_individual_json(directory: Path, dataset: str, payload: dict | None):
    data = dict(payload or {})
    data.setdefault("dataset", dataset)
    target_path = directory / f"gaussian_mixture_{dataset}.json"
    target_path.write_text(json.dumps(data, indent=2, sort_keys=True, default=_json_default))


def _maybe_flip_all_positive_labels(
    labels: np.ndarray,
    judge_df: pd.DataFrame,
    cfg: dict,
) -> tuple[np.ndarray, pd.DataFrame]:
    return labels, judge_df


def load_dataset(cfg: dict):
    judge_dir = JUDGE_OUTPUT_ROOT / cfg["judge_subdir"]
    csv_paths = sorted(judge_dir.glob("*.csv"))
    if not csv_paths:
        raise FileNotFoundError(f"No judge outputs found in {judge_dir}")

    judge_columns = {}
    label_series = None
    expected_len = None

    for csv_path in csv_paths:
        df = pd.read_csv(csv_path)
        if expected_len is None:
            expected_len = len(df)
        elif len(df) != expected_len:
            raise ValueError(f"Row count mismatch for {csv_path.name}: {len(df)} vs {expected_len}")

        if label_series is None:
            if cfg["label_source"] == "csv":
                label_df = pd.read_csv(cfg["label_path"])
                label_series = pd.to_numeric(label_df[cfg["label_column"]], errors="coerce")
            else:
                label_series = pd.to_numeric(df[cfg["label_column"]], errors="coerce")
            transform = cfg.get("label_transform")
            if transform is not None:
                label_series = transform(label_series)

        if label_series is not None:
            finite = label_series.dropna().to_numpy(dtype=float)
            if finite.size:
                unique_values = set(np.unique(finite))
                if unique_values == {0.0} and "gold_label_num" in df.columns:
                    numeric_gold = pd.to_numeric(df["gold_label_num"], errors="coerce")
                    if numeric_gold.notna().any():
                        label_series = (numeric_gold > 0).astype(int)
                        finite = label_series.dropna().to_numpy(dtype=float)
                        unique_values = set(np.unique(finite))
                if unique_values.issubset({-1.0, 0.0, 1.0}) and unique_values.intersection({-1.0, 1.0}):
                    label_series = (label_series > 0).astype(int)
                    finite = label_series.dropna().to_numpy(dtype=float)
                    unique_values = set(np.unique(finite))
                if label_series.nunique(dropna=True) <= 1:
                    pref_column = cfg.get("pref_label_column")
                    if pref_column:
                        pref_series = extract_pref_labels(df, pref_column)
                        if pref_series is not None and pref_series.nunique(dropna=True) > 1:
                            label_series = pref_series
                            finite = label_series.dropna().to_numpy(dtype=float)
                            unique_values = set(np.unique(finite))
                            if unique_values.issubset({-1.0, 0.0, 1.0}) and unique_values.intersection({-1.0, 1.0}):
                                label_series = (label_series > 0).astype(int)

        score_series = None
        for column in cfg.get("score_columns", SCORE_COLUMN_CANDIDATES):
            if column not in df.columns:
                continue
            series = pd.to_numeric(df[column], errors="coerce")
            if column == "parsed_output" and series.isna().any():
                series = df[column].apply(
                    lambda x: extract_score_from_parsed_output(
                        x,
                        min_rating=cfg["min_rating"],
                        max_rating=cfg["max_rating"],
                    )
                )
                series = pd.to_numeric(series, errors="coerce")
            score_series = series
            break

        if score_series is None:
            continue

        judge_name = csv_path.stem.replace("_prefs", "")
        score_series = maybe_remap_score_range(score_series, cfg.get("name", cfg.get("judge_subdir", "")), judge_name)
        judge_columns[judge_name] = score_series

    judge_df = pd.DataFrame(judge_columns)
    if judge_df.empty:
        raise ValueError(f"No usable judge scores found in {judge_dir}")

    min_rating = cfg["min_rating"]
    max_rating = cfg["max_rating"]
    within_bounds = judge_df.ge(min_rating) & judge_df.le(max_rating)
    mask = (
        (~label_series.isna())
        & within_bounds.all(axis=1)
        & judge_df.notna().all(axis=1)
    )

    labels = label_series[mask].astype(int).reset_index(drop=True)
    judge_df = judge_df[mask].reset_index(drop=True)

    label_array = labels.to_numpy(dtype=int)
    label_array, judge_df = _maybe_flip_all_positive_labels(label_array, judge_df, cfg)

    return label_array.astype(int, copy=False), judge_df, list(judge_df.columns)


def parse_args(argv: list[str] | None = None):
    parser = argparse.ArgumentParser(description="Run Gaussian mixture aggregation with raw judge scores.")
    parser.add_argument("--datasets", nargs="+", help="Optional subset of dataset names to run.")
    parser.add_argument("--use-cache", action="store_true", help="Reuse cached results when available.")
    parser.add_argument("--output", type=Path, help="Path for the metrics CSV (defaults to notebooks/gaussian_mixture_results.csv).")
    parser.add_argument("--judge-summary", type=Path, help="Optional path to write the dataset-to-judge mapping CSV.")
    parser.add_argument("--cache-path", type=Path, default=DEFAULT_CACHE_PATH, help="Location of the caret cache JSON.")
    parser.add_argument("--individual-cache-dir", type=Path, help="Directory for per-dataset JSON cache outputs.")
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    force_rerun = not args.use_cache

    results_dir = RESULTS_DIR
    results_dir.mkdir(parents=True, exist_ok=True)

    output_path = args.output or (results_dir / "gaussian_mixture_results.csv")
    judge_summary_path = args.judge_summary or (results_dir / "gaussian_mixture_judges.csv")
    output_path.parent.mkdir(parents=True, exist_ok=True)
    judge_summary_path.parent.mkdir(parents=True, exist_ok=True)

    individual_dir = args.individual_cache_dir or results_dir
    individual_dir.mkdir(parents=True, exist_ok=True)

    selected_configs = DATASET_CONFIGS
    if args.datasets:
        requested = set(args.datasets)
        name_to_cfg = {cfg["name"]: cfg for cfg in DATASET_CONFIGS}
        missing = requested - name_to_cfg.keys()
        if missing:
            raise SystemExit(f"Unknown dataset(s): {', '.join(sorted(missing))}")
        # Preserve the notebook ordering for any requested subset
        selected_configs = [name_to_cfg[name] for name in [cfg["name"] for cfg in DATASET_CONFIGS] if name in requested]

    cache_path = args.cache_path
    cache_data = load_cache(cache_path)
    if "__meta__" not in cache_data:
        cache_data["__meta__"] = {
            "notebook": "gaussian_mixture_main_raw_scores",
            "created": datetime.utcnow().isoformat(timespec="seconds"),
        }
    dataset_cache = cache_data.setdefault("datasets", {})

    results: list[dict] = []
    judge_registry: dict[str, list[str]] = {}
    skipped_datasets: dict[str, str] = {}

    for cfg in selected_configs:
        ds_name = cfg["name"]
        cached_entry = None if force_rerun else dataset_cache.get(ds_name)

        if cached_entry:
            status = cached_entry.get("status")
            if status == "ok":
                metrics = cached_entry.get("metrics")
                if metrics:
                    print(f"{ds_name}: using cached results")
                    results.append(metrics)
                    judge_registry[ds_name] = cached_entry.get("judges", [])
                    write_individual_json(individual_dir, ds_name, cached_entry)
                    continue
            elif status == "error":
                reason = cached_entry.get("error", "previous failure")
                print(f"{ds_name}: skipping (cached error: {reason})")
                skipped_datasets[ds_name] = reason
                write_individual_json(individual_dir, ds_name, cached_entry)
                continue

        try:
            y, judge_df, judge_names = load_dataset(cfg)
        except Exception as exc:  # pragma: no cover - defensive logging
            reason = str(exc)
            print(f"{ds_name}: skipped during load -> {reason}")
            dataset_cache[ds_name] = {
                "status": "error",
                "error": reason,
                "timestamp": datetime.utcnow().isoformat(timespec="seconds"),
            }
            save_cache(cache_data, cache_path)
            write_individual_json(individual_dir, ds_name, dataset_cache[ds_name])
            skipped_datasets[ds_name] = reason
            continue

        judge_registry[ds_name] = judge_names

        print(f"{ds_name}: {len(y)} examples, {len(judge_names)} judges")
        print("Judges:", ", ".join(judge_names))

        if judge_df.shape[1] < 3:
            reason = "CARET requires at least 3 judges to form groups"
            print(f"  skipping {ds_name}: {reason}")
            dataset_cache[ds_name] = {
                "status": "error",
                "error": reason,
                "judges": judge_names,
                "timestamp": datetime.utcnow().isoformat(timespec="seconds"),
            }
            save_cache(cache_data, cache_path)
            write_individual_json(individual_dir, ds_name, dataset_cache[ds_name])
            skipped_datasets[ds_name] = reason
            continue

        pos_rate = float(np.clip(np.mean(y), 0.0, 1.0))
        class_balance = float(np.clip((1.0 - pos_rate) * 100.0, 0.0, 100.0))

        indices = np.arange(len(y))
        unique_labels = np.unique(y)
        if unique_labels.size > 1:
            idx_rest, idx_val = train_test_split(
                indices,
                test_size=0.10,
                random_state=42,
                stratify=y,
            )
        else:
            val_size = max(1, int(round(0.10 * len(y))))
            perm = np.random.default_rng(42).permutation(indices)
            idx_val = perm[:val_size]
            idx_rest = perm[val_size:]
            if idx_rest.size == 0:
                idx_rest = idx_val

        idx_test = idx_rest

        best_acc = -1.0
        best_params: tuple[float | None, float | None] = (None, None)
        best_preds = None
        threshold = cfg["binary_threshold"]
        judge_binary = (judge_df >= threshold).astype(int)
        judge_scores_matrix = judge_df.to_numpy(dtype=float)

        last_error = None
        for lam_L in LAM_L_GRID:
            for lam_S in LAM_S_GRID:
                try:
                    preds = caret_aggregate(
                        judge_scores_matrix,
                        lam_L=lam_L,
                        lam_S=lam_S,
                        class_balance=class_balance,
                        ranks=cfg.get("ranks", (4, 5, 6, 7)),
                        tensor_opts=FAST_TENSOR_OPTS,
                    )
                except Exception as exc:  # pragma: no cover - defensive logging
                    last_error = str(exc)
                    continue

                preds = np.asarray(preds, dtype=int)
                val_acc = accuracy_score(y[idx_val], preds[idx_val])
                if val_acc > best_acc:
                    best_acc = val_acc
                    best_params = (lam_L, lam_S)
                    best_preds = preds

        if best_preds is None:
            reason = last_error or "caret_aggregate did not return any predictions"
            print(f"  skipping {ds_name}: {reason}")
            dataset_cache[ds_name] = {
                "status": "error",
                "error": reason,
                "judges": judge_names,
                "timestamp": datetime.utcnow().isoformat(timespec="seconds"),
            }
            save_cache(cache_data, cache_path)
            write_individual_json(individual_dir, ds_name, dataset_cache[ds_name])
            skipped_datasets[ds_name] = reason
            continue

        caret_pred = np.asarray(best_preds, dtype=int)
        mv_pred = np.asarray(majority_vote(judge_binary), dtype=int)
        ws_pred = np.asarray(ws_aggregate(judge_binary), dtype=int)

        avg_scores = judge_df.mean(axis=1)
        avg_pred = np.asarray((avg_scores >= threshold), dtype=int)

        uws_scores = np.asarray(uws_aggregate(judge_df), dtype=float)
        uws_pred = np.asarray((uws_scores >= threshold), dtype=int)

        caresl_scores = np.asarray(caresl_aggregate(judge_df), dtype=float)
        caresl_pred = np.asarray((caresl_scores >= threshold), dtype=int)

        def acc(pred):
            arr = np.asarray(pred, dtype=int)
            return float((arr[idx_test] == y[idx_test]).mean())

        metrics = {
            "dataset": ds_name,
            "n_examples": int(len(y)),
            "n_judges": int(judge_df.shape[1]),
            "mv": acc(mv_pred),
            "avg": acc(avg_pred),
            "ws": acc(ws_pred),
            "uws": acc(uws_pred),
            "caresl": acc(caresl_pred),
            "caret": acc(caret_pred),
            "lam_L": best_params[0],
            "lam_S": best_params[1],
            "val_acc": best_acc,
            "val_size": int(idx_val.size),
            "test_size": int(idx_test.size),
            "class_balance": class_balance,
        }

        results.append(metrics)

        dataset_cache[ds_name] = {
            "status": "ok",
            "metrics": metrics,
            "judges": judge_names,
            "timestamp": datetime.utcnow().isoformat(timespec="seconds"),
            "hyperparams": {
                "lam_L_grid": list(LAM_L_GRID),
                "lam_S_grid": list(LAM_S_GRID),
                "ranks": list(cfg.get("ranks", (4, 5, 6, 7))),
            },
        }
        save_cache(cache_data, cache_path)
        write_individual_json(individual_dir, ds_name, dataset_cache[ds_name])

        print(
            "  best val acc:",
            f"{best_acc:.3f}",
            "with lam_L",
            best_params[0],
            "lam_S",
            best_params[1],
        )

    if results:
        result_df = pd.DataFrame(results).sort_values("dataset")
        result_df.to_csv(output_path, index=False)
        print(f"Wrote metrics to {output_path}")
    else:
        print("No results to write.")

    if judge_registry:
        judge_df = pd.DataFrame(
            {
                "dataset": list(judge_registry.keys()),
                "judges": [", ".join(v) for v in judge_registry.values()],
            }
        ).sort_values("dataset")
        judge_df.to_csv(judge_summary_path, index=False)
        print(f"Wrote judge registry to {judge_summary_path}")

    if skipped_datasets:
        print("Skipped datasets:")
        for name, reason in skipped_datasets.items():
            print(f"  {name}: {reason}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
