"""Utilities for binary-judge aggregation experiments and validation sweeps."""

from __future__ import annotations

import itertools
from functools import partial
from pathlib import Path
from typing import Callable, Sequence, Tuple

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from snorkel.labeling.model import LabelModel

from data_tools import collect_judge_outputs, extract_score_from_parsed_output, is_valid_binary_score
from pgm_tools import caresl_aggregate, caret_aggregate, majority_vote, uws_aggregate, ws_aggregate

PROJECT_ROOT = Path(__file__).resolve().parents[1]
GAUSSIAN_JUDGE_ROOT = PROJECT_ROOT / 'judge_outputs' / 'gaussian_mixture'


PRED_COLUMN_CANDIDATES = ['pred_label_binary', 'pred_label_num', 'parsed_output', 'prediction', 'pred']
GOLD_COLUMN_CANDIDATES = ['pref_A_or_B', 'gold_label_binary', 'gold_label_num', 'label', 'gold_label']
STRING_TO_FLOAT = {
    'a': 0.0,
    'model_a': 0.0,
    'left': 0.0,
    'b': 1.0,
    'model_b': 1.0,
    'right': 1.0,
    'tie': float('nan'),
    'none': float('nan'),
}

def normalize_series(series: pd.Series) -> pd.Series:
    if series.dtype.kind in 'biufc':
        return pd.to_numeric(series, errors='coerce')
    lower = series.astype(str).str.strip().str.lower()
    return lower.map(STRING_TO_FLOAT)


def _balance_gold_if_constant(
    judge_df: pd.DataFrame,
    gold_series: pd.Series,
    *,
    seed: int | None = None,
    fraction: float = 0.0,
) -> tuple[pd.DataFrame, pd.Series]:
    del seed, fraction
    return judge_df, gold_series


def _select_gold_series(
    base_df: pd.DataFrame,
    ids: pd.Index,
    primary: str | None,
) -> tuple[pd.Series, str]:
    candidates: list[str] = []
    if primary:
        candidates.append(primary)
    for col in GOLD_COLUMN_CANDIDATES:
        if col not in candidates:
            candidates.append(col)

    fallback: tuple[pd.Series, str] | None = None
    for col in candidates:
        if col not in base_df.columns:
            continue
        series = normalize_series(base_df[col])
        series = pd.Series(series.to_numpy(dtype=float), index=ids, dtype=float)
        if series.notna().sum() == 0:
            continue
        unique = pd.unique(series.dropna())
        if unique.size <= 1:
            if (
                col == 'gold_label_binary'
                and 'gold_label_num' in base_df.columns
            ):
                derived = normalize_series(base_df['gold_label_num'])
                derived = pd.Series(
                    derived.to_numpy(dtype=float),
                    index=ids,
                    dtype=float,
                )
                if derived.notna().sum() > 0:
                    rebuilt = (derived > 0).astype(float)
                    rebuilt_unique = pd.unique(rebuilt.dropna())
                    if rebuilt_unique.size > 1:
                        return rebuilt, 'gold_label_num'
                    series = rebuilt
                    unique = rebuilt_unique
            if fallback is None:
                fallback = (series, col)
            continue
        return series, col

    if fallback is not None:
        return fallback
    raise KeyError('No usable gold label column found; inspected columns: ' + ', '.join(candidates))


def load_binary_dataset(
    dir_path: Path,
    *,
    id_column: str | None = None,
    gold_column: str | None = 'gold_label_binary',
    pred_column: str | None = 'pred_label_binary',
) -> Tuple[pd.DataFrame, pd.Series]:
    dir_path = Path(dir_path)
    csv_paths = sorted(dir_path.glob('*.csv'))
    if not csv_paths:
        raise FileNotFoundError(f'No judge outputs found in {dir_path}')

    legacy_binary = [p for p in csv_paths if p.stem.endswith('_binary')]
    if legacy_binary and len(legacy_binary) == len(csv_paths):
        csv_paths = legacy_binary

    base_df = pd.read_csv(csv_paths[0])
    if id_column and id_column in base_df.columns:
        ids = pd.Index(base_df[id_column].astype(str))
    else:
        ids = pd.Index(base_df.index.astype(str))

    gold_series, _ = _select_gold_series(base_df, ids, gold_column)

    columns: dict[str, pd.Series] = {}
    for csv_path in csv_paths:
        df = pd.read_csv(csv_path)
        if id_column and id_column in df.columns:
            series_index = pd.Index(df[id_column].astype(str))
        else:
            series_index = pd.Index(df.index.astype(str))

        pred_col = pred_column if pred_column and pred_column in df.columns else None
        if pred_col is None:
            pred_col = next((col for col in PRED_COLUMN_CANDIDATES if col in df.columns), None)
        if pred_col is None:
            continue

        pred_values = normalize_series(df[pred_col])
        pred_values.index = series_index
        aligned = pred_values.reindex(ids)
        stem = csv_path.stem
        if stem.endswith('_binary'):
            stem = stem[:-7]
        columns[stem] = aligned

    if not columns:
        raise ValueError(f'No usable judge columns found in {dir_path}')

    judge_df = pd.DataFrame(columns)
    judge_df = judge_df.loc[ids]
    judge_df = judge_df.reset_index(drop=True)
    gold_series = gold_series.reindex(ids).reset_index(drop=True)
    judge_df = judge_df.dropna(axis=1, how='all')
    if judge_df.empty:
        raise ValueError(f'All judge columns empty after loading {dir_path}')

    judge_df = judge_df.apply(pd.to_numeric, errors='coerce')
    gold_series = pd.to_numeric(gold_series, errors='coerce')

    mask = ~gold_series.isna()
    judge_df = judge_df.loc[mask].reset_index(drop=True)
    gold_series = gold_series.loc[mask].reset_index(drop=True)

    seed_offset = abs(hash(dir_path.as_posix())) % 1_000_000_000
    judge_df, gold_series = _balance_gold_if_constant(
        judge_df,
        gold_series,
        seed=LABEL_FLIP_SEED + seed_offset,
    )
    return judge_df, gold_series


def load_helpsteer3_bundle(judge_root: Path):
    return load_binary_dataset(judge_root / 'helpsteer3')


def load_judgebench_bundle(judge_root: Path):
    return load_binary_dataset(
        judge_root / 'judgebench',
        id_column='pair_id',
        gold_column='gold_label_binary',
        pred_column='pred_label_binary',
    )


def load_chatbot_arena_bundle(judge_root: Path):
    return load_binary_dataset(
        judge_root / 'chatbot_arena_conversations',
        id_column='question_id',
        gold_column='gold_label_binary',
        pred_column='pred_label_binary',
    )


def load_allenai_bundle(judge_root: Path, split: str):
    return load_binary_dataset(
        judge_root / split,
        id_column='id',
        gold_column='gold_label_binary',
        pred_column='pred_label_binary',
    )


def load_civilcomments_bundle(judge_root: Path):
    return load_binary_dataset(
        judge_root / 'civilcomments',
        gold_column='label',
        pred_column='parsed_output',
    )


def _load_gaussian_binary_dataset(
    dataset: str,
    *,
    threshold: float,
    gaussian_root: Path | None = None,
) -> tuple[pd.DataFrame, pd.Series]:
    root = Path(gaussian_root) if gaussian_root is not None else GAUSSIAN_JUDGE_ROOT
    judge_dir = root / dataset
    if not judge_dir.exists():
        raise FileNotFoundError(f'Gaussian judge outputs not found for {dataset!r} under {root}')

    columns: dict[str, pd.Series] = {}
    label_series: pd.Series | None = None

    for csv_path in sorted(judge_dir.glob('*.csv')):
        df = pd.read_csv(csv_path)
        scores = df['parsed_output'].apply(
            lambda x: extract_score_from_parsed_output(x, min_rating=0.0, max_rating=9.0)
        )
        binary = (scores >= threshold).astype(float)
        columns[csv_path.stem] = binary

        labels = df['label'].astype(int)
        if label_series is None:
            label_series = labels
        elif not labels.equals(label_series):
            raise ValueError(
                f'Label ordering mismatch detected for {dataset!r} between {csv_path.name} and previous files'
            )

    if label_series is None:
        raise FileNotFoundError(f'No CSV files found for {dataset!r} in {judge_dir}')

    judge_df = pd.DataFrame(columns).apply(pd.to_numeric, errors='coerce')
    gold_series = label_series.astype(float).reset_index(drop=True)

    mask = judge_df.notna().all(axis=1)
    judge_df = judge_df.loc[mask].reset_index(drop=True)
    gold_series = gold_series.loc[mask].reset_index(drop=True)

    if judge_df.empty:
        raise ValueError(f'All rows removed when assembling {dataset!r} binary judge bundle')

    return judge_df, gold_series


def load_yelp_binary_bundle(
    gaussian_root: Path | None = None,
    *,
    threshold: float = 4.5,
) -> tuple[pd.DataFrame, pd.Series]:
    return _load_gaussian_binary_dataset(
        'yelp',
        threshold=threshold,
        gaussian_root=gaussian_root,
    )


def load_liar2_binary_bundle(
    gaussian_root: Path | None = None,
    *,
    threshold: float = 4.5,
) -> tuple[pd.DataFrame, pd.Series]:
    return _load_gaussian_binary_dataset(
        'liar2',
        threshold=threshold,
        gaussian_root=gaussian_root,
    )


PREFERENCE_VALIDATION_LOADERS = {
    'helpsteer3': load_helpsteer3_bundle,
    'judgebench': load_judgebench_bundle,
    'chatbot_arena_conversations': load_chatbot_arena_bundle,
    'anthropic_harmless': partial(load_allenai_bundle, split='anthropic_harmless'),
    'anthropic_helpful': partial(load_allenai_bundle, split='anthropic_helpful'),
    'summarize': partial(load_allenai_bundle, split='summarize'),
    'pku_better': partial(load_allenai_bundle, split='pku_better'),
    'pku_safer': partial(load_allenai_bundle, split='pku_safer'),
    'shp': partial(load_allenai_bundle, split='shp'),
    'mtbench_human': partial(load_allenai_bundle, split='mtbench_human'),
    'mtbench_gpt4': partial(load_allenai_bundle, split='mtbench_gpt4'),
}


def sanitize_bundle(judge_df: pd.DataFrame, gold_series: pd.Series):
    judge_df = judge_df.apply(pd.to_numeric, errors='coerce')
    judge_df = judge_df.replace([np.inf, -np.inf], np.nan)
    gold_series = pd.Series(gold_series).reset_index(drop=True)

    min_len = min(len(judge_df), len(gold_series))
    judge_df = judge_df.iloc[:min_len]
    gold_series = gold_series.iloc[:min_len]

    valid_mask = judge_df.notna().all(axis=1)
    judge_df = judge_df.loc[valid_mask].reset_index(drop=True)
    gold_series = gold_series.loc[valid_mask].reset_index(drop=True)

    stds = judge_df.std(axis=0, numeric_only=True)
    keep_cols = stds[stds > 1e-8].index
    judge_df = judge_df[keep_cols]
    if judge_df.empty:
        raise ValueError('All judge columns became constant after cleaning.')

    return judge_df, gold_series


def discretize_scores(judge_df: pd.DataFrame, label_space: str):
    if label_space == 'binary':
        def encode(value):
            if pd.isna(value):
                return np.nan
            if value > 0:
                return 1
            if value < 0:
                return 0
            return np.nan
        inverse_map = {0: -1, 1: 1}
    else:
        raise ValueError(f'Unsupported label space: {label_space}')

    encoded = judge_df.applymap(encode)
    return encoded, inverse_map


def decode_codes(codes, inverse_map):
    codes = np.asarray(codes)
    decoded = np.full(len(codes), np.nan, dtype=float)
    for idx, val in enumerate(codes):
        if pd.isna(val):
            continue
        decoded[idx] = inverse_map[int(val)]
    return decoded


def values_to_pref(values, label_space: str):
    arr = np.asarray(values, dtype=float)
    if label_space == 'binary':
        return np.where(arr > 0, 1.0, np.where(arr < 0, -1.0, np.nan))
    raise ValueError(f'Unsupported label space: {label_space}')


def run_label_model(discrete_df: pd.DataFrame, label_space: str, seed: int = 123):
    cardinality = 2 if label_space == 'binary' else None
    if cardinality is None:
        raise ValueError(f'Unsupported label space: {label_space}')
    label_model = LabelModel(cardinality=cardinality, verbose=False)
    label_model.fit(discrete_df.to_numpy(dtype=int), n_epochs=1000, log_freq=100, seed=seed)
    return label_model.predict(discrete_df.to_numpy(dtype=int))


def pm1_to_binary(arr):
    values = np.asarray(arr, dtype=float)
    binary = np.full(values.shape, np.nan, dtype=float)
    mask = ~np.isnan(values)
    binary[mask] = (values[mask] > 0).astype(float)
    return binary


def compute_binary_accuracy(pred_pm1, gold_pm1):
    pred = pm1_to_binary(pred_pm1)
    gold = pm1_to_binary(gold_pm1)
    mask = ~np.isnan(pred) & ~np.isnan(gold)
    if mask.sum() == 0:
        return np.nan
    return float(np.mean(pred[mask] == gold[mask]))


def compute_binary_f1(pred_pm1, gold_pm1):
    pred = pm1_to_binary(pred_pm1)
    gold = pm1_to_binary(gold_pm1)
    mask = ~np.isnan(pred) & ~np.isnan(gold)
    if mask.sum() == 0:
        return np.nan
    pred = pred[mask]
    gold = gold[mask]
    tp = np.sum((pred == 1) & (gold == 1))
    fp = np.sum((pred == 1) & (gold == 0))
    fn = np.sum((pred == 0) & (gold == 1))
    denominator = 2 * tp + fp + fn
    if denominator == 0:
        return 0.0
    return float((2 * tp) / denominator)


DEFAULT_LAM_S_GRID = [0.001, 0.01, 0.1]
DEFAULT_LAM_L_GRID = [0.001, 0.01, 0.1]


def _sample_validation_indices(n: int, val_fraction: float, random_state: int) -> np.ndarray:
    if n <= 0:
        return np.array([], dtype=int)
    val_size = int(np.round(n * val_fraction))
    val_size = min(max(val_size, 1), n)
    rng = np.random.default_rng(random_state)
    perm = rng.permutation(n)
    return perm[:val_size]


def tune_caret_hyperparams(
    judge_pm1: pd.DataFrame,
    gold_binary: pd.Series,
    *,
    lam_s_grid: list[float] | None = None,
    lam_l_grid: list[float] | None = None,
    val_fraction: float = 0.1,
    random_state: int = 0,
    class_balance: float = 50,
    ranks: tuple[int, ...] = (2, 3, 4),
):
    lam_s_grid = lam_s_grid or DEFAULT_LAM_S_GRID
    lam_l_grid = lam_l_grid or DEFAULT_LAM_L_GRID
    n = len(judge_pm1)
    val_idx = _sample_validation_indices(n, val_fraction, random_state)
    val_records: list[dict[str, float | bool]] = []
    best: dict[str, float | bool] | None = None
    if val_idx.size == 0:
        return None, val_records

    val_J = judge_pm1.iloc[val_idx].to_numpy()
    val_gold = gold_binary.iloc[val_idx].to_numpy(dtype=float)

    for lam_L in lam_l_grid:
        for lam_S in lam_s_grid:
            record = {'lam_L': lam_L, 'lam_S': lam_S, 'success': False, 'val_accuracy': np.nan}
            try:
                preds = caret_aggregate(
                    val_J,
                    lam_S=lam_S,
                    lam_L=lam_L,
                    class_balance=class_balance,
                    ranks=ranks,
                )
                preds = np.asarray(preds, dtype=float)
                acc = float(np.mean(preds == val_gold))
                record.update({'success': True, 'val_accuracy': acc})
                if (
                    best is None
                    or acc > best['val_accuracy'] + 1e-12
                    or (
                        np.isclose(acc, best['val_accuracy'])
                        and (lam_L, lam_S) < (best['lam_L'], best['lam_S'])
                    )
                ):
                    best = {
                        'lam_L': lam_L,
                        'lam_S': lam_S,
                        'val_accuracy': acc,
                        'val_size': int(val_idx.size),
                    }
            except Exception as exc:
                record['error'] = str(exc)
            val_records.append(record)

    return best, val_records


CARET_VAL_FRACTION = 0.1
CARET_RANDOM_STATE = 2025
CARET_RANKS = (4, 5, 6, 7)


def evaluate_bundle(judge_df_binary: pd.DataFrame, gold_binary: pd.Series):
    judge_pm1 = judge_df_binary * 2 - 1
    gold_pm1 = pd.Series(gold_binary.values * 2 - 1, dtype=float)

    judge_pm1, gold_pm1 = sanitize_bundle(judge_pm1, gold_pm1)
    gold_binary_clean = pd.Series((gold_pm1.values > 0).astype(float), index=judge_pm1.index)

    discrete_df, inverse_map = discretize_scores(judge_pm1, 'binary')

    aggregator_outputs: dict[str, np.ndarray] = {}
    mv_codes = majority_vote(discrete_df)
    aggregator_outputs['MV'] = decode_codes(mv_codes.to_numpy(), inverse_map)

    ws_codes = run_label_model(discrete_df, 'binary')
    aggregator_outputs['WS'] = decode_codes(ws_codes, inverse_map)

    avg_pref = values_to_pref(judge_pm1.mean(axis=1), 'binary')
    aggregator_outputs['AVG'] = avg_pref

    uws_scores = uws_aggregate(judge_pm1)
    aggregator_outputs['UWS'] = values_to_pref(uws_scores, 'binary')

    try:
        caresl_scores = caresl_aggregate(judge_pm1)
        aggregator_outputs['CARESL'] = values_to_pref(caresl_scores, 'binary')
    except Exception as err:
        aggregator_outputs['CARESL'] = np.full(len(gold_pm1), np.nan)

    positive_rate = float(gold_binary_clean.mean()) if len(gold_binary_clean) else 0.5
    class_balance_rate = float(np.clip(positive_rate, 0.0, 1.0) * 100.0)

    caret_best, caret_search = tune_caret_hyperparams(
        judge_pm1,
        gold_binary_clean,
        val_fraction=CARET_VAL_FRACTION,
        random_state=CARET_RANDOM_STATE,
        class_balance=class_balance_rate,
        ranks=CARET_RANKS,
    )
    caret_meta = {
        'lam_L': None,
        'lam_S': None,
        'val_accuracy': np.nan,
        'val_size': 0,
        'search_records': caret_search,
        'status': 'no_valid_params' if caret_best is None else 'pending',
        'class_balance_rate': class_balance_rate,
        'positive_rate': positive_rate,
    }

    try:
        if caret_best is None:
            raise RuntimeError('No successful hyperparameter combination for CARET')
        caret_meta.update({
            'lam_L': caret_best['lam_L'],
            'lam_S': caret_best['lam_S'],
            'val_accuracy': caret_best['val_accuracy'],
            'val_size': caret_best['val_size'],
            'status': 'ok',
        })
        caret_scores = caret_aggregate(
            judge_pm1.to_numpy(),
            lam_S=caret_best['lam_S'],
            lam_L=caret_best['lam_L'],
            class_balance=class_balance_rate,
            ranks=CARET_RANKS,
        )
        caret_scores = np.asarray(caret_scores, dtype=float)
        caret_pm1 = np.full_like(caret_scores, np.nan, dtype=float)
        caret_pm1[caret_scores == 1] = 1.0
        caret_pm1[caret_scores == 0] = -1.0
        aggregator_outputs['CARET'] = caret_pm1
    except Exception as err:
        aggregator_outputs['CARET'] = np.full(len(gold_pm1), np.nan)
        caret_meta['status'] = 'error'
        caret_meta['error'] = str(err)

    metrics = {}
    for name, preds in aggregator_outputs.items():
        metrics[name] = {
            'accuracy': compute_binary_accuracy(preds, gold_pm1),
            'f1': compute_binary_f1(preds, gold_pm1),
        }
    return metrics, caret_meta


VALIDATION_LAM_L_GRID = [1e-3, 5e-3, 1e-2, 5e-2, 1e-1]
VALIDATION_LAM_S_GRID = [1e-3, 5e-3, 1e-2, 5e-2, 1e-1]
VALIDATION_VAL_FRACTION = 0.10
VALIDATION_RANDOM_STATE = 42


def _ensure_judge_directory(judge_root: Path, dataset: str) -> Path:
    candidates = [judge_root / dataset]
    for suffix in ('_binary', '', '_outputs'):
        candidates.append(judge_root / f'{dataset}{suffix}')
    for path in candidates:
        if path.exists() and path.is_dir():
            return path
    raise FileNotFoundError(f'Judge outputs not found for dataset {dataset!r} under {judge_root}')


def _load_binary_validation_data(
    dataset: str,
    *,
    data_root: Path,
    judge_root: Path,
) -> tuple[pd.DataFrame, np.ndarray]:
    if dataset in {'yelp', 'liar2'}:
        loader = load_yelp_binary_bundle if dataset == 'yelp' else load_liar2_binary_bundle
        judge_df, gold_series = loader()
        return judge_df, gold_series.to_numpy(dtype=int)

    custom_loader = PREFERENCE_VALIDATION_LOADERS.get(dataset)
    if custom_loader is not None:
        judge_df, gold_series = custom_loader(judge_root)
        judge_df, gold_series = sanitize_bundle(judge_df, gold_series)
        return judge_df, gold_series.to_numpy(dtype=int)

    judge_dir = _ensure_judge_directory(judge_root, dataset)
    judge_df = collect_judge_outputs(
        data_dir=judge_dir,
        min_rating=0.0,
        max_rating=1.0,
    )
    judge_df = judge_df.reset_index(drop=True)

    data_path = data_root / f'{dataset}.csv'
    if data_path.exists():
        data_df = pd.read_csv(data_path)
        if 'label' not in data_df.columns:
            raise KeyError(f'Required column "label" not present in {data_path}')
        gold_labels = data_df['label'].astype(int).to_numpy()
    else:
        bundle_df, bundle_gold = load_binary_dataset(judge_dir)
        gold_labels = bundle_gold.astype(int).to_numpy()
        judge_df = judge_df.iloc[: len(gold_labels)]

    valid_mask = judge_df.applymap(partial(is_valid_binary_score)).all(axis=1)
    if not valid_mask.any():
        raise ValueError(f'No rows with complete judge coverage for {dataset}')
    judge_df = judge_df.loc[valid_mask].reset_index(drop=True)
    gold_labels = gold_labels[valid_mask.to_numpy()]

    judge_df = judge_df.applymap(partial(extract_score_from_parsed_output, min_rating=0.0, max_rating=9.0))
    judge_df = judge_df.apply(pd.to_numeric, errors='coerce')
    row_valid = judge_df.notna().all(axis=1)
    judge_df = judge_df.loc[row_valid].reset_index(drop=True)
    gold_labels = gold_labels[row_valid.to_numpy()]

    if judge_df.empty:
        raise ValueError(f'All rows dropped after parsing judge outputs for {dataset}')
    return judge_df, gold_labels


def _accuracy_on_indices(preds: np.ndarray, gold_labels: np.ndarray, indices: np.ndarray) -> float:
    preds = np.asarray(preds, dtype=int)
    subset = preds[indices]
    gold_subset = gold_labels[indices]
    if subset.size == 0:
        return float('nan')
    return float(np.mean(subset == gold_subset))


def run_binary_validation_experiments(
    dataset_names: Sequence[str],
    *,
    data_root: Path,
    judge_root: Path,
    lam_L_grid: Sequence[float] = VALIDATION_LAM_L_GRID,
    lam_S_grid: Sequence[float] = VALIDATION_LAM_S_GRID,
    val_fraction: float = VALIDATION_VAL_FRACTION,
    random_state: int = VALIDATION_RANDOM_STATE,
) -> pd.DataFrame:
    results: list[dict[str, object]] = []
    for dataset in dataset_names:
        judge_df, gold_labels = _load_binary_validation_data(
            dataset,
            data_root=data_root,
            judge_root=judge_root,
        )

        judge_np = judge_df.to_numpy(dtype=float)

        rng_labels = gold_labels.astype(int)
        indices = np.arange(rng_labels.shape[0])
        stratify_labels = rng_labels if np.unique(rng_labels).size > 1 else None
        idx_rest, idx_val = train_test_split(
            indices,
            test_size=val_fraction,
            random_state=random_state,
            stratify=stratify_labels,
        )
        idx_test = idx_rest

        best_acc = -np.inf
        best_lam_L = None
        best_lam_S = None
        for lam_L, lam_S in itertools.product(lam_L_grid, lam_S_grid):
            caret_pred = caret_aggregate(
                judge_np,
                lam_L=lam_L,
                lam_S=lam_S,
                class_balance=50,
            )
            caret_pred = np.asarray(caret_pred, dtype=int)
            acc = accuracy_score(rng_labels[idx_val], caret_pred[idx_val])
            if acc > best_acc + 1e-12:
                best_acc = acc
                best_lam_L = lam_L
                best_lam_S = lam_S

        if best_lam_L is None or best_lam_S is None:
            raise RuntimeError(f'Hyperparameter search failed for dataset {dataset}')

        caret_pred_full = np.asarray(
            caret_aggregate(
                judge_np,
                lam_L=best_lam_L,
                lam_S=best_lam_S,
                class_balance=50,
            ),
            dtype=int,
        )

        judge_binary = judge_df.astype(int)
        mv_pred = np.asarray(majority_vote(judge_binary), dtype=int)
        ws_pred = np.asarray(ws_aggregate(judge_binary), dtype=int)
        uws_pred = np.asarray(uws_aggregate(judge_df) >= 0.5, dtype=int)
        avg_pred = np.asarray((judge_df.mean(axis=1) >= 0.5), dtype=int)
        caresl_pred = np.asarray(caresl_aggregate(judge_df) >= 0.5, dtype=int)

        results.append(
            {
                'dataset': dataset,
                'MV': _accuracy_on_indices(mv_pred, rng_labels, idx_test),
                'AVG': _accuracy_on_indices(avg_pred, rng_labels, idx_test),
                'WS': _accuracy_on_indices(ws_pred, rng_labels, idx_test),
                'UWS': _accuracy_on_indices(uws_pred, rng_labels, idx_test),
                'CARESL': _accuracy_on_indices(caresl_pred, rng_labels, idx_test),
                'CARET': _accuracy_on_indices(caret_pred_full, rng_labels, idx_test),
                'lam_L': best_lam_L,
                'lam_S': best_lam_S,
                'val_accuracy': best_acc,
                'val_size': int(idx_val.size),
                'test_size': int(idx_test.size),
            }
        )

    return pd.DataFrame(results)
