#!/usr/bin/env python3
"""Unified judge-output generation for the exp1 datasets.

Datasets covered
----------------
- feedbackqa (1–4 scalar rating)
- helpsteer2 (0–4 scalar rating)
- ultrafeedback (1–5 scalar rating)
- helpsteer3 (pairwise preference A/B)
- helpsteer3_binary (pairwise preference A/B, binary label)
- judgebench (pairwise preference A/B; no ties)
- judgebench_binary (pairwise preference A/B; binary label)
- chatbot_arena_conversations (pairwise preference with ties)
- chatbot_arena_conversations_binary (pairwise preference; ties removed, binary label)
- allenai_preference_test_sets (pairwise strict preference; multiple splits)
- allenai_preference_test_sets_binary (pairwise strict preference; binary label)
- civilcomments_binary (binary toxicity classification)
- summarize_from_feedback_axis (validation/test axes, 1–7 scalar rating)
- yelp_review_full/test (5-class sentiment)
- tripadvisor_reviews (5-class sentiment)
- asset_ratings (0–100 simplification quality ratings)
- masterkey_exp (master-key perturbations on Multi-subject-RLVR; YES/NO strict scoring)
- masterkey_exp_with_ref (master-key perturbations with explicit references; YES/NO strict scoring)

Usage examples
--------------
Generate outputs for every dataset with the default model suite::

    python save_judge_outputs_exp1.py --datasets all

Limit execution to a particular dataset and subset of models::

    python save_judge_outputs_exp1.py --datasets feedbackqa helpsteer2         --models Qwen/Qwen3-1.7B meta-llama/Llama-3.1-8B-Instruct

The script reproduces the behaviour of the (now fragmented) dataset-specific
entrypoints while sharing common utilities for prompt formatting, text
parsing, batching, and model invocation. Where the original pipeline relied
on transformers’ eager generation we preserve that backend; vLLM is used for
larger HuggingFace-hosted corpora.
"""

from __future__ import annotations

import argparse
import ast
import csv
import gc
import os
import random
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import pandas as pd
import numpy as np
from datasets import load_dataset
from tqdm.auto import tqdm

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

try:
    import psutil
except ImportError:  # pragma: no cover - optional dependency
    psutil = None

SCRIPT_DIR = Path(__file__).resolve().parent
SRC_DIR = SCRIPT_DIR.parent / 'src'
if str(SRC_DIR) not in os.sys.path:
    os.sys.path.append(str(SRC_DIR))

from data_tools import (  # noqa: E402
    load_feedbackqa,
    load_helpsteer2,
    load_helpsteer3,
    load_ultrafeedback,
    load_review5k,
)
from llm_tools import (  # noqa: E402
    JUDGE_PROMPT,
    REVIEW5K_BINARY_JUDGE_PROMPT,
    REVIEW5K_JUDGE_PROMPT,
    TRIPADVISOR_REVIEW_JUDGE_PROMPT,
    ULTRAFEEDBACK_JUDGE_PROMPT,
    HELPSTEER2_JUDGE_PROMPT,
    YELP_REVIEW_FULL_JUDGE_PROMPT,
    PREFERENCE_JUDGE_PROMPT,
    HELPSTEER3_PREF_JUDGE_PROMPT,
    JUDGEBENCH_PREF_NO_TIE_PROMPT,
    PREFERENCE_BINARY_JUDGE_PROMPT,
    CIVILCOMMENTS_BINARY_JUDGE_PROMPT,
    CIVILCOMMENTS_SCORE_JUDGE_PROMPT,
    MASTERKEY_EXP_JUDGE_TEMPLATE,
    ASSET_SIMPLIFICATION_JUDGE_PROMPT,
)

# ---------------------------------------------------------------------------
# Global defaults
# ---------------------------------------------------------------------------

# vLLM backend only
DEVICE = 'cuda'
BINARY_ADDITIONAL_DEFAULT_MODELS = [
    # Qwen 2.5 instruction-tuned family (1.5B - 14B)
    'Qwen/Qwen2.5-1.5B-Instruct',
    'Qwen/Qwen2.5-3B-Instruct',
    'Qwen/Qwen2.5-7B-Instruct',
    'Qwen/Qwen2.5-14B-Instruct',
    # Llama 3.2 instruction-tuned refresh (1B - 3B)
    'meta-llama/Llama-3.2-1B-Instruct',
    'meta-llama/Llama-3.2-3B-Instruct',
    # Additional medium-scale instruct models (7B - 12B)
    'mistralai/Mistral-Nemo-12B-Instruct',
    'google/gemma-2-9b-it',
    'deepseek-ai/deepseek-llm-7b-chat',
    '01-ai/Yi-1.5-6B-Chat',
    '01-ai/Yi-1.5-9B-Chat',
    '01-ai/Yi-1.5-13B-Chat',
]

DEFAULT_MODELS = [
    'Qwen/Qwen3-0.6B',
    'Qwen/Qwen3-1.7B',
    'Qwen/Qwen3-4B',
    'Qwen/Qwen3-8B',
    'meta-llama/Llama-3.1-8B-Instruct',
    'meta-llama/Llama-3.2-1B',
    'meta-llama/Llama-3.2-3B',
    'google/gemma-3-1b-it',
    'google/gemma-3-4b-it',
    'microsoft/Phi-4-mini-instruct',
    'mistralai/Mistral-7B-Instruct-v0.3',
]

MASTERKEY_DEFAULT_MODELS = [
    'Qwen/Qwen3-0.6B',
    'Qwen/Qwen3-1.7B',
    'Qwen/Qwen3-4B',
    'Qwen/Qwen3-8B',
    'Qwen/Qwen2.5-1.5B-Instruct',
    'Qwen/Qwen2.5-3B-Instruct',
    'Qwen/Qwen2.5-7B-Instruct',
    'Qwen/Qwen2.5-14B-Instruct',
    'meta-llama/Llama-3.1-8B-Instruct',
    'meta-llama/Llama-3.2-1B',
    'meta-llama/Llama-3.2-3B',
    'meta-llama/Llama-3.2-1B-Instruct',
    'meta-llama/Llama-3.2-3B-Instruct',
    'google/gemma-3-1b-it',
    'google/gemma-3-4b-it',
    'google/gemma-2-9b-it',
    'microsoft/Phi-4-mini-instruct',
    'mistralai/Mistral-7B-Instruct-v0.3',
    '01-ai/Yi-1.5-6B-Chat',
    '01-ai/Yi-1.5-9B-Chat',
]

MASTERKEY_CSV_PATH = Path('~/llm-judge-bias/data/miscellaneous/master_keys.csv').expanduser()
MASTERKEY_WITH_REF_CSV_PATH = Path('~/llm-judge-bias/data/miscellaneous/master_keys_with_ref.csv').expanduser()
MASTERKEY_OUTPUT_DIR = Path('~/llm-judge-bias/judge_outputs/master_keys').expanduser()
MASTERKEY_WITH_REF_OUTPUT_DIR = Path('~/llm-judge-bias/judge_outputs/master_keys_with_ref').expanduser()

DTYPE_OVERRIDES: Dict[str, torch.dtype] = {
    'google/gemma-3-4b-it': torch.bfloat16,
    'google/gemma-3-1b-it': torch.bfloat16,
    'google/gemma-2-9b-it': torch.bfloat16,
}

REVIEW5K_MODEL_CONTEXT_LENGTHS: Dict[str, int] = {
    'Qwen/Qwen3-0.6B': 32768,
    'Qwen/Qwen3-1.7B': 32768,
    'Qwen/Qwen3-4B': 32768,
    'Qwen/Qwen3-8B': 32768,
    'meta-llama/Llama-3.1-8B-Instruct': 131072,
    'meta-llama/Llama-3.2-1B': 65536,
    'meta-llama/Llama-3.2-3B': 65536,
    'meta-llama/Llama-3.2-1B-Instruct': 65536,
    'meta-llama/Llama-3.2-3B-Instruct': 65536,
    'Qwen/Qwen2.5-1.5B-Instruct': 32768,
    'Qwen/Qwen2.5-3B-Instruct': 32768,
    'Qwen/Qwen2.5-7B-Instruct': 32768,
    'Qwen/Qwen2.5-14B-Instruct': 32768,
    'google/gemma-2-9b-it': 8192,
    'mistralai/Mistral-Nemo-12B-Instruct': 32768,
    'deepseek-ai/deepseek-llm-7b-chat': 8192,
    '01-ai/Yi-1.5-6B-Chat': 4096,
    '01-ai/Yi-1.5-9B-Chat': 4096,
    '01-ai/Yi-1.5-13B-Chat': 4096,
    'allenai/tulu-2-dpo-7b': 4096,
    'allenai/tulu-2-dpo-13b': 4096,
    'NousResearch/Nous-Hermes-2-Mistral-7B-DPO': 8192,
}

MODEL_CONTEXT_LENGTHS: Dict[str, int] = REVIEW5K_MODEL_CONTEXT_LENGTHS

PROMPT_TRUNCATION_NOTICE = "\n\n[Prompt truncated to satisfy model context length.]\n"

REVIEW5K_TRUNCATION_NOTICE = "\n\n[Paper truncated to satisfy model context length.]\n"
REVIEW5K_PROMPT_TEMPLATES: Dict[str, str] = {
    'gaussian_mixture': REVIEW5K_JUDGE_PROMPT,
    'binary': REVIEW5K_BINARY_JUDGE_PROMPT,
}
REVIEW5K_DEFAULT_MAX_NEW_TOKENS = 6
REVIEW5K_DEFAULT_RESERVE_TOKENS = 32
DEFAULT_PROMPT_RESERVE_TOKENS = REVIEW5K_DEFAULT_RESERVE_TOKENS
REVIEW5K_FALLBACK_CONTEXT = 8192

BATCH_SIZE = int(os.environ.get('VLLM_BATCH_SIZE', '16'))
SWAP_AB: bool = False  # set by CLI flag
MODE: str = 'binary'  # scoring mode, toggled by CLI
GAUSSIAN_PREFERENCE_THRESHOLD = 0.0
CIVILCOMMENTS_GAUSSIAN_THRESHOLD = 5

FAILURE_LOG_PATH = SCRIPT_DIR / 'generation_failures.log'

# Ensure deterministic CUDA kernel selection matches the legacy scripts.
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

BYTES_IN_GIB = 1024 ** 3
GPU_MEMORY_SAFETY_MARGIN = 2 * BYTES_IN_GIB

# ---------------------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------------------


def get_binary_output_dir(dataset: str) -> Path:
    """Return the output directory for binary (or gaussian) judge runs."""
    root = 'judge_outputs/gaussian_mixture' if MODE == 'gaussian_mixture' else 'judge_outputs/binary'
    output_dir = SCRIPT_DIR.parent / root / dataset
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir


def log_generation_failure(model_name: str, dataset_name: str, error: Exception) -> None:
    timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
    message = (
        f"[generation_failure] {timestamp} dataset={dataset_name} model={model_name} "
        f"error={error}"
    )
    print(message, flush=True)
    try:
        with open(FAILURE_LOG_PATH, 'a', encoding='utf-8') as handle:
            handle.write(message + '\n')
    except OSError as log_err:
        print(
            f"[generation_failure] {timestamp} failed to append log: {log_err}",
            flush=True,
        )


def _bytes_to_gib(value: float) -> float:
    return value / BYTES_IN_GIB


def cleanup_orphan_vllm_workers(reason: str) -> None:
    if psutil is None:
        return
    current_pid = os.getpid()
    orphans = []
    marker = 'VLLM::EngineCore'
    for proc in psutil.process_iter(['pid', 'ppid', 'name', 'cmdline']):
        try:
            name = proc.info.get('name') or ''
            cmdline_parts = proc.info.get('cmdline') or []
            ppid = proc.info.get('ppid')
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            continue
        if marker not in name and all(marker not in part for part in cmdline_parts):
            continue
        if ppid in {current_pid, os.getppid()}:
            continue
        parent_running = False
        if ppid not in (None, 0, 1):
            try:
                parent_running = psutil.Process(ppid).is_running()
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                parent_running = False
        if parent_running:
            continue
        orphans.append(proc)
    if not orphans:
        return
    print(
        f"[vLLM] Terminating {len(orphans)} orphaned engine process(es) before {reason}.",
        flush=True,
    )
    for proc in orphans:
        try:
            proc.terminate()
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            continue
    deadline = time.time() + 5.0
    for proc in orphans:
        remaining = max(0.0, deadline - time.time())
        try:
            proc.wait(timeout=remaining)
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            continue
        except psutil.TimeoutExpired:
            try:
                proc.kill()
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                continue


def maybe_adjust_gpu_memory_utilization(engine_kwargs: Dict[str, object], model_name: str) -> None:
    if not torch.cuda.is_available():
        return
    requested_ratio = float(engine_kwargs.get('gpu_memory_utilization', 0.75))
    free_bytes, total_bytes = torch.cuda.mem_get_info()
    desired_bytes = requested_ratio * total_bytes
    if free_bytes >= desired_bytes:
        return
    headroom_bytes = max(0.0, free_bytes - GPU_MEMORY_SAFETY_MARGIN)
    if headroom_bytes <= 0.0:
        raise RuntimeError(
            f"Insufficient free GPU memory for {model_name}: "
            f"{_bytes_to_gib(free_bytes):.2f} GiB available, "
            f"requires at least {_bytes_to_gib(desired_bytes):.2f} GiB."
        )
    adjusted_ratio = max(0.05, min(requested_ratio, headroom_bytes / total_bytes))
    if adjusted_ratio >= requested_ratio:
        return
    engine_kwargs['gpu_memory_utilization'] = adjusted_ratio
    print(
        f"[vLLM] Reduced gpu_memory_utilization to {adjusted_ratio:.3f} for {model_name} "
        f"(free {_bytes_to_gib(free_bytes):.2f} GiB of {_bytes_to_gib(total_bytes):.2f} GiB).",
        flush=True,
    )


@dataclass
class Review5KTruncationStats:
    total: int
    truncated: int = 0

    def log(self, model_name: str) -> None:
        if self.truncated:
            print(
                f"[truncate] {self.truncated}/{self.total} prompts shortened for {model_name}",
                flush=True,
            )


def build_review5k_prompt(
    *,
    guidelines: str,
    paper: str,
    tokenizer,
    allowed_tokens: int,
    suffix_token_len: int,
    notice_token_len: int,
    prefix_template: str,
    suffix_template: str,
    truncation_notice: str = REVIEW5K_TRUNCATION_NOTICE,
) -> Tuple[str, bool]:
    prefix = prefix_template.format(guidelines=guidelines)
    prefix_token_len = len(tokenizer.encode(prefix, add_special_tokens=False))
    budget = allowed_tokens - prefix_token_len - suffix_token_len
    if budget <= 0:
        raise ValueError(
            'Prompt template and guidelines exceed the available context tokens for Review-5K.'
        )

    paper_tokens = tokenizer.encode(paper, add_special_tokens=False)
    if len(paper_tokens) <= budget:
        prompt = prefix + paper + suffix_template
        return prompt, False

    if budget <= notice_token_len:
        raise ValueError(
            'Not enough tokens remain for the truncation notice; adjust reserve tokens or shorten guidelines.'
        )

    allowed_paper_tokens = budget - notice_token_len
    truncated_tokens = paper_tokens[:allowed_paper_tokens]
    truncated_paper = tokenizer.decode(
        truncated_tokens,
        skip_special_tokens=False,
        clean_up_tokenization_spaces=False,
    )

    prompt = prefix + truncated_paper + truncation_notice + suffix_template
    total_tokens = tokenizer.encode(prompt, add_special_tokens=False)
    if len(total_tokens) > allowed_tokens:
        raise ValueError('Truncation failed to satisfy context budget for Review-5K prompt.')

    return prompt, True


def truncate_prompts_to_budget(
    prompts: List[str],
    *,
    tokenizer,
    allowed_prompt_tokens: int,
    notice: str,
) -> Tuple[List[str], int]:
    if allowed_prompt_tokens <= 0:
        raise ValueError('Allowed prompt token budget must be positive.')

    notice_ids = tokenizer.encode(notice, add_special_tokens=False) if notice else []
    adjusted: List[str] = []
    truncated = 0

    suffix_markers = [
        "\nFeedback:::\nTotal rating:",
        "Feedback:::\nTotal rating:",
    ]

    for text in prompts:
        suffix_text = ''
        prefix_text = text
        for marker in suffix_markers:
            idx = text.rfind(marker)
            if idx != -1:
                suffix_text = text[idx:]
                prefix_text = text[:idx]
                break

        prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False)
        suffix_ids = tokenizer.encode(suffix_text, add_special_tokens=False) if suffix_text else []

        if len(prefix_ids) + len(suffix_ids) <= allowed_prompt_tokens:
            adjusted.append(text)
            continue

        truncated += 1

        if suffix_text:
            budget_for_prefix = allowed_prompt_tokens - len(suffix_ids) - len(notice_ids)
        else:
            budget_for_prefix = allowed_prompt_tokens - len(notice_ids)

        if budget_for_prefix <= 0:
            raise ValueError(
                'Not enough context budget remains for required suffix/notice; '
                'reduce reserve tokens or shorten the prompt template.'
            )

        trimmed_prefix_ids = prefix_ids[:budget_for_prefix]
        trimmed_text = tokenizer.decode(
            trimmed_prefix_ids,
            skip_special_tokens=False,
            clean_up_tokenization_spaces=False,
        )

        parts: List[str] = [trimmed_text]
        if notice_ids:
            parts.append(notice)
        if suffix_text:
            parts.append(suffix_text)
        candidate = ''.join(parts)
        candidate_ids = tokenizer.encode(candidate, add_special_tokens=False)

        while candidate_ids and len(candidate_ids) > allowed_prompt_tokens and trimmed_prefix_ids:
            trimmed_prefix_ids = trimmed_prefix_ids[:-1]
            trimmed_text = tokenizer.decode(
                trimmed_prefix_ids,
                skip_special_tokens=False,
                clean_up_tokenization_spaces=False,
            )
            parts = [trimmed_text]
            if notice_ids:
                parts.append(notice)
            if suffix_text:
                parts.append(suffix_text)
            candidate = ''.join(parts)
            candidate_ids = tokenizer.encode(candidate, add_special_tokens=False)

        if len(candidate_ids) > allowed_prompt_tokens:
            trimmed_ids = (prefix_ids + suffix_ids)[:allowed_prompt_tokens]
            candidate = tokenizer.decode(
                trimmed_ids,
                skip_special_tokens=False,
                clean_up_tokenization_spaces=False,
            )

        adjusted.append(candidate)

    return adjusted, truncated


def get_dtype(model_name: str) -> Optional[torch.dtype]:
    return DTYPE_OVERRIDES.get(model_name)


def parse_first_number(text: str, *, allow_float: bool = True) -> str:
    if not text:
        return ''
    pattern = r"[-+]?[0-9]*\.?[0-9]+" if allow_float else r"[-+]?[0-9]+"
    match = re.search(pattern, text)
    return match.group(0) if match else ''


def parse_total_rating(text: str, allowed: Iterable[int]) -> str:
    if not text:
        return ''
    allowed_set = {str(v) for v in allowed}
    match = re.search(r"TOTAL\s*RATING\s*:\s*([0-9]+)", text, flags=re.IGNORECASE)
    if match and match.group(1) in allowed_set:
        return match.group(1)
    match = re.search(r"([0-9]+)", text)
    if match and match.group(1) in allowed_set:
        return match.group(1)
    return ''


def parse_signed_total_rating(text: str, *, min_allowed: int = -3, max_allowed: int = 3, allow_zero: bool = True) -> str:
    """Parse integer from `Feedback:::/Total rating:` or first integer; enforce range.

    Returns the matched numeric string or '' if invalid/out of range.
    """
    if not text:
        return ''
    m = re.search(r"TOTAL\s*RATING\s*:\s*([+-]?\d+)", text, flags=re.IGNORECASE)
    if not m:
        m = re.search(r"([+-]?\d+)", text)
    if not m:
        return ''
    val_s = m.group(1)
    try:
        val = int(val_s)
    except ValueError:
        return ''
    if (val == 0 and not allow_zero) or val < min_allowed or val > max_allowed:
        return ''
    return val_s


SHIFTED_PREFERENCE_ALLOWED = [0, 1, 2, 3, 4, 5, 6]
SHIFTED_TO_SIGNED = {0: -3, 1: -2, 2: -1, 3: 0, 4: 1, 5: 2, 6: 3}


def extract_shifted_signed_score(text: str, *, allow_zero: bool = False) -> Optional[int]:
    """Decode shifted preference labels (0,1,2,4,5,6) back to -3..-1 and 1..3.

    Falls back to the legacy signed scale when models return the old format.
    """
    token = parse_total_rating(text, allowed=SHIFTED_PREFERENCE_ALLOWED)
    if token:
        try:
            return SHIFTED_TO_SIGNED[int(token)]
        except (ValueError, KeyError):
            pass

    token = parse_signed_total_rating(text, allow_zero=allow_zero)
    if token:
        try:
            return int(token)
        except ValueError:
            return None
    return None


# ---------------------------------------------------------------------------
# Generation backend (vLLM)
# ---------------------------------------------------------------------------

def run_vllm_generation(
    model_name: str,
    prompts: List[str],
    *,
    max_new_tokens: int,
    dataset_name: Optional[str] = None,
    reserve_tokens: int = DEFAULT_PROMPT_RESERVE_TOKENS,
    truncation_notice: Optional[str] = PROMPT_TRUNCATION_NOTICE,
    apply_truncation: bool = True,
) -> Optional[List[str]]:
    engine_kwargs = {
        'model': model_name,
        'gpu_memory_utilization': float(os.environ.get('VLLM_GPU_MEMORY_UTILIZATION', '0.9')),
        # 'tensor_parallel_size': 4,
    }
    max_len_hint = MODEL_CONTEXT_LENGTHS.get(model_name)
    if max_len_hint is not None:
        engine_kwargs['max_model_len'] = max_len_hint
    dtype = get_dtype(model_name)
    if dtype is not None:
        engine_kwargs['dtype'] = dtype

    dataset_label = dataset_name or 'unknown'
    engine: Optional[LLM] = None
    try:
        cleanup_orphan_vllm_workers(f"initializing {model_name}")
        maybe_adjust_gpu_memory_utilization(engine_kwargs, model_name)
        engine = LLM(**engine_kwargs)

        actual_max_len = max_len_hint
        if actual_max_len is None:
            engine_cfg = getattr(getattr(engine, 'llm_engine', None), 'model_config', None)
            actual_max_len = getattr(engine_cfg, 'max_model_len', None)

        if apply_truncation and actual_max_len is not None:
            tokenizer = engine.get_tokenizer()
            allowed_prompt_tokens = actual_max_len - (max_new_tokens + reserve_tokens)
            if allowed_prompt_tokens <= 0:
                raise ValueError(
                    f"max_model_len={actual_max_len} is insufficient for reserve "
                    f"{max_new_tokens + reserve_tokens} tokens"
                )
            notice_text = truncation_notice or ''
            adjusted_prompts, truncated = truncate_prompts_to_budget(
                prompts,
                tokenizer=tokenizer,
                allowed_prompt_tokens=allowed_prompt_tokens,
                notice=notice_text,
            )
            if truncated:
                print(
                    f"[truncate] {truncated}/{len(prompts)} prompts shortened for {model_name} "
                    f"({dataset_label}) to fit {allowed_prompt_tokens} prompt tokens.",
                    flush=True,
                )
            prompts = adjusted_prompts
        elif apply_truncation and actual_max_len is None:
            print(
                f"[warn] Unable to determine max context for {model_name}; skipping prompt truncation.",
                flush=True,
            )
        sampling_params = SamplingParams(
            max_tokens=max_new_tokens,
            temperature=0.0,
            top_k=1,
            top_p=1.0,
        )
        outputs = engine.generate(prompts, sampling_params)
        raw = [out.outputs[0].text.strip() for out in outputs]
        return raw
    except Exception as exc:  # noqa: BLE001
        cleanup_orphan_vllm_workers(f"handling failure for {model_name}")
        log_generation_failure(model_name, dataset_label, exc)
        return None
    finally:
        if engine is not None:
            del engine
        torch.cuda.empty_cache()
        gc.collect()


# ---------------------------------------------------------------------------
# Master-key RLVR helper utilities
# ---------------------------------------------------------------------------


def _load_masterkey_records(csv_path: Path = MASTERKEY_CSV_PATH) -> pd.DataFrame:
    if not csv_path.exists():
        raise FileNotFoundError(f'Master-key CSV not found: {csv_path}')

    frame = pd.read_csv(csv_path)
    frame = frame.copy()
    if 'label' in frame.columns and 'reference' not in frame.columns:
        frame.rename(columns={'label': 'reference'}, inplace=True)

    def parse_attack(value):
        if isinstance(value, str) and value.strip():
            try:
                parsed = ast.literal_eval(value)
                if isinstance(parsed, dict):
                    return parsed
            except (ValueError, SyntaxError):
                return {}
        return value if isinstance(value, dict) else {}

    if 'attack' in frame.columns:
        meta = frame['attack'].apply(parse_attack)
        frame['attack_type'] = meta.apply(lambda x: x.get('type') if isinstance(x, dict) else None)
        frame['attack_id'] = meta.apply(lambda x: x.get('id') if isinstance(x, dict) else None)
        frame['attack_category'] = meta.apply(lambda x: x.get('category') if isinstance(x, dict) else None)
    else:
        frame['attack_type'] = None
        frame['attack_id'] = None
        frame['attack_category'] = None

    if 'source_file' not in frame.columns:
        frame['source_file'] = None

    frame['row_index'] = range(len(frame))
    return frame


def _parse_masterkey_output(text: str) -> str:
    if not text:
        return ''
    first_line = text.strip().splitlines()[0]
    match = re.search(r'\b(yes|no)\b', first_line, flags=re.IGNORECASE)
    if match:
        return match.group(1).upper()
    match = re.search(r'\b(yes|no)\b', text, flags=re.IGNORECASE)
    if match:
        return match.group(1).upper()
    return first_line.strip()


# ---------------------------------------------------------------------------
# Dataset runners
# ---------------------------------------------------------------------------


def _run_masterkey_generation(
    models: List[str],
    csv_path: Path,
    output_dir: Path,
    dataset_label: str,
) -> None:
    frame = _load_masterkey_records(csv_path=csv_path)
    if frame.empty:
        print(f"[{dataset_label}] No rows available in {csv_path}")
        return

    prompts = [
        MASTERKEY_EXP_JUDGE_TEMPLATE.format(
            question=row['question'],
            response=row['response'],
            reference=row['reference'],
        )
        for _, row in frame.iterrows()
    ]

    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=4,
            dataset_name=dataset_label,
        )
        if raw is None:
            continue

        parsed = [_parse_masterkey_output(text) for text in raw]
        result = frame.copy()
        result['raw_output'] = raw
        result['parsed_output'] = parsed
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def run_masterkey_exp(models: List[str]) -> None:
    _run_masterkey_generation(
        models,
        csv_path=MASTERKEY_CSV_PATH,
        output_dir=MASTERKEY_OUTPUT_DIR,
        dataset_label='master_keys',
    )


def run_masterkey_exp_with_ref(models: List[str]) -> None:
    _run_masterkey_generation(
        models,
        csv_path=MASTERKEY_WITH_REF_CSV_PATH,
        output_dir=MASTERKEY_WITH_REF_OUTPUT_DIR,
        dataset_label='master_keys_with_ref',
    )


def run_feedbackqa(models: List[str]) -> None:
    ratings = load_feedbackqa()
    base_df = ratings[['question', 'answer']].copy()
    prompts = [JUDGE_PROMPT.format(question=q, answer=a) for q, a in zip(base_df['question'], base_df['answer'])]

    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'feedbackqa'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=8,
            dataset_name='feedbackqa',
        )
        if raw is None:
            continue
        parsed = [parse_first_number(text) for text in raw]
        result = base_df.copy()
        result['raw_output'] = raw
        result['parsed_output'] = parsed
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def run_helpsteer2(models: List[str]) -> None:
    ratings = load_helpsteer2()
    base_df = ratings[['question', 'answer']].copy()
    prompts = [
        HELPSTEER2_JUDGE_PROMPT.format(
            question=q,
            answer=a,
        )
        for q, a in zip(base_df['question'], base_df['answer'])
    ]

    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'helpsteer2'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=8,
            dataset_name='helpsteer2',
        )
        if raw is None:
            continue
        parsed = [parse_first_number(text) for text in raw]
        result = base_df.copy()
        result['raw_output'] = raw
        result['parsed_output'] = parsed
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def run_ultrafeedback(models: List[str]) -> None:
    ratings = load_ultrafeedback()
    base_df = ratings[['question', 'answer']].copy()
    prompts = [
        ULTRAFEEDBACK_JUDGE_PROMPT.format(
            min_rating=1.0,
            max_rating=5.0,
            question=q,
            answer=a,
        )
        for q, a in zip(base_df['question'], base_df['answer'])
    ]

    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'ultrafeedback_sampled'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=8,
            dataset_name='ultrafeedback',
        )
        if raw is None:
            continue
        parsed = [parse_first_number(text) for text in raw]
        result = base_df.copy()
        result['raw_output'] = raw
        result['parsed_output'] = parsed
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def _load_helpsteer3_pairs() -> pd.DataFrame:
    ratings = load_helpsteer3()
    frame = ratings[['question', 'response1', 'response2', 'overall_preference']].copy()
    scores = pd.to_numeric(frame['overall_preference'], errors='coerce')
    frame['overall_preference'] = scores
    binary = pd.Series(pd.NA, index=frame.index, dtype='Int64')
    binary.loc[scores > 0] = 1
    binary.loc[scores <= 0] = 0
    frame['gold_label_binary'] = binary
    return frame


def run_helpsteer3(models: List[str]) -> None:
    base_df = _load_helpsteer3_pairs()

    response_a: List[str] = []
    response_b: List[str] = []
    was_swapped: List[bool] = []
    for r1, r2 in zip(base_df['response1'].tolist(), base_df['response2'].tolist()):
        if SWAP_AB and random.random() < 0.5:
            response_a.append(r2)
            response_b.append(r1)
            was_swapped.append(True)
        else:
            response_a.append(r1)
            response_b.append(r2)
            was_swapped.append(False)

    prompts = [
        HELPSTEER3_PREF_JUDGE_PROMPT.format(
            question=q,
            answer_a=a,
            answer_b=b,
        )
        for q, a, b in zip(base_df['question'].tolist(), response_a, response_b)
    ]

    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'helpsteer3'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=8,
            dataset_name='helpsteer3',
        )
        if raw is None:
            continue

        scores_ab: List[Optional[int]] = []
        scores_original: List[Optional[int]] = []
        pref_labels: List[str] = []
        preferred_resp: List[Optional[str]] = []
        pred_binary: List[Optional[int]] = []

        for idx, entry in enumerate(raw):
            score = extract_shifted_signed_score(entry, allow_zero=False)
            scores_ab.append(score)
            if score is None:
                scores_original.append(None)
                pref_labels.append('')
                preferred_resp.append(None)
                pred_binary.append(None)
                continue
            s_orig = -score if was_swapped[idx] else score
            scores_original.append(s_orig)
            if s_orig > 0:
                pref_labels.append('B')
                preferred_resp.append('response2')
                pred_binary.append(1)
            elif s_orig < 0:
                pref_labels.append('A')
                preferred_resp.append('response1')
                pred_binary.append(0)
            else:
                pref_labels.append('tie')
                preferred_resp.append(None)
                pred_binary.append(None)

        result = base_df.copy()
        result['response_a'] = response_a
        result['response_b'] = response_b
        result['was_swapped'] = was_swapped
        result['raw_output'] = raw
        result['score_ab'] = scores_ab
        result['score_original_order'] = scores_original
        result['pref_A_or_B'] = pref_labels
        result['preferred_resp'] = preferred_resp
        result['pred_label_binary'] = pred_binary
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def run_helpsteer3_binary(models: List[str]) -> None:
    base_df = _load_helpsteer3_pairs()
    filtered = base_df.dropna(subset=['overall_preference']).loc[lambda df: df['overall_preference'] != 0]
    filtered = filtered.reset_index(drop=True)

    gaussian_mode = MODE == 'gaussian_mixture'
    prompt_template = (
        JUDGEBENCH_PREF_NO_TIE_PROMPT if gaussian_mode else PREFERENCE_BINARY_JUDGE_PROMPT
    )
    max_tokens = 8 if gaussian_mode else 6

    response_a: List[str] = []
    response_b: List[str] = []
    was_swapped: List[bool] = []
    for r1, r2 in zip(filtered['response1'].tolist(), filtered['response2'].tolist()):
        if SWAP_AB and random.random() < 0.5:
            response_a.append(r2)
            response_b.append(r1)
            was_swapped.append(True)
        else:
            response_a.append(r1)
            response_b.append(r2)
            was_swapped.append(False)

    prompts = [
        prompt_template.format(
            question=q,
            answer_a=a,
            answer_b=b,
        )
        for q, a, b in zip(filtered['question'].tolist(), response_a, response_b)
    ]

    output_dir = get_binary_output_dir('helpsteer3')

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=max_tokens,
            dataset_name='helpsteer3_binary',
        )
        if raw is None:
            continue

        scores_ab: List[Optional[int]] = []
        scores_original: List[Optional[int]] = []
        pref_labels: List[str] = []
        pred_binary: List[Optional[int]] = []

        for idx, entry in enumerate(raw):
            if gaussian_mode:
                score = extract_shifted_signed_score(entry, allow_zero=False)
            else:
                token = parse_total_rating(entry, allowed=[0, 1])
                score = int(token) if token not in ('', None) else None
            scores_ab.append(score)
            if score is None:
                scores_original.append(None)
                pref_labels.append('')
                pred_binary.append(None)
                continue

            if gaussian_mode:
                s_orig = -score if was_swapped[idx] else score
                scores_original.append(s_orig)
                if s_orig > 0:
                    pref_labels.append('B')
                    pred_binary.append(1)
                elif s_orig < 0:
                    pref_labels.append('A')
                    pred_binary.append(0)
                else:
                    pref_labels.append('tie')
                    pred_binary.append(None)
            else:
                s_orig = 1 - score if was_swapped[idx] else score
                scores_original.append(s_orig)
                if s_orig == 0:
                    pref_labels.append('A')
                elif s_orig == 1:
                    pref_labels.append('B')
                else:
                    pref_labels.append('')
                pred_binary.append(s_orig if s_orig in (0, 1) else None)

        result = filtered.copy()
        result['response_a'] = response_a
        result['response_b'] = response_b
        result['was_swapped'] = was_swapped
        result['raw_output'] = raw
        result['score_ab'] = scores_ab
        result['score_original_order'] = scores_original
        result['pref_A_or_B'] = pref_labels
        result['pred_label_binary'] = pred_binary
        result['scoring_mode'] = MODE
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)
def _load_judgebench_pairs() -> pd.DataFrame:
    ds_claude = load_dataset('ScalerLab/JudgeBench', split='claude')
    ds_gpt = load_dataset('ScalerLab/JudgeBench', split='gpt')

    def to_rows(ds) -> List[Dict[str, object]]:
        rows: List[Dict[str, object]] = []
        for ex in ds:
            rows.append(
                {
                    'pair_id': ex.get('pair_id'),
                    'original_id': ex.get('original_id'),
                    'source': ex.get('source'),
                    'response_model': ex.get('response_model'),
                    'question': (ex.get('question') or '').strip(),
                    'response_A': (ex.get('response_A') or '').strip(),
                    'response_B': (ex.get('response_B') or '').strip(),
                    'label_ab': (ex.get('label') or '').strip(),
                }
            )
        return rows

    rows = to_rows(ds_claude) + to_rows(ds_gpt)
    base_df = pd.DataFrame(rows)

    def map_label_num(lbl: Optional[str]) -> Optional[int]:
        if lbl is None:
            return None
        lbl = str(lbl).strip()
        if lbl == 'A>B':
            return 1
        if lbl == 'B>A':
            return 0
        return None

    base_df['gold_label_num'] = base_df['label_ab'].map(map_label_num)
    binary = pd.Series(pd.NA, index=base_df.index, dtype='Int64')
    binary.loc[base_df['label_ab'] == 'B>A'] = 1
    binary.loc[base_df['label_ab'] == 'A>B'] = 0
    base_df['gold_label_binary'] = binary
    return base_df


def run_judgebench(models: List[str]) -> None:
    base_df = _load_judgebench_pairs()

    gaussian_mode = MODE == 'gaussian_mixture'
    root_name = 'gaussian_mixture' if gaussian_mode else 'binary'
    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / root_name / 'judgebench'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        response_a: List[str] = []
        response_b: List[str] = []
        was_swapped: List[bool] = []
        for ra, rb in zip(base_df['response_A'].tolist(), base_df['response_B'].tolist()):
            if SWAP_AB and random.random() < 0.5:
                response_a.append(rb)
                response_b.append(ra)
                was_swapped.append(True)
            else:
                response_a.append(ra)
                response_b.append(rb)
                was_swapped.append(False)

        prompts_swapped = [
            JUDGEBENCH_PREF_NO_TIE_PROMPT.format(
                question=q,
                answer_a=a,
                answer_b=b,
            )
            for q, a, b in zip(base_df['question'].tolist(), response_a, response_b)
        ]
        raw = run_vllm_generation(
            model_name,
            prompts_swapped,
            max_new_tokens=8,
            dataset_name='judgebench',
        )
        if raw is None:
            continue

        scores_ab: List[Optional[int]] = []
        scores_original: List[Optional[int]] = []
        pref_labels: List[str] = []
        pred_label_num: List[Optional[int]] = []
        for idx, entry in enumerate(raw):
            if gaussian_mode:
                score = extract_shifted_signed_score(entry, allow_zero=False)
            else:
                token = parse_total_rating(entry, allowed=[0, 1])
                score = int(token) if token not in ('', None) else None

            scores_ab.append(score)
            if score is None:
                scores_original.append(None)
                pref_labels.append('')
                pred_label_num.append(None)
                continue

            s_orig = -score if was_swapped[idx] else score
            scores_original.append(s_orig)
            if s_orig > 0:
                pref_labels.append('B')
                pred_label_num.append(0)
            elif s_orig < 0:
                pref_labels.append('A')
                pred_label_num.append(1)
            else:
                pref_labels.append('tie')
                pred_label_num.append(None)

        result = base_df.copy()
        result['response_a'] = response_a
        result['response_b'] = response_b
        result['was_swapped'] = was_swapped
        result['raw_output'] = raw
        result['score_ab'] = scores_ab
        result['score_original_order'] = scores_original
        result['pref_A_or_B'] = pref_labels
        result['pred_label_num'] = pred_label_num
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        new_path = output_dir / f"{model_short}.csv"
        result.to_csv(new_path, index=False)
        old_path = output_dir / f"{model_short}_prefs.csv"
        if old_path.exists() and old_path != new_path:
            try:
                old_path.unlink()
            except OSError:
                pass


def run_judgebench_binary(models: List[str]) -> None:
    base_df = _load_judgebench_pairs()
    filtered = base_df.dropna(subset=['gold_label_binary']).reset_index(drop=True)

    output_dir = get_binary_output_dir('judgebench')

    gaussian_mode = MODE == 'gaussian_mixture'
    prompt_template = (
        JUDGEBENCH_PREF_NO_TIE_PROMPT if gaussian_mode else PREFERENCE_BINARY_JUDGE_PROMPT
    )
    max_tokens = 8 if gaussian_mode else 6

    for model_name in models:
        response_a: List[str] = []
        response_b: List[str] = []
        was_swapped: List[bool] = []
        for ra, rb in zip(filtered['response_A'].tolist(), filtered['response_B'].tolist()):
            if SWAP_AB and random.random() < 0.5:
                response_a.append(rb)
                response_b.append(ra)
                was_swapped.append(True)
            else:
                response_a.append(ra)
                response_b.append(rb)
                was_swapped.append(False)

        prompts = [
            prompt_template.format(
                question=q,
                answer_a=a,
                answer_b=b,
            )
            for q, a, b in zip(filtered['question'].tolist(), response_a, response_b)
        ]

        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=max_tokens,
            dataset_name='judgebench_binary',
        )
        if raw is None:
            continue

        scores_ab: List[Optional[int]] = []
        scores_original: List[Optional[int]] = []
        pref_labels: List[str] = []
        pred_binary: List[Optional[int]] = []
        for idx, entry in enumerate(raw):
            if gaussian_mode:
                score = extract_shifted_signed_score(entry, allow_zero=False)
            else:
                token = parse_total_rating(entry, allowed=[0, 1])
                score = int(token) if token not in ('', None) else None
            scores_ab.append(score)
            if score is None:
                scores_original.append(None)
                pref_labels.append('')
                pred_binary.append(None)
                continue
            if gaussian_mode:
                s_orig = -score if was_swapped[idx] else score
                scores_original.append(s_orig)
                if s_orig > 0:
                    pref_labels.append('B')
                    pred_binary.append(1)
                elif s_orig < 0:
                    pref_labels.append('A')
                    pred_binary.append(0)
                else:
                    pref_labels.append('')
                    pred_binary.append(None)
            else:
                s_orig = 1 - score if was_swapped[idx] else score
                scores_original.append(s_orig)
                if s_orig == 0:
                    pref_labels.append('A')
                elif s_orig == 1:
                    pref_labels.append('B')
                else:
                    pref_labels.append('')
                pred_binary.append(s_orig if s_orig in (0, 1) else None)

        result = filtered.copy()
        result['response_a'] = response_a
        result['response_b'] = response_b
        result['was_swapped'] = was_swapped
        result['raw_output'] = raw
        result['score_ab'] = scores_ab
        result['score_original_order'] = scores_original
        result['pref_A_or_B'] = pref_labels
        result['pred_label_binary'] = pred_binary
        result['scoring_mode'] = MODE
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def _format_arena_turns(conv: List[Dict[str, object]]) -> Dict[str, str]:
    """Extract first user question and first assistant reply from a conversation list.
    Returns {'question': str, 'answer': str} (empty strings if unavailable).
    """
    q = ''
    a = ''
    for msg in conv:
        if not isinstance(msg, dict):
            continue
        role = (msg.get('role') or '').strip().lower()
        content = (msg.get('content') or '').strip()
        if role == 'user' and not q:
            q = content
        elif role == 'assistant' and not a:
            a = content
        if q and a:
            break
    return {'question': q, 'answer': a}


def _load_chatbot_arena_pairs(sample_size: Optional[int], seed: int) -> pd.DataFrame:
    ds = load_dataset('lmsys/chatbot_arena_conversations', split='train')

    rows: List[Dict[str, object]] = []
    for ex in ds:
        turn = int(ex.get('turn', 0) or 0)
        if turn > 2:
            continue
        conv_a = ex.get('conversation_a') or []
        conv_b = ex.get('conversation_b') or []
        if not isinstance(conv_a, list) or not isinstance(conv_b, list):
            continue
        fa = _format_arena_turns(conv_a)
        fb = _format_arena_turns(conv_b)
        if not (fa['question'] and fa['answer'] and fb['answer']):
            continue
        question = fa['question'] or fb['question']

        winner = (ex.get('winner') or '').strip().lower()
        if winner == 'model_a':
            gold = -1
        elif winner == 'model_b':
            gold = 1
        elif winner == 'tie':
            gold = 0
        else:
            gold = None

        rows.append(
            {
                'question_id': ex.get('question_id'),
                'model_a': ex.get('model_a'),
                'model_b': ex.get('model_b'),
                'winner': ex.get('winner'),
                'turn': turn,
                'question': question,
                'response_A': fa['answer'],
                'response_B': fb['answer'],
                'gold_label_num': gold,
                'judge': ex.get('judge'),
                'language': ex.get('language'),
                'anony': ex.get('anony'),
            }
        )

    base_df = pd.DataFrame(rows)
    if sample_size is not None and len(base_df) > sample_size:
        base_df = base_df.sample(n=sample_size, random_state=seed).reset_index(drop=True)

    binary = pd.Series(pd.NA, index=base_df.index, dtype='Int64')
    binary.loc[base_df['gold_label_num'] == 1] = 1
    binary.loc[base_df['gold_label_num'] == -1] = 0
    base_df['gold_label_binary'] = binary
    return base_df


def run_chatbot_arena(models: List[str], sample_size: int = 5000, seed: int = 42) -> None:
    base_df = _load_chatbot_arena_pairs(sample_size, seed)

    response_a: List[str] = []
    response_b: List[str] = []
    was_swapped: List[bool] = []
    for ra, rb in zip(base_df['response_A'].tolist(), base_df['response_B'].tolist()):
        if SWAP_AB and random.random() < 0.5:
            response_a.append(rb)
            response_b.append(ra)
            was_swapped.append(True)
        else:
            response_a.append(ra)
            response_b.append(rb)
            was_swapped.append(False)

    prompts = [
        HELPSTEER3_PREF_JUDGE_PROMPT.format(
            question=q,
            answer_a=a,
            answer_b=b,
        )
        for q, a, b in zip(base_df['question'].tolist(), response_a, response_b)
    ]

    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'binary' / 'chatbot_arena_conversations'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=8,
            dataset_name='chatbot_arena_conversations',
        )
        if raw is None:
            continue

        scores_ab: List[Optional[int]] = []
        scores_original: List[Optional[int]] = []
        pref_labels: List[str] = []
        pred_label_num: List[Optional[int]] = []
        for idx, entry in enumerate(raw):
            score = extract_shifted_signed_score(entry, allow_zero=False)
            scores_ab.append(score)
            if score is None:
                scores_original.append(None)
                pref_labels.append('')
                pred_label_num.append(None)
                continue
            s_orig = -score if was_swapped[idx] else score
            scores_original.append(s_orig)
            if s_orig > 0:
                pref_labels.append('B')
                pred_label_num.append(1)
            elif s_orig < 0:
                pref_labels.append('A')
                pred_label_num.append(-1)
            else:
                pref_labels.append('tie')
                pred_label_num.append(0)

        result = base_df.copy()
        result['response_a'] = response_a
        result['response_b'] = response_b
        result['was_swapped'] = was_swapped
        result['raw_output'] = raw
        result['score_ab'] = scores_ab
        result['score_original_order'] = scores_original
        result['pref_A_or_B'] = pref_labels
        result['pred_label_num'] = pred_label_num
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        new_path = output_dir / f"{model_short}.csv"
        result.to_csv(new_path, index=False)
        old_path = output_dir / f"{model_short}_prefs.csv"
        if old_path.exists() and old_path != new_path:
            try:
                old_path.unlink()
            except OSError:
                pass


def run_chatbot_arena_binary(models: List[str], sample_size: int = 5000, seed: int = 42) -> None:
    base_df = _load_chatbot_arena_pairs(sample_size, seed)
    filtered = base_df.dropna(subset=['gold_label_binary']).reset_index(drop=True)

    gaussian_mode = MODE == 'gaussian_mixture'
    prompt_template = (
        JUDGEBENCH_PREF_NO_TIE_PROMPT if gaussian_mode else PREFERENCE_BINARY_JUDGE_PROMPT
    )
    max_tokens = 8 if gaussian_mode else 6

    response_a: List[str] = []
    response_b: List[str] = []
    was_swapped: List[bool] = []
    for ra, rb in zip(filtered['response_A'].tolist(), filtered['response_B'].tolist()):
        if SWAP_AB and random.random() < 0.5:
            response_a.append(rb)
            response_b.append(ra)
            was_swapped.append(True)
        else:
            response_a.append(ra)
            response_b.append(rb)
            was_swapped.append(False)

    prompts = [
        prompt_template.format(
            question=q,
            answer_a=a,
            answer_b=b,
        )
        for q, a, b in zip(filtered['question'].tolist(), response_a, response_b)
    ]

    output_dir = get_binary_output_dir('chatbot_arena_conversations')

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=max_tokens,
            dataset_name='chatbot_arena_conversations_binary',
        )
        if raw is None:
            continue

        scores_ab: List[Optional[int]] = []
        scores_original: List[Optional[int]] = []
        pref_labels: List[str] = []
        pred_binary: List[Optional[int]] = []
        for idx, entry in enumerate(raw):
            if gaussian_mode:
                score = extract_shifted_signed_score(entry, allow_zero=False)
            else:
                token = parse_total_rating(entry, allowed=[0, 1])
                score = int(token) if token not in ('', None) else None
            scores_ab.append(score)
            if score is None:
                scores_original.append(None)
                pref_labels.append('')
                pred_binary.append(None)
                continue
            if gaussian_mode:
                s_orig = -score if was_swapped[idx] else score
                scores_original.append(s_orig)
                if s_orig > 0:
                    pref_labels.append('B')
                    pred_binary.append(1)
                elif s_orig < 0:
                    pref_labels.append('A')
                    pred_binary.append(0)
                else:
                    pref_labels.append('tie')
                    pred_binary.append(None)
            else:
                s_orig = 1 - score if was_swapped[idx] else score
                scores_original.append(s_orig)
                if s_orig == 0:
                    pref_labels.append('A')
                elif s_orig == 1:
                    pref_labels.append('B')
                else:
                    pref_labels.append('')
                pred_binary.append(s_orig if s_orig in (0, 1) else None)

        result = filtered.copy()
        result['response_a'] = response_a
        result['response_b'] = response_b
        result['was_swapped'] = was_swapped
        result['raw_output'] = raw
        result['score_ab'] = scores_ab
        result['score_original_order'] = scores_original
        result['pref_A_or_B'] = pref_labels
        result['pred_label_binary'] = pred_binary
        result['scoring_mode'] = MODE
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def _load_summarize_split(split: str) -> pd.DataFrame:
    ds = load_dataset('openai/summarize_from_feedback', 'axis', split=split)
    records: List[Dict[str, object]] = []
    for row in ds:
        info = row.get('info', {}) or {}
        summary = row.get('summary', {}) or {}
        axes = summary.get('axes', {}) or {}
        records.append(
            {
                'example_id': info.get('id'),
                'post': (info.get('post') or '').strip(),
                'summary_text': (summary.get('text') or '').strip(),
                'overall_score': axes.get('overall'),
                'policy': summary.get('policy'),
                'split': split,
            }
        )
    return pd.DataFrame(records)


def run_summarize_from_feedback_axis(models: List[str], splits: Tuple[str, ...] = ('validation', 'test')) -> None:
    split_frames = {split: _load_summarize_split(split) for split in splits}
    output_root = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'summarize_from_feedback'
    for split in splits:
        (output_root / split).mkdir(parents=True, exist_ok=True)

    for model_name in models:
        prompts_per_split = {
            split: [
                SUMMARIZE_AXIS_JUDGE_PROMPT.format(post=row['post'], summary=row['summary_text'])
                for _, row in frame.iterrows()
            ]
            for split, frame in split_frames.items()
        }
        for split, prompts in prompts_per_split.items():
            if not prompts:
                continue
            raw = run_vllm_generation(
                model_name,
                prompts,
                max_new_tokens=4,
                dataset_name=f'summarize_from_feedback_axis/{split}',
            )
            if raw is None:
                continue
            parsed = [parse_total_rating(text, range(1, 8)) for text in raw]
            frame = split_frames[split].copy()
            frame['raw_output'] = raw
            frame['parsed_output'] = parsed
            model_short = model_name.split('/')[-1]
            frame['model'] = model_short
            frame.to_csv(output_root / split / f"{model_short}.csv", index=False)


def _load_yelp_subset(split: str, sample_size: int, seed: int) -> pd.DataFrame:
    dataset = load_dataset('Yelp/yelp_review_full', split=split)
    sample_count = min(sample_size, len(dataset))
    sampled = dataset.shuffle(seed=seed).select(range(sample_count))
    rows = []
    for idx, example in enumerate(sampled):
        rows.append(
            {
                'example_id': idx,
                'text': example.get('text', ''),
                'label': int(example.get('label', -1)),
            }
        )
    # Ensure expected columns exist even if sample_count == 0
    return pd.DataFrame(rows, columns=['example_id', 'text', 'label'])


def run_yelp_review_full(models: List[str], sample_size: int = 5000, seed: int = 42) -> None:
    split_frames = {'test': _load_yelp_subset('test', sample_size, seed)}
    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'yelp_with_scores'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        for split, frame in split_frames.items():
            # Be robust to column naming; prefer 'text' but fall back to 'review'
            if 'text' in frame.columns:
                _texts = frame['text'].astype(str).tolist()
            elif 'review' in frame.columns:
                _texts = frame['review'].astype(str).tolist()
            else:
                raise KeyError("Expected 'text' (or 'review') column in Yelp frame")

            # YELP_REVIEW_FULL_JUDGE_PROMPT expects a 'text' field in the template
            prompts = [YELP_REVIEW_FULL_JUDGE_PROMPT.format(text=t) for t in _texts]
            raw = run_vllm_generation(
                model_name,
                prompts,
                max_new_tokens=8,
                dataset_name=f'yelp_review_full/{split}',
            )
            if raw is None:
                continue
            parsed = [parse_total_rating(text, range(1, 6)) for text in raw]
            result = frame.copy()
            result['raw_output'] = raw
            result['parsed_output'] = parsed
            result['parsed_score'] = parsed
            result['model'] = model_name.split('/')[-1]
            result.to_csv(output_dir / f"{result['model'].iloc[0]}_{split}.csv", index=False)


def run_civilcomments_binary(models: List[str]) -> None:
    data_path = SCRIPT_DIR.parent / 'data' / 'binary' / 'civilcomments.csv'
    df = pd.read_csv(data_path)
    texts = df['text'].astype(str).tolist()
    gold = pd.to_numeric(df['label'], errors='coerce').astype('Int64')

    gaussian_mode = MODE == 'gaussian_mixture'
    prompt_template = CIVILCOMMENTS_SCORE_JUDGE_PROMPT if gaussian_mode else CIVILCOMMENTS_BINARY_JUDGE_PROMPT
    max_tokens = 8 if gaussian_mode else 6
    prompts = [prompt_template.format(text=text) for text in texts]
    output_dir = get_binary_output_dir('civilcomments')

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=max_tokens,
            dataset_name='civilcomments_binary',
        )
        if raw is None:
            continue
        if gaussian_mode:
            parsed_scores: List[Optional[int]] = []
            pred_binary: List[Optional[int]] = []
            for entry in raw:
                token = parse_total_rating(entry, allowed=list(range(0, 10)))
                score = int(token) if token not in ('', None, '') else None
                parsed_scores.append(score)
                if score is None:
                    pred_binary.append(None)
                else:
                    pred_binary.append(1 if score >= CIVILCOMMENTS_GAUSSIAN_THRESHOLD else 0)
            result = pd.DataFrame(
                {
                    'text': texts,
                    'label': gold,
                    'raw_output': raw,
                    'parsed_output': parsed_scores,
                }
            )
            result['pred_label_binary'] = pred_binary
        else:
            parsed_tokens = [parse_total_rating(entry, allowed=[0, 1]) for entry in raw]
            parsed_binary = [int(tok) if tok not in ('', None, '') else None for tok in parsed_tokens]
            result = pd.DataFrame(
                {
                    'text': texts,
                    'label': gold,
                    'raw_output': raw,
                    'parsed_output': parsed_binary,
                }
            )
            result['pred_label_binary'] = parsed_binary
        result['model'] = model_name.split('/')[-1]
        result['scoring_mode'] = MODE
        result.to_csv(output_dir / f"{result['model'].iloc[0]}.csv", index=False)


def run_tripadvisor(models: List[str], sample_size: int = 5000, seed: int = 42) -> None:
    dataset = load_dataset('nhull/tripadvisor-split-dataset-v2', split='test')
    sample_count = min(sample_size, len(dataset))
    sampled = dataset.shuffle(seed=seed).select(range(sample_count))
    frame = pd.DataFrame(
        {
            'example_id': list(range(sample_count)),
            'review': [example.get('review', '') for example in sampled],
            'label': [int(round(float(example.get('label', -1)))) for example in sampled],
        }
    )

    prompts = [TRIPADVISOR_REVIEW_JUDGE_PROMPT.format(review=text) for text in frame['review']]
    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'tripadvisor_reviews'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=8,
            dataset_name='tripadvisor_reviews',
        )
        if raw is None:
            continue
        parsed = [parse_total_rating(text, range(1, 6)) for text in raw]
        result = frame.copy()
        result['raw_output'] = raw
        result['parsed_output'] = parsed
        result['parsed_score'] = parsed
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def run_review_5k(
    models: List[str],
    split: str = 'all',
    *,
    max_new_tokens: int = REVIEW5K_DEFAULT_MAX_NEW_TOKENS,
    reserve_tokens: int = REVIEW5K_DEFAULT_RESERVE_TOKENS,
    skip_models: Optional[Iterable[str]] = None,
) -> None:
    project_root = SCRIPT_DIR.parent
    frame = load_review5k(split=split, project_root=project_root)
    gaussian_mode = MODE == 'gaussian_mixture'
    if gaussian_mode:
        output_dir = project_root / 'judge_outputs' / 'fully_gaussian' / 'review_5k'
    else:
        output_dir = get_binary_output_dir('review_5k')
    output_dir.mkdir(parents=True, exist_ok=True)

    skip_set = set(skip_models or [])
    prompt_template = REVIEW5K_PROMPT_TEMPLATES.get(
        'gaussian_mixture' if gaussian_mode else 'binary',
        REVIEW5K_JUDGE_PROMPT if gaussian_mode else REVIEW5K_BINARY_JUDGE_PROMPT,
    )
    prefix_template, suffix_template = prompt_template.split('{paper}')

    for model_name in models:
        if model_name in skip_set:
            print(f"[skip] {model_name} (requested)")
            continue

        context = REVIEW5K_MODEL_CONTEXT_LENGTHS.get(
            model_name,
            REVIEW5K_FALLBACK_CONTEXT,
        )
        allowed_prompt_tokens = context - (max_new_tokens + reserve_tokens)
        if allowed_prompt_tokens <= 0:
            raise ValueError(
                f"Insufficient context window for {model_name}: context={context}, "
                f"max_new_tokens={max_new_tokens}, reserve={reserve_tokens}"
            )

        if model_name not in REVIEW5K_MODEL_CONTEXT_LENGTHS:
            print(
                f"[warn] Context length unknown for {model_name}; using fallback {REVIEW5K_FALLBACK_CONTEXT} tokens.",
                flush=True,
            )

        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            use_fast=True,
            trust_remote_code=True,
        )
        suffix_token_len = len(tokenizer.encode(suffix_template, add_special_tokens=False))
        notice_token_len = len(tokenizer.encode(REVIEW5K_TRUNCATION_NOTICE, add_special_tokens=False))

        prompts: List[str] = []
        truncation_flags: List[bool] = []
        stats = Review5KTruncationStats(total=len(frame))

        for _, row in frame.iterrows():
            prompt, was_truncated = build_review5k_prompt(
                guidelines=str(row.get('system_prompt', '')),
                paper=str(row.get('paper_content', '')),
                tokenizer=tokenizer,
                allowed_tokens=allowed_prompt_tokens,
                suffix_token_len=suffix_token_len,
                notice_token_len=notice_token_len,
                prefix_template=prefix_template,
                suffix_template=suffix_template,
            )
            if was_truncated:
                stats.truncated += 1
            prompts.append(prompt)
            truncation_flags.append(was_truncated)

        stats.log(model_name)

        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=max_new_tokens,
            dataset_name='review_5k',
            reserve_tokens=reserve_tokens,
            apply_truncation=False,
        )
        if raw is None:
            continue

        if gaussian_mode:
            parsed = [
                parse_total_rating(text, range(1, 11))
                or parse_first_number(text, allow_float=False)
                for text in raw
            ]
            pred_binary: List[Optional[int]] = [None] * len(parsed)
        else:
            parsed_tokens: List[str] = []
            pred_binary = []
            for text in raw:
                token = parse_total_rating(text, [0, 1])
                if not token:
                    token = parse_first_number(text, allow_float=False)
                if token in ('0', '1'):
                    parsed_tokens.append(token)
                    pred_binary.append(int(token))
                else:
                    parsed_tokens.append(token)
                    pred_binary.append(None)
            parsed = parsed_tokens

        result = frame.copy()
        result['raw_output'] = raw
        result['parsed_output'] = parsed
        result['prompt_truncated'] = truncation_flags
        result['prompt_token_budget'] = allowed_prompt_tokens
        result['pred_label_binary'] = pred_binary
        result['scoring_mode'] = MODE
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(
            output_dir / f"{model_short}.csv",
            index=False,
            quoting=csv.QUOTE_MINIMAL,
            escapechar='\\',
        )

        del tokenizer


def run_asset_ratings(models: List[str]) -> None:
    ds = load_dataset('facebook/asset', 'ratings', split='full')
    frame = pd.DataFrame(
        {
            'row_index': list(range(len(ds))),
            'example_index': [int(row.get('original_sentence_id', idx) or 0) for idx, row in enumerate(ds)],
            'worker_id': [row.get('worker_id') for row in ds],
            'aspect': [row.get('aspect') for row in ds],
            'original': [(row.get('original') or '').strip() for row in ds],
            'simplification': [(row.get('simplification') or '').strip() for row in ds],
            'human_rating': [float(row.get('rating', 0.0)) for row in ds],
        }
    )

    prompts = [
        ASSET_SIMPLIFICATION_JUDGE_PROMPT.format(
            aspect=(row['aspect'] or ''),
            original=row['original'],
            simplification=row['simplification'],
        )
        for _, row in frame.iterrows()
    ]

    output_dir = SCRIPT_DIR.parent / 'judge_outputs' / 'fully_gaussian' / 'asset'
    output_dir.mkdir(parents=True, exist_ok=True)

    for model_name in models:
        raw = run_vllm_generation(
            model_name,
            prompts,
            max_new_tokens=4,
            dataset_name='asset_ratings',
        )
        if raw is None:
            continue
        parsed = [parse_total_rating(text, range(0, 101)) or parse_first_number(text, allow_float=False) for text in raw]
        result = frame.copy()
        result['raw_output'] = raw
        result['parsed_output'] = parsed
        model_short = model_name.split('/')[-1]
        result['model'] = model_short
        result.to_csv(output_dir / f"{model_short}.csv", index=False)


def _extract_last_user_question(prompt_turns: List[Dict[str, object]]) -> str:
    """Return the content of the last 'user' role in a prompt list of {role, content} dicts."""
    q = ''
    if not isinstance(prompt_turns, list):
        return q
    for msg in prompt_turns:
        if not isinstance(msg, dict):
            continue
        role = (msg.get('role') or '').strip().lower()
        content = (msg.get('content') or '').strip()
        if role == 'user':
            q = content  # keep last seen user message
    return q


def _extract_assistant_review(messages: List[Dict[str, object]]) -> str:
    if not isinstance(messages, list):
        return ''
    for msg in reversed(messages):
        if not isinstance(msg, dict):
            continue
        role = (msg.get('role') or '').strip().lower()
        if role == 'assistant':
            return (msg.get('content') or '').strip()
    return ''


def _load_allenai_preference_split(split: str) -> pd.DataFrame:
    ds = load_dataset('allenai/preference-test-sets', split=split)
    rows: List[Dict[str, object]] = []
    for ex in ds:
        prompt_turns = ex.get('prompt') or []
        question = _extract_last_user_question(prompt_turns)
        chosen = (ex.get('chosen') or '').strip()
        rejected = (ex.get('rejected') or '').strip()
        if not (question and chosen and rejected):
            continue
        rows.append(
            {
                'id': ex.get('id'),
                'subset': ex.get('subset'),
                'split': split,
                'question': question,
                'response_A': chosen,
                'response_B': rejected,
                'gold_label_num': 1,
            }
        )

    base_df = pd.DataFrame(rows)
    if base_df.empty:
        return base_df

    binary = pd.Series(pd.NA, index=base_df.index, dtype='Int64')
    binary.loc[:] = 1
    base_df['gold_label_binary'] = binary
    return base_df


def run_allenai_preference_test_sets(models: List[str], splits: Optional[Iterable[str]] = None) -> None:
    default_splits = [
        'anthropic_harmless',
        'anthropic_helpful',
        'summarize',
        'pku_better',
        'pku_safer',
        'shp',
        'mtbench_human',
        'mtbench_gpt4',
    ]
    splits = list(splits) if splits is not None else default_splits

    root_name = 'gaussian_mixture' if MODE == 'gaussian_mixture' else 'fully_gaussian'
    output_root = SCRIPT_DIR.parent / 'judge_outputs' / root_name
    for split in splits:
        base_df = _load_allenai_preference_split(split)
        if base_df.empty:
            continue

        response_a: List[str] = []
        response_b: List[str] = []
        was_swapped: List[bool] = []
        for ra, rb in zip(base_df['response_A'].tolist(), base_df['response_B'].tolist()):
            if SWAP_AB and random.random() < 0.5:
                response_a.append(rb)
                response_b.append(ra)
                was_swapped.append(True)
            else:
                response_a.append(ra)
                response_b.append(rb)
                was_swapped.append(False)

        prompts = [
            HELPSTEER3_PREF_JUDGE_PROMPT.format(
                question=q,
                answer_a=a,
                answer_b=b,
            )
            for q, a, b in zip(base_df['question'].tolist(), response_a, response_b)
        ]

        (output_root / split).mkdir(parents=True, exist_ok=True)

        for model_name in models:
            raw = run_vllm_generation(
                model_name,
                prompts,
                max_new_tokens=8,
                dataset_name=f'allenai_preference_test_sets/{split}',
            )
            if raw is None:
                continue

            scores_ab: List[Optional[int]] = []
            scores_original: List[Optional[int]] = []
            pref_labels: List[str] = []
            pred_label_num: List[Optional[int]] = []
            for idx, entry in enumerate(raw):
                score = extract_shifted_signed_score(entry, allow_zero=False)
                scores_ab.append(score)
                if score is None:
                    scores_original.append(None)
                    pref_labels.append('')
                    pred_label_num.append(None)
                    continue
                s_orig = -score if was_swapped[idx] else score
                scores_original.append(s_orig)
                if s_orig > 0:
                    pref_labels.append('B')
                    pred_label_num.append(0)
                elif s_orig < 0:
                    pref_labels.append('A')
                    pred_label_num.append(1)
                else:
                    pref_labels.append('tie')
                    pred_label_num.append(None)

            result = base_df.copy()
            result['response_a'] = response_a
            result['response_b'] = response_b
            result['was_swapped'] = was_swapped
            result['raw_output'] = raw
            result['score_ab'] = scores_ab
            result['score_original_order'] = scores_original
            result['pref_A_or_B'] = pref_labels
            result['pred_label_num'] = pred_label_num
            model_short = model_name.split('/')[-1]
            result['model'] = model_short
            target_dir = output_root / split
            new_path = target_dir / f"{model_short}.csv"
            result.to_csv(new_path, index=False)
            old_path = target_dir / f"{model_short}_prefs.csv"
            if old_path.exists() and old_path != new_path:
                try:
                    old_path.unlink()
                except OSError:
                    pass


def run_allenai_preference_test_sets_binary(models: List[str], splits: Optional[Iterable[str]] = None) -> None:
    default_splits = [
        'anthropic_harmless',
        'anthropic_helpful',
        'summarize',
        'pku_better',
        'pku_safer',
        'shp',
        'mtbench_human',
        'mtbench_gpt4',
    ]
    splits = list(splits) if splits is not None else default_splits

    root_name = 'gaussian_mixture' if MODE == 'gaussian_mixture' else 'binary'
    output_root = SCRIPT_DIR.parent / 'judge_outputs' / root_name
    for split in splits:
        base_df = _load_allenai_preference_split(split)
        if base_df.empty:
            continue

        filtered = base_df.dropna(subset=['gold_label_binary']).reset_index(drop=True)
        gaussian_mode = MODE == 'gaussian_mixture'
        prompt_template = (
            JUDGEBENCH_PREF_NO_TIE_PROMPT if gaussian_mode else PREFERENCE_BINARY_JUDGE_PROMPT
        )
        max_tokens = 8 if gaussian_mode else 6
        dataset_slug = (
            'allenai_preference_test_sets_gaussian_mixture'
            if gaussian_mode
            else 'allenai_preference_test_sets_binary'
        )
        file_suffix = '.csv'
        response_a: List[str] = []
        response_b: List[str] = []
        was_swapped: List[bool] = []
        for ra, rb in zip(filtered['response_A'].tolist(), filtered['response_B'].tolist()):
            if SWAP_AB and random.random() < 0.5:
                response_a.append(rb)
                response_b.append(ra)
                was_swapped.append(True)
            else:
                response_a.append(ra)
                response_b.append(rb)
                was_swapped.append(False)

        prompts = [
            prompt_template.format(
                question=q,
                answer_a=a,
                answer_b=b,
            )
            for q, a, b in zip(filtered['question'].tolist(), response_a, response_b)
        ]

        (output_root / split).mkdir(parents=True, exist_ok=True)

        for model_name in models:
            raw = run_vllm_generation(
                model_name,
                prompts,
                max_new_tokens=max_tokens,
                dataset_name=f'{dataset_slug}/{split}',
            )
            if raw is None:
                continue

            scores_ab: List[Optional[int]] = []
            scores_original: List[Optional[int]] = []
            pref_labels: List[str] = []
            pred_binary: List[Optional[int]] = []
            for idx, entry in enumerate(raw):
                if gaussian_mode:
                    score = extract_shifted_signed_score(entry, allow_zero=False)
                else:
                    token = parse_total_rating(entry, allowed=[0, 1])
                    score = int(token) if token not in ('', None) else None
                scores_ab.append(score)
                if score is None:
                    scores_original.append(None)
                    pref_labels.append('')
                    pred_binary.append(None)
                    continue
                if gaussian_mode:
                    s_orig = -score if was_swapped[idx] else score
                    scores_original.append(s_orig)
                    if s_orig > 0:
                        pref_labels.append('B')
                        pred_binary.append(1)
                    elif s_orig < 0:
                        pref_labels.append('A')
                        pred_binary.append(0)
                    else:
                        pref_labels.append('tie')
                        pred_binary.append(None)
                else:
                    s_orig = 1 - score if was_swapped[idx] else score
                    scores_original.append(s_orig)
                    if s_orig == 0:
                        pref_labels.append('A')
                    elif s_orig == 1:
                        pref_labels.append('B')
                    else:
                        pref_labels.append('')
                    pred_binary.append(s_orig if s_orig in (0, 1) else None)

            result = filtered.copy()
            result['response_a'] = response_a
            result['response_b'] = response_b
            result['was_swapped'] = was_swapped
            result['raw_output'] = raw
            result['score_ab'] = scores_ab
            result['score_original_order'] = scores_original
            result['pref_A_or_B'] = pref_labels
            result['pred_label_binary'] = pred_binary
            result['scoring_mode'] = MODE
            model_short = model_name.split('/')[-1]
            result['model'] = model_short
            target_dir = output_root / split
            new_path = target_dir / f"{model_short}{file_suffix}"
            result.to_csv(new_path, index=False)
            old_path = target_dir / f"{model_short}_prefs.csv"
            if old_path.exists() and old_path != new_path:
                try:
                    old_path.unlink()
                except OSError:
                    pass


DATASET_RUNNERS: Dict[str, Callable[[List[str]], None]] = {
    'master_keys': run_masterkey_exp,
    'master_keys_with_ref': run_masterkey_exp_with_ref,
    'feedbackqa': run_feedbackqa,
    'helpsteer2': run_helpsteer2,
    'ultrafeedback': run_ultrafeedback,
    'helpsteer3': run_helpsteer3,
    'helpsteer3_binary': run_helpsteer3_binary,
    'judgebench': run_judgebench,
    'judgebench_binary': run_judgebench_binary,
    'chatbot_arena_conversations': run_chatbot_arena,
    'chatbot_arena_conversations_binary': run_chatbot_arena_binary,
    'allenai_preference_test_sets': run_allenai_preference_test_sets,
    'allenai_preference_test_sets_binary': run_allenai_preference_test_sets_binary,
    'allenai_preference_test_sets/anthropic_harmless': lambda models: run_allenai_preference_test_sets(models, ['anthropic_harmless']),
    'allenai_preference_test_sets/anthropic_harmless_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['anthropic_harmless']),
    'allenai_preference_test_sets/anthropic_helpful': lambda models: run_allenai_preference_test_sets(models, ['anthropic_helpful']),
    'allenai_preference_test_sets/anthropic_helpful_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['anthropic_helpful']),
    'allenai_preference_test_sets/pku_better': lambda models: run_allenai_preference_test_sets(models, ['pku_better']),
    'allenai_preference_test_sets/pku_better_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['pku_better']),
    'allenai_preference_test_sets/pku_safer': lambda models: run_allenai_preference_test_sets(models, ['pku_safer']),
    'allenai_preference_test_sets/pku_safer_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['pku_safer']),
    'allenai_preference_test_sets/shp': lambda models: run_allenai_preference_test_sets(models, ['shp']),
    'allenai_preference_test_sets/shp_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['shp']),
    'allenai_preference_test_sets/mtbench_human': lambda models: run_allenai_preference_test_sets(models, ['mtbench_human']),
    'allenai_preference_test_sets/mtbench_human_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['mtbench_human']),
    'allenai_preference_test_sets/mtbench_gpt4': lambda models: run_allenai_preference_test_sets(models, ['mtbench_gpt4']),
    'allenai_preference_test_sets/mtbench_gpt4_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['mtbench_gpt4']),
    'allenai_preference_test_sets/summarize': lambda models: run_allenai_preference_test_sets(models, ['summarize']),
    'allenai_preference_test_sets/summarize_binary': lambda models: run_allenai_preference_test_sets_binary(models, ['summarize']),
    'summarize_from_feedback_axis/test': lambda models: run_summarize_from_feedback_axis(models),
    'review_5k': run_review_5k,
    'yelp_review_full/test': run_yelp_review_full,
    'civilcomments_binary': run_civilcomments_binary,
    'tripadvisor_reviews': run_tripadvisor,
    'asset_ratings': run_asset_ratings,
}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description='Generate judge outputs for exp1 datasets.')
    parser.add_argument(
        '--datasets',
        nargs='+',
        default=['all'],
        help='Datasets to process (default: all).'
             ' Choices: ' + ', '.join(DATASET_RUNNERS),
    )
    parser.add_argument('--models', nargs='+', default=None, help='Override the default model list.')
    parser.add_argument('--sample-size', type=int, default=5000, help='Sample size for Yelp/TripAdvisor datasets (default: 5000).')
    parser.add_argument('--seed', type=int, default=42, help='Sampling seed (default: 42).')
    parser.add_argument('--swap-ab', action='store_true', help='Randomly swap A/B positions when prompting pairwise judges.')
    parser.add_argument(
        '--mode',
        choices=['binary', 'gaussian_mixture'],
        default='binary',
        help='Scoring mode for binary datasets. "binary" uses 0/1 prompts;'
             ' "gaussian_mixture" uses richer multi-level scores before binarisation.',
    )
    parser.add_argument(
        '--review5k-max-new-tokens',
        type=int,
        default=REVIEW5K_DEFAULT_MAX_NEW_TOKENS,
        help=f'Max tokens to generate for Review-5K prompts (default: {REVIEW5K_DEFAULT_MAX_NEW_TOKENS}).',
    )
    parser.add_argument(
        '--review5k-reserve-tokens',
        type=int,
        default=REVIEW5K_DEFAULT_RESERVE_TOKENS,
        help=f'Reserved tokens to keep free when truncating Review-5K prompts (default: {REVIEW5K_DEFAULT_RESERVE_TOKENS}).',
    )
    parser.add_argument(
        '--review5k-skip-models',
        nargs='+',
        default=None,
        help='Optional list of models to skip when generating Review-5K outputs.',
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    models = args.models or DEFAULT_MODELS
    requested = list(DATASET_RUNNERS.keys()) if 'all' in args.datasets else args.datasets

    random.seed(args.seed)
    global SWAP_AB
    SWAP_AB = bool(args.swap_ab)
    global MODE
    MODE = args.mode

    for dataset in requested:
        if dataset not in DATASET_RUNNERS:
            raise ValueError(f"Unsupported dataset: {dataset}")
        runner = DATASET_RUNNERS[dataset]
        if args.models:
            model_list = args.models
        elif dataset in {'master_keys', 'master_keys_with_ref'}:
            model_list = MASTERKEY_DEFAULT_MODELS
        else:
            model_list = models
        if dataset == 'yelp_review_full/test':
            runner(model_list, sample_size=args.sample_size, seed=args.seed)
        elif dataset == 'review_5k':
            runner(
                model_list,
                max_new_tokens=args.review5k_max_new_tokens,
                reserve_tokens=args.review5k_reserve_tokens,
                skip_models=args.review5k_skip_models,
            )
        elif dataset == 'tripadvisor_reviews':
            runner(model_list, sample_size=args.sample_size, seed=args.seed)
        elif dataset in {'chatbot_arena_conversations', 'chatbot_arena_conversations_binary'}:
            runner(model_list, sample_size=args.sample_size, seed=args.seed)
        else:
            runner(model_list)
if __name__ == '__main__':
    main()
