import argparse
from typing import List, Dict, Any, Optional, Tuple, Union
import json
import os
import numpy as np
from loguru import logger
from datasets import load_from_disk, load_dataset
from tqdm import tqdm
import copy
from tabulate import tabulate
try:
    import sys
    basedir = os.path.abspath(os.path.join(os.path.dirname(__file__),"../../metal-ama"))
    sys.path.append(basedir)
except:
    pass
from metal.label_model import LabelModel
from metal.label_model.continuous_model import ContinuousModel

from scipy.stats import pointbiserialr, pearsonr, spearmanr

from collections import defaultdict

from itertools import product 

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from sklearn.linear_model import LogisticRegression

logger.remove()
logger.add(lambda msg: print(msg, end=""), format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {message}", level="INFO")


DATASET_TO_REWARD_MODELS = {
    'anonymous_research/MATH_with_LM_Judges_and_Reward_Model_Results_V2': [
        'armor_rm_score', 'grm_scores', 'skyworks_scores', 'urm_scores', 'qrm_scores',
        'gpm_scores', 'grm_gemma_scores', 'grm_llama32_scores', 'internlm_scores', 'offset_bias_scores'
    ],
    'anonymous_research/MATH_with_LM_Judges_and_Reward_Model_Results_V3': [
        'armor_rm_score', 'eurus_prm2_scores', 'eurus_prm_scores', 'gpm_scores',
        'grm_gemma_scores', 'grm_llama32_scores', 'grm_scores', 'inform_scores', 
        'internlm2_scores', 'internlm_scores', 'offset_bias_scores', 'qrm_gemma_scores',
        'qrm_scores', 'qwen25_math_scores', 'skyworks_scores', 'urm_scores'
    ], 
    'anonymous_research/GPQA_with_GPT-4o-mini_Samples_and_RMs_and_LM_Judges_v1': [
        'armor_rm_score', 'eurus_prm2_scores', 'eurus_prm_scores', 'gpm_scores', 
        'grm_gemma_scores', 'grm_llama32_scores', 'grm_scores', 'internlm2_scores',
        'internlm_scores', 'offset_bias_scores', 'qrm_gemma_scores', 'qrm_scores', 
        'qwen25_math_scores', 'skywork_gemma_scores', 'skyworks_scores', 'urm_scores'
    ],
    'anonymous_research/CodeContests_Llama_70B_with_LM_Judges_and_RMs_v1': [
        'armor_rm_score', 'gpm_scores', 'grm_gemma_scores', 'grm_llama32_scores',
        'grm_scores', 'internlm_scores', 'offset_bias_scores', 'qrm_scores', 'skyworks_scores', 'urm_scores'
    ],
    'anonymous_research/AIMO_GPT-4o-mini_with_LM_Judges_and_RMs_v2': [
        'armor_rm_score', 'eurus_prm2_scores', 'eurus_prm_scores', 'gpm_scores',
        'grm_gemma_scores', 'grm_llama32_scores', 'grm_scores', 'internlm2_scores',
        'internlm_scores', 'offset_bias_scores', 'qrm_gemma_scores', 'qrm_scores',
        'qwen25_math_scores', 'skywork_gemma_scores',
        'skyworks_scores', 'urm_scores'
    ],
    'anonymous_research/GPQA_with_GPT-4o-mini_Samples_and_RMs_and_LM_Judges_v2': [
        'armor_rm_score', 'gpm_scores', 'grm_gemma_scores', 'grm_llama32_scores', 'grm_scores',
        'internlm_scores', 'offset_bias_scores', 'qrm_scores', 'skyworks_scores', 'urm_scores'
    ]
}

DATASET_TO_LM_JUDGES = {
    'anonymous_research/MATH_with_LM_Judges_and_Reward_Model_Results_V2': [
        'judge_qwen2-72b-instruct_verdicts', 'judge_qwen2.5-72b-instruct-turbo_verdicts', 
        'judge_qwq-32b-preview_verdicts', 'judge_nous-hermes-2-mixtral-8x7b-dpo_verdicts', 
        'judge_llama-3.1-nemotron-70b-instruct-hf_verdicts', 'judge_meta-llama-3.1-405b-instruct-turbo_verdicts', 
        'judge_gemma-2-27b-it_verdicts', 'judge_claude-3-5-sonnet-latest_verdicts', 
        'judge_llama-3.3-70b-instruct-turbo_verdicts', 'judge_gpt-4o_verdicts'
    ],
    'anonymous_research/MATH_with_LM_Judges_and_Reward_Model_Results_V3': [
        'judge_qwen2-72b-instruct_verdicts', 'judge_qwen2.5-72b-instruct-turbo_verdicts', 
        'judge_qwq-32b-preview_verdicts', 'judge_nous-hermes-2-mixtral-8x7b-dpo_verdicts', 
        'judge_llama-3.1-nemotron-70b-instruct-hf_verdicts', 'judge_meta-llama-3.1-405b-instruct-turbo_verdicts', 
        'judge_gemma-2-27b-it_verdicts', 'judge_claude-3-5-sonnet-latest_verdicts', 
        'judge_llama-3.3-70b-instruct-turbo_verdicts', 'judge_gpt-4o_verdicts'
    ],
    'anonymous_research/GPQA_with_GPT-4o-mini_Samples_and_RMs_and_LM_Judges_v1': [
         'judge_gpt-4o_verdicts', 'judge_llama-3.3-70b-instruct-turbo_verdicts',
         'judge_meta-llama-3.1-405b-instruct-turbo_verdicts', 'judge_nous-hermes-2-mixtral-8x7b-dpo_verdicts',
         'judge_qwen2-72b-instruct_verdicts', 'judge_qwen2.5-72b-instruct-turbo_verdicts'
    ],
    'anonymous_research/CodeContests_Llama_70B_with_LM_Judges_and_RMs_v1': [
         'judge_claude-3-5-sonnet-latest_verdicts', 'judge_gemma-2-27b-it_verdicts', 
         'judge_gpt-4o_verdicts', 'judge_llama-3.1-nemotron-70b-instruct-hf_verdicts',
         'judge_llama-3.3-70b-instruct-turbo_verdicts', 'judge_meta-llama-3.1-405b-instruct-turbo_verdicts', 
         'judge_mixtral-8x22b-instruct-v0.1_verdicts', 'judge_nous-hermes-2-mixtral-8x7b-dpo_verdicts',
         'judge_qwen2-72b-instruct_verdicts', 'judge_qwen2.5-72b-instruct-turbo_verdicts',
         'judge_qwq-32b-preview_verdicts', 'judge_wizardlm-2-8x22b_verdicts',
    ],
    'anonymous_research/AIMO_GPT-4o-mini_with_LM_Judges_and_RMs_v2': [
         'judge_gpt-4o_verdicts', 'judge_llama-3.1-nemotron-70b-instruct-hf_verdicts',
         'judge_llama-3.3-70b-instruct-turbo_verdicts', 'judge_meta-llama-3.1-405b-instruct-turbo_verdicts',
         'judge_nous-hermes-2-mixtral-8x7b-dpo_verdicts', 'judge_qwen2-72b-instruct_verdicts',
         'judge_qwen2.5-72b-instruct-turbo_verdicts', 'judge_qwq-32b-preview_verdicts'
    ],
    'anonymous_research/GPQA_with_GPT-4o-mini_Samples_and_RMs_and_LM_Judges_v2': [
        'judge_gpt-4o_verdicts',
        'judge_llama-3.3-70b-instruct-turbo_verdicts',
        'judge_meta-llama-3.1-405b-instruct-turbo_verdicts',
        'judge_nous-hermes-2-mixtral-8x7b-dpo_verdicts',
        'judge_qwen2-72b-instruct_verdicts',
        'judge_qwen2.5-72b-instruct-turbo_verdicts'
    ]
}



DATASET_TO_COMBINATIONS_RESULTS_PATH = {
    'anonymous_research/MATH_with_LM_Judges_and_Reward_Model_Results_V2': 'combinations/math_v2',
    'anonymous_research/MATH_with_LM_Judges_and_Reward_Model_Results_V3': 'combinations/math_v3',
    'anonymous_research/GPQA_with_GPT-4o-mini_Samples_and_RMs_and_LM_Judges_v1': 'combinations/gpqa_v1',
    'anonymous_research/CodeContests_Llama_70B_with_LM_Judges_and_RMs_v1': 'combinations/codecontests',
    'anonymous_research/AIMO_GPT-4o-mini_with_LM_Judges_and_RMs_v2': 'combinations/aimo',
    'anonymous_research/GPQA_with_GPT-4o-mini_Samples_and_RMs_and_LM_Judges_v2': 'combinations/gpqa_v2',
    'anonymous_research/MATH_with_RM_LJ_UT_v1': 'combinations/math_rm_lj_ut_v1',
    'anonymous_research/MATH-500_with_Llama_3.1_8B_Instruct': 'combinations/math_500_8b',
    'anonymous_research/MMLU-College_with_Llama_3.1_8B_Instruct': 'combinations/mmlu_college_8b',
    'anonymous_research/BBH_with_Llama_3.1_8B_Instruct': 'combinations/bbh_8b',
    'anonymous_research/AlpacaEval_with_Llama_3.1_8B_Instruct': 'combinations/alpaca_eval_8b',
    'anonymous_research/MMLU-Pro_with_Llama_3.1_8B_Instruct': 'combinations/mmlu_pro_8b',
    'anonymous_research/AIMO_with_Llama_3.1_8B_Instruct': 'combinations/aimo_8b',
    'anonymous_research/GPQA_with_Llama_3.1_8B_Instruct': 'combinations/gpqa_8b',
    'anonymous_research/ArenaHard_with_Llama_3.1_8B_Instruct': 'combinations/arena_hard_8b',
    'anonymous_research/MATH500_with_Llama_3.1_70B_Instruct': 'combinations/math_500_70b',
    'anonymous_research/AIMO_with_Llama_3.1_70B_Instruct': 'combinations/aimo_70b',
    'anonymous_research/MMLU_with_Llama_3.1_70B_Instruct': 'combinations/mmlu_70b',
}


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)


class VerifierDataset:
    def __init__(self, verifiers, dataset_path, reward_threshold, ws_threshold_strategy, ws_class_balance, hard_problem_cutoff, verbose=False):        
        self.dataset_path = dataset_path 
        try:
            dataset = load_from_disk(self.dataset_path)
        except Exception as e:
            dataset = load_dataset(self.dataset_path)['data']
       
        self.subsampled_verifiers = True if verifiers is not None else False
        self.reward_models = [rm for rm in DATASET_TO_REWARD_MODELS[self.dataset_path] if rm in verifiers] if verifiers is not None else DATASET_TO_REWARD_MODELS[self.dataset_path]
        self.lm_judges = [judge for judge in DATASET_TO_LM_JUDGES[self.dataset_path] if judge in verifiers] if verifiers is not None else DATASET_TO_LM_JUDGES[self.dataset_path]
        self.verifiers = self.reward_models + self.lm_judges

        self.samples = []
        for i in range(len(dataset)):
            example = dataset[i]
            correctness = example.get('answer_correct', example.get('correctness', None))
            # select hard problems if args.hard_problem_cutoff is set 
            if hard_problem_cutoff is not None and np.array(correctness).mean() > hard_problem_cutoff:
                continue 

            self.samples.append({
                'samples': example['samples'],
                'correctness': correctness,
                'extracted_answers': example.get('extracted_answers', None),
                **{k: example[k] for k in example.keys()}
            })

        self.samples = normalize_reward_model_scores(self.samples, self.reward_models)

        if verbose:
            logger.info("\nAnalyzing dataset statistics...")
            dataset_stats = analyze_dataset_stats(self.samples)
            log_dataset_stats(dataset_stats)
            
            # Analyze missing data
            logger.info("\nAnalyzing missing data patterns...")
            missing_data_stats = analyze_missing_data(self.samples, self.reward_models, self.lm_judges)
            log_missing_data_analysis(missing_data_stats)

        self.ws_threshold_strategy = ws_threshold_strategy 
        if len(self.reward_models) > 0:
            self.ws_reward_thresholds, self.reward_thresholds = _compute_thresholds(self.samples, reward_threshold, self.ws_threshold_strategy, self.reward_models, ws_class_balance)

        self.score_matrices = []
        self.vote_matrices = []
        self.true_labels = []
        for i, sample in enumerate(self.samples):
            if 'correctness' not in sample:
                continue 

            votes = []
            scores = []
            for j, rm in enumerate(self.reward_models):
                if rm in sample:
                    s = sample[rm]
                    thresh = self.ws_reward_thresholds[j, i] 
                    v = np.array([1 if (x is not None and x > thresh) else (0 if x is not None else None) for x in s])
                    votes.append(v)
                    scores.append(s)

            for judge in self.lm_judges:
                if judge in sample:
                    s = sample[judge]
                    s = np.array([int(x) if isinstance(x, bool) else x for x in s])
                    votes.append(s)
                    scores.append(s) 

            votes = np.array(votes).T 
            scores = np.array(scores).T 
            ground_truth = np.array(sample['correctness']).astype(int)

            self.score_matrices.append(scores)
            self.vote_matrices.append(votes)
            self.true_labels.append(ground_truth )

        self.score_matrices = np.array(self.score_matrices)
        self.vote_matrices = np.array(self.vote_matrices)

        self.score_matrices_with_nones = copy.deepcopy(self.score_matrices)
        self.vote_matrices_with_nones = copy.deepcopy(self.vote_matrices)

        # NOTE: by setting None to 0s, we effectively ignore those verifier scores when ensembling.
        self.score_matrices = np.where(self.score_matrices == None, 0, self.score_matrices)
        self.vote_matrices = np.where(self.vote_matrices == None, 0, self.vote_matrices)

        self.true_labels = np.array(self.true_labels)
        self.n_problems, self.n_generations, self.n_verifiers = self.score_matrices.shape


def calculate_predictive_accuracies(vd: VerifierDataset) -> Dict[str, float]:
    """Calculate prediction accuracies. Does not evaluate when scores are None"""
    results = {v : {} for v in vd.verifiers}
    flattened_labels = vd.true_labels.flatten()
    for i in range(len(vd.verifiers)):
        flattened_none_scores = vd.vote_matrices_with_nones[:, :, i].flatten() 
        valid_score_idxs = np.where(flattened_none_scores != None)[0]

        results[vd.verifiers[i]]['accuracy'] = accuracy_score(flattened_labels[valid_score_idxs], flattened_none_scores[valid_score_idxs].astype(int))
        results[vd.verifiers[i]]['f1'] = f1_score(flattened_labels[valid_score_idxs], flattened_none_scores[valid_score_idxs].astype(int))
        results[vd.verifiers[i]]['precision'] = precision_score(flattened_labels[valid_score_idxs], flattened_none_scores[valid_score_idxs].astype(int))

        indices_0 = np.where(flattened_labels == 0)[0]
        indices_1 = np.where(flattened_labels == 1)[0]

        valid_indices_0 = np.array(list(set(valid_score_idxs).intersection(set(indices_0))))
        valid_indices_1 = np.array(list(set(valid_score_idxs).intersection(set(indices_1))))
        
        # Compute accuracy for this class
        results[vd.verifiers[i]]['tnr'] = accuracy_score(flattened_labels[valid_indices_0], flattened_none_scores[valid_indices_0].astype(int))
        results[vd.verifiers[i]]['tpr'] = accuracy_score(flattened_labels[valid_indices_1], flattened_none_scores[valid_indices_1].astype(int))

    return results


def calculate_coverage(vd: VerifierDataset):
    vd.true_labels.sum(axis=0)

def calculate_correlations(vd: VerifierDataset) -> Dict[str, Dict[str, Any]]:
    """Calculate correlations between model scores/verdicts and correctness."""
    # Define models and judges    
    correlations = {
        "reward_models": {},
        "lm_judges": {}
    }
    
    # Calculate reward model correlations
    for rm in vd.reward_models:
        scores = []
        correctness = []
        
        for sample in vd.samples:
            if rm in sample and 'correctness' in sample:
                rm_scores = sample[rm]
                correct = sample['correctness']
                
                if isinstance(rm_scores, (list, np.ndarray)) and isinstance(correct, (list, np.ndarray)):
                    for score, is_correct in zip(rm_scores, correct):
                        if score is not None and isinstance(score, (int, float)) and is_correct is not None:
                            scores.append(score)
                            correctness.append(int(is_correct))
        
        if scores and correctness:
            correlation, p_value = pointbiserialr(correctness, scores)
            correlations["reward_models"][rm] = {
                "correlation": correlation,
                "p_value": p_value,
                "n_samples": len(scores)
            }
    
    # Calculate LM judge correlations
    for judge in vd.lm_judges:
        verdicts = []
        correctness = []
        
        for sample in vd.samples:
            if judge in sample and 'correctness' in sample:
                judge_verdicts = sample[judge]
                correct = sample['correctness']
                
                if isinstance(judge_verdicts, (list, np.ndarray)) and isinstance(correct, (list, np.ndarray)):
                    for verdict, is_correct in zip(judge_verdicts, correct):
                        if verdict is not None and verdict in [0, 1] and is_correct is not None:
                            verdicts.append(verdict)
                            correctness.append(int(is_correct))
        
        if verdicts and correctness:
            correlation, p_value = pearsonr(verdicts, correctness)
            correlations["lm_judges"][judge] = {
                "correlation": correlation,
                "p_value": p_value,
                "n_samples": len(verdicts)
            }
    
    return correlations

def log_correlations(correlations: Dict[str, Dict[str, Any]]):
    """Log correlation results in table format."""
    # Log reward model correlations
    logger.info("\nReward Model Score Correlations with Correctness:")
    rm_table_data = []
    for model, stats in sorted(correlations["reward_models"].items(), key=lambda x: abs(x[1]["correlation"]), reverse=True):
        rm_table_data.append([
            model,
            f"{stats['correlation']:.3f}",
            f"{stats['p_value']:.2e}",
            f"{stats['n_samples']:,}"  # Added thousands separator
        ])
    
    if rm_table_data:
        logger.info("\n" + tabulate(rm_table_data, 
                                  headers=["Model", "Correlation", "P-value", "N"],
                                  tablefmt="grid",
                                  floatfmt=".3f"))
    
    # Log LM judge correlations
    logger.info("\nLM Judge Verdict Correlations with Correctness:")
    lm_table_data = []
    for judge, stats in sorted(correlations["lm_judges"].items(), key=lambda x: abs(x[1]["correlation"]), reverse=True):
        lm_table_data.append([
            judge,
            f"{stats['correlation']:.3f}",
            f"{stats['p_value']:.2e}",
            f"{stats['n_samples']:,}"  # Added thousands separator
        ])
    
    if lm_table_data:
        logger.info("\n" + tabulate(lm_table_data,
                                  headers=["Judge", "Correlation", "P-value", "N"],
                                  tablefmt="grid",
                                  floatfmt=".3f"))

def calculate_majority_baseline(vd: VerifierDataset) -> Dict[str, Dict[str, Any]]:
    """Calculate accuracy of majority voting baseline."""
    baseline_metrics = {
        "total_valid": 0,
        "correct_predictions": 0,
        "incorrect_predictions": 0,
        "accuracy": 0.0,
        "true_positives": 0,
        "true_negatives": 0,
        "false_positives": 0,
        "false_negatives": 0,
        "precision": 0.0,
        "recall": 0.0,
        "f1": 0.0
    }
    
    for sample in vd.samples:
        if 'correctness' not in sample:
            continue
            
        # Skip NO_ANSWER responses when counting majority
        valid_answers = []
        for idx, answer in enumerate(sample.get('extracted_answer', [])):
            if answer != "NO_ANSWER" and idx < len(sample['correctness']):
                valid_answers.append((idx, sample['correctness'][idx]))
        
        if not valid_answers:
            continue
            
        baseline_metrics["total_valid"] += 1
        true_label = 1 if any(sample['correctness']) else 0
        
        # Only consider valid (non-NO_ANSWER) responses for majority voting
        if valid_answers:
            correct_count = sum(1 for _, is_correct in valid_answers if is_correct)
            prediction = 1 if correct_count > len(valid_answers) / 2 else 0
        else:
            prediction = 0  # Default to 0 if no valid answers
            
        if prediction == true_label:
            baseline_metrics["correct_predictions"] += 1
            if prediction == 1:
                baseline_metrics["true_positives"] += 1
            else:
                baseline_metrics["true_negatives"] += 1
        else:
            baseline_metrics["incorrect_predictions"] += 1
            if prediction == 1:
                baseline_metrics["false_positives"] += 1
            else:
                baseline_metrics["false_negatives"] += 1
    
    # Calculate final metrics
    if baseline_metrics["total_valid"] > 0:
        baseline_metrics["accuracy"] = baseline_metrics["correct_predictions"] / baseline_metrics["total_valid"]
        
        if baseline_metrics["true_positives"] + baseline_metrics["false_positives"] > 0:
            baseline_metrics["precision"] = baseline_metrics["true_positives"] / (baseline_metrics["true_positives"] + baseline_metrics["false_positives"])
        
        if baseline_metrics["true_positives"] + baseline_metrics["false_negatives"] > 0:
            baseline_metrics["recall"] = baseline_metrics["true_positives"] / (baseline_metrics["true_positives"] + baseline_metrics["false_negatives"])
        
        if baseline_metrics["precision"] + baseline_metrics["recall"] > 0:
            baseline_metrics["f1"] = 2 * baseline_metrics["precision"] * baseline_metrics["recall"] / (baseline_metrics["precision"] + baseline_metrics["recall"])
    
    return {"majority_classification": baseline_metrics}

def calculate_first_positive_lm_judge_baseline(vd: VerifierDataset, thresholds: List[float]) -> Dict[str, Dict[str, Any]]:
    """Calculate accuracy of first sample that gets majority positive LM judge verdicts for different thresholds."""
        
    results = {}
    
    for threshold in thresholds:
        stats = {
            "total_rows": 0,
            "correct_rows": 0,
            "accuracy": 0.0,
            "true_positives": 0,
            "true_negatives": 0,
            "false_positives": 0,
            "false_negatives": 0,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0
        }
        
        for sample in vd.samples:
            if 'correctness' not in sample:
                continue
                
            correctness = sample.get('correctness', [])
            if not isinstance(correctness, (list, np.ndarray)) or len(correctness) == 0:
                continue
                
            stats["total_rows"] += 1
            
            # Look for first sample with required percentage of positive verdicts
            found_positive = False
            selected_idx = None
            
            for idx in range(len(correctness)):
                positive_verdicts = 0
                total_valid_verdicts = 0
                
                for judge in vd.lm_judges:
                    if judge in sample and isinstance(sample[judge], (list, np.ndarray)):
                        verdicts = sample[judge]
                        if idx < len(verdicts) and verdicts[idx] is not None and verdicts[idx] in [0, 1]:
                            total_valid_verdicts += 1
                            if verdicts[idx] == 1:
                                positive_verdicts += 1
                
                if total_valid_verdicts > 0 and (positive_verdicts / total_valid_verdicts) >= threshold:
                    found_positive = True
                    selected_idx = idx
                    break
            
            # Make prediction and update stats
            prediction = 1 if found_positive else 0
            has_correct_answer = any(correctness)
            
            if found_positive:
                if correctness[selected_idx]:
                    stats["correct_rows"] += 1
                    stats["true_positives"] += 1
                else:
                    stats["false_positives"] += 1
            else:
                if not has_correct_answer:
                    stats["correct_rows"] += 1
                    stats["true_negatives"] += 1
                else:
                    stats["false_negatives"] += 1
        
        # Calculate metrics
        if stats["total_rows"] > 0:
            stats["accuracy"] = stats["correct_rows"] / stats["total_rows"]
            
            if stats["true_positives"] + stats["false_positives"] > 0:
                stats["precision"] = stats["true_positives"] / (stats["true_positives"] + stats["false_positives"])
            
            if stats["true_positives"] + stats["false_negatives"] > 0:
                stats["recall"] = stats["true_positives"] / (stats["true_positives"] + stats["false_negatives"])
            
            if stats["precision"] + stats["recall"] > 0:
                stats["f1"] = 2 * stats["precision"] * stats["recall"] / (stats["precision"] + stats["recall"])
        
        results[f"threshold_{threshold:.2f}"] = stats
    
    return results

def calculate_first_positive_reward_model_baseline(vd: VerifierDataset, thresholds: List[float]) -> Dict[str, Dict[str, Any]]:
    """Calculate accuracy of first sample that gets majority positive reward model scores for different thresholds."""
        
    results = {}
    
    for threshold in thresholds:
        stats = {
            "total_rows": 0,
            "correct_rows": 0,
            "accuracy": 0.0,
            "true_positives": 0,
            "true_negatives": 0,
            "false_positives": 0,
            "false_negatives": 0,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0
        }
        
        for i, sample in enumerate(vd.samples):
            if 'correctness' not in sample:
                continue
                
            correctness = sample.get('correctness', [])
            if not isinstance(correctness, (list, np.ndarray)) or len(correctness) == 0:
                continue
                
            stats["total_rows"] += 1
            
            # Look for first sample with required percentage of positive scores
            found_positive = False
            selected_idx = None
            
            for idx in range(len(correctness)):
                positive_scores = 0
                total_valid_scores = 0
                
                for j, rm in enumerate(vd.reward_models):
                    # allows for a static threshold and per-problem/per-verifier mean/median threshold
                    thresh = vd.reward_thresholds if type(vd.reward_thresholds) == float else vd.reward_thresholds[j, i] 

                    if rm in sample and isinstance(sample[rm], (list, np.ndarray)):
                        scores = sample[rm]
                        if idx < len(scores) and scores[idx] is not None and isinstance(scores[idx], (int, float)):
                            total_valid_scores += 1
                            if scores[idx] >= thresh:
                                positive_scores += 1
                
                if total_valid_scores > 0 and (positive_scores / total_valid_scores) >= threshold:
                    found_positive = True
                    selected_idx = idx
                    break
            
            # Make prediction and update stats
            prediction = 1 if found_positive else 0
            has_correct_answer = any(correctness)
            
            if found_positive:
                if correctness[selected_idx]:
                    stats["correct_rows"] += 1
                    stats["true_positives"] += 1
                else:
                    stats["false_positives"] += 1
            else:
                if not has_correct_answer:
                    stats["correct_rows"] += 1
                    stats["true_negatives"] += 1
                else:
                    stats["false_negatives"] += 1
        
        # Calculate metrics
        if stats["total_rows"] > 0:
            stats["accuracy"] = stats["correct_rows"] / stats["total_rows"]
            
            if stats["true_positives"] + stats["false_positives"] > 0:
                stats["precision"] = stats["true_positives"] / (stats["true_positives"] + stats["false_positives"])
            
            if stats["true_positives"] + stats["false_negatives"] > 0:
                stats["recall"] = stats["true_positives"] / (stats["true_positives"] + stats["false_negatives"])
            
            if stats["precision"] + stats["recall"] > 0:
                stats["f1"] = 2 * stats["precision"] * stats["recall"] / (stats["precision"] + stats["recall"])
        
        results[f"threshold_{threshold:.2f}"] = stats
    
    return results

def analyze_missing_data(samples: List[Dict], reward_models: List[str], lm_judges: List[str]) -> Dict[str, Any]:
    """Analyze patterns of missing data in samples for both judges and reward models."""
    # Calculate total number of possible verdicts/scores
    total_samples_per_row = [len(sample['samples']) for sample in samples]
    total_possible_values = sum(total_samples_per_row)
    
    missing_stats = {
        "total_rows": len(samples),
        "total_samples": total_possible_values,
        "samples_per_row": total_samples_per_row,
        "reward_models": {k: {'total_missing': 0, 'total_invalid': 0, 'total_present': 0} for k in reward_models},
        "lm_judges": {k: {'total_missing': 0, 'total_invalid': 0, 'total_present': 0} for k in lm_judges},
    }
    
    # Process each sample
    for sample in samples:
        num_samples = len(sample['samples'])
        
        # Check reward models
        for rm_field in missing_stats["reward_models"].keys():
            if rm_field not in sample:
                missing_stats["reward_models"][rm_field]["total_missing"] += num_samples
            else:
                scores = sample[rm_field]
                if not isinstance(scores, (list, np.ndarray)):
                    missing_stats["reward_models"][rm_field]["total_invalid"] += num_samples
                else:
                    for i in range(num_samples):
                        if i >= len(scores) or scores[i] is None or not isinstance(scores[i], (int, float)):
                            missing_stats["reward_models"][rm_field]["total_invalid"] += 1
                        else:
                            missing_stats["reward_models"][rm_field]["total_present"] += 1
                    
                    if len(scores) < num_samples:
                        missing_stats["reward_models"][rm_field]["total_missing"] += (num_samples - len(scores))
        
        # Check LM judges
        for judge_field in missing_stats["lm_judges"].keys():
            if judge_field not in sample:
                missing_stats["lm_judges"][judge_field]["total_missing"] += num_samples
            else:
                verdicts = sample[judge_field]
                if not isinstance(verdicts, (list, np.ndarray)):
                    missing_stats["lm_judges"][judge_field]["total_invalid"] += num_samples
                else:
                    for i in range(num_samples):
                        if i >= len(verdicts) or verdicts[i] is None or not isinstance(verdicts[i], (int, float)) or verdicts[i] not in [0, 1]:
                            missing_stats["lm_judges"][judge_field]["total_invalid"] += 1
                        else:
                            missing_stats["lm_judges"][judge_field]["total_present"] += 1
                    
                    if len(verdicts) < num_samples:
                        missing_stats["lm_judges"][judge_field]["total_missing"] += (num_samples - len(verdicts))
    
    # Calculate percentages
    for category in ["reward_models", "lm_judges"]:
        for model_name, stats in missing_stats[category].items():
            stats["missing_percentage"] = (stats["total_missing"] / total_possible_values) * 100
            stats["invalid_percentage"] = (stats["total_invalid"] / total_possible_values) * 100
            stats["present_percentage"] = (stats["total_present"] / total_possible_values) * 100
            stats["total_expected"] = total_possible_values
    
    return missing_stats



def log_missing_data_analysis(stats: Dict[str, Any]):
    """Log the missing data analysis results in table format."""
    # Dataset Overview Table
    logger.info("\nDataset Overview:")
    overview_data = [
        ["Total rows", stats['total_rows']],
        ["Total samples", stats['total_samples']],
        ["Average samples per row", f"{stats['total_samples'] / stats['total_rows']:.1f}"]
    ]
    logger.info("\n" + tabulate(overview_data, tablefmt="grid"))
    
    # Reward Model Coverage Table
    logger.info("\nReward Model Coverage:")
    rm_headers = ["Model", "Expected", "Present", "Missing", "Invalid"]
    rm_data = []
    for rm_name, rm_stats in stats["reward_models"].items():
        rm_data.append([
            rm_name,
            rm_stats['total_expected'],
            f"{rm_stats['total_present']} ({rm_stats['present_percentage']:.1f}%)",
            f"{rm_stats['total_missing']} ({rm_stats['missing_percentage']:.1f}%)",
            f"{rm_stats['total_invalid']} ({rm_stats['invalid_percentage']:.1f}%)"
        ])
    logger.info("\n" + tabulate(rm_data, headers=rm_headers, tablefmt="grid"))
    
    # LM Judge Coverage Table
    logger.info("\nLM Judge Coverage:")
    judge_headers = ["Judge", "Expected", "Present", "Missing", "Invalid"]
    judge_data = []
    for judge_name, judge_stats in stats["lm_judges"].items():
        judge_data.append([
            judge_name,
            judge_stats['total_expected'],
            f"{judge_stats['total_present']} ({judge_stats['present_percentage']:.1f}%)",
            f"{judge_stats['total_missing']} ({judge_stats['missing_percentage']:.1f}%)",
            f"{judge_stats['total_invalid']} ({judge_stats['invalid_percentage']:.1f}%)"
        ])
    logger.info("\n" + tabulate(judge_data, headers=judge_headers, tablefmt="grid"))

def calculate_individual_model_first_positive(vd: VerifierDataset, thresholds: List[float]) -> Dict[str, Any]:
    """Calculate performance metrics for first positive sample from each individual model."""
    performance_stats = {
        "reward_models_first_positive": {},
        "lm_judges_first_positive": {}
    }
    
    # Initialize reward model stats for each threshold
    for rm in vd.reward_models:
        performance_stats["reward_models_first_positive"][rm] = {}
        for threshold in thresholds:
            performance_stats["reward_models_first_positive"][rm][f"threshold_{threshold:.2f}"] = {
                "total_rows": 0,
                "correct_rows": 0,
                "accuracy": 0.0,
                "true_positives": 0,
                "true_negatives": 0,
                "false_positives": 0,
                "false_negatives": 0,
                "precision": 0.0,
                "recall": 0.0,
                "f1": 0.0
            }
    
    # Initialize LM judge stats
    for judge in vd.lm_judges:
        performance_stats["lm_judges_first_positive"][judge] = {
            "total_rows": 0,
            "correct_rows": 0,
            "accuracy": 0.0,
            "true_positives": 0,
            "true_negatives": 0,
            "false_positives": 0,
            "false_negatives": 0,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0
        }
    
    # For reward models
    for j, rm in enumerate(vd.reward_models):
        for i, example in enumerate(vd.samples):
            # allows for a static threshold and per-problem/per-verifier mean/median threshold
            thresh = vd.reward_thresholds if type(vd.reward_thresholds) == float else vd.reward_thresholds[j, i]
            if rm not in example or 'correctness' not in example:
                continue
                
            has_correct_answer = any(example['correctness'])
            
            # For each threshold
            for threshold in thresholds:
                found_positive = False
                selected_idx = None
                stats = performance_stats["reward_models_first_positive"][rm][f"threshold_{threshold:.2f}"]
                stats["total_rows"] += 1
                
                # Count positive scores for this threshold
                positive_scores = 0
                total_valid_scores = 0
                
                # Look through samples
                for idx, (score, correct) in enumerate(zip(example[rm], example['correctness'])):
                    if score is not None and isinstance(score, (int, float)):
                        total_valid_scores += 1
                        if score >= thresh:
                            positive_scores += 1
                            if not found_positive:  # Only update if we haven't found a positive yet
                                found_positive = True
                                selected_idx = idx
                
                # Check if we meet the threshold requirement
                if total_valid_scores > 0:
                    threshold_met = (positive_scores / total_valid_scores) >= threshold
                else:
                    threshold_met = False
                
                # Update stats based on threshold and first positive
                if threshold_met and found_positive:
                    if example['correctness'][selected_idx]:
                        stats["correct_rows"] += 1
                        stats["true_positives"] += 1
                    else:
                        stats["false_positives"] += 1
                else:
                    if not has_correct_answer:
                        stats["correct_rows"] += 1
                        stats["true_negatives"] += 1
                    else:
                        stats["false_negatives"] += 1
    
    # For LM judges
    for judge in vd.lm_judges:
        for example in vd.samples:
            if judge not in example or 'correctness' not in example:
                continue
                
            stats = performance_stats["lm_judges_first_positive"][judge]
            stats["total_rows"] += 1
            
            found_positive = False
            selected_idx = None
            has_correct_answer = any(example['correctness'])
            
            # Look through samples
            for idx, (verdict, correct) in enumerate(zip(example[judge], example['correctness'])):
                if verdict == 1:
                    found_positive = True
                    selected_idx = idx
                    break
            
            # Update stats based on first positive or lack thereof
            if found_positive:
                if example['correctness'][selected_idx]:
                    stats["correct_rows"] += 1
                    stats["true_positives"] += 1
                else:
                    stats["false_positives"] += 1
            else:
                if not has_correct_answer:
                    stats["correct_rows"] += 1
                    stats["true_negatives"] += 1
                else:
                    stats["false_negatives"] += 1
    
    # Calculate metrics

    # For reward models
    for rm_stats in performance_stats["reward_models_first_positive"].values():
        for threshold_stats in rm_stats.values():
            if threshold_stats["total_rows"] > 0:
                threshold_stats["accuracy"] = threshold_stats["correct_rows"] / threshold_stats["total_rows"]
                
                if threshold_stats["true_positives"] + threshold_stats["false_positives"] > 0:
                    threshold_stats["precision"] = threshold_stats["true_positives"] / (threshold_stats["true_positives"] + threshold_stats["false_positives"])
                
                if threshold_stats["true_positives"] + threshold_stats["false_negatives"] > 0:
                    threshold_stats["recall"] = threshold_stats["true_positives"] / (threshold_stats["true_positives"] + threshold_stats["false_negatives"])
                
                if threshold_stats["precision"] + threshold_stats["recall"] > 0:
                    threshold_stats["f1"] = 2 * threshold_stats["precision"] * threshold_stats["recall"] / (threshold_stats["precision"] + threshold_stats["recall"])
    
    # For LM judges
    for judge_stats in performance_stats["lm_judges_first_positive"].values():
        if judge_stats["total_rows"] > 0:
            judge_stats["accuracy"] = judge_stats["correct_rows"] / judge_stats["total_rows"]
            
            if judge_stats["true_positives"] + judge_stats["false_positives"] > 0:
                judge_stats["precision"] = judge_stats["true_positives"] / (judge_stats["true_positives"] + judge_stats["false_positives"])
            
            if judge_stats["true_positives"] + judge_stats["false_negatives"] > 0:
                judge_stats["recall"] = judge_stats["true_positives"] / (judge_stats["true_positives"] + judge_stats["false_negatives"])
            
            if judge_stats["precision"] + judge_stats["recall"] > 0:
                judge_stats["f1"] = 2 * judge_stats["precision"] * judge_stats["recall"] / (judge_stats["precision"] + judge_stats["recall"])
    
    return performance_stats

def calculate_rm_performance(vd: VerifierDataset) -> Dict[str, Any]:
    """Calculate performance metrics for individual reward models."""
        
    metrics = {
        "reward_models": {},
        "lm_judges": {}
    }
    
    # Calculate highest score performance for each reward model
    for rm in vd.reward_models:
        rm_stats = {
            "total_valid": 0,
            "correct_predictions": 0,
            "incorrect_predictions": 0,
            "accuracy": 0.0,
            "true_positives": 0,
            "true_negatives": 0,
            "false_positives": 0,
        "false_negatives": 0,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0
        }
        
        # Process each sample
        for sample in vd.samples:
            if rm not in sample or 'correctness' not in sample:
                continue
                
            scores = sample[rm]
            correctness = sample['correctness']
            
            if not isinstance(scores, (list, np.ndarray)) or not isinstance(correctness, (list, np.ndarray)):
                continue
                
            # Find index of highest score
            valid_scores = [(i, s) for i, s in enumerate(scores) 
                           if s is not None and isinstance(s, (int, float)) and i < len(correctness)]
            
            if not valid_scores:
                continue
                
            rm_stats["total_valid"] += 1
            max_score_idx = max(valid_scores, key=lambda x: x[1])[0]
            
            # Update metrics based on correctness of highest-scored sample
            if correctness[max_score_idx]:
                rm_stats["correct_predictions"] += 1
                rm_stats["true_positives"] += 1
            else:
                rm_stats["incorrect_predictions"] += 1
                rm_stats["false_positives"] += 1
        
        # Calculate final metrics
        if rm_stats["total_valid"] > 0:
            rm_stats["accuracy"] = rm_stats["correct_predictions"] / rm_stats["total_valid"]
            rm_stats["precision"] = rm_stats["true_positives"] / rm_stats["total_valid"]
            rm_stats["recall"] = rm_stats["true_positives"] / (rm_stats["true_positives"] + rm_stats["false_negatives"]) if rm_stats["true_positives"] + rm_stats["false_negatives"] > 0 else 0.0
            rm_stats["f1"] = 2 * rm_stats["precision"] * rm_stats["recall"] / (rm_stats["precision"] + rm_stats["recall"]) if rm_stats["precision"] + rm_stats["recall"] > 0 else 0.0
        
        metrics["reward_models"][rm] = rm_stats
    
    return metrics


def calculate_judge_performance(vd: VerifierDataset, tiebreaker='first') -> Dict[str, Any]:
    """Calculate performance metrics for individual LM judges."""
        
    metrics = {}
    
    # Calculate highest score performance for each reward model
    for rm in vd.lm_judges:
        judge_stats = {
            "total_valid": 0,
            "correct_predictions": 0,
            "incorrect_predictions": 0,
            "accuracy": 0.0,
            "true_positives": 0,
            "true_negatives": 0,
            "false_positives": 0,
        "false_negatives": 0,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0
        }
        
        # Process each sample
        for sample in vd.samples:
            if rm not in sample or 'correctness' not in sample:
                continue
                
            scores = sample[rm]
            correctness = sample['correctness']
            
            if not isinstance(scores, (list, np.ndarray)) or not isinstance(correctness, (list, np.ndarray)):
                continue
                
            # Find index of highest score
            valid_scores = [(i, s) for i, s in enumerate(scores) 
                           if s is not None and isinstance(s, (int, float)) and i < len(correctness)]
            
            if not valid_scores:
                continue
                
            judge_stats["total_valid"] += 1

            if tiebreaker == 'first':
                max_score_idx = max(valid_scores, key=lambda x: x[1])[0]
            elif tiebreaker == 'random':
                top_idxs = np.array([score[0] for score in valid_scores if score[1] == True])
                if len(top_idxs) == 0:
                    top_idxs = np.arange(len(scores))
                max_score_idx = np.random.choice(top_idxs)

            # Update metrics based on correctness of highest-scored sample
            if correctness[max_score_idx]:
                judge_stats["correct_predictions"] += 1
                judge_stats["true_positives"] += 1
            else:
                judge_stats["incorrect_predictions"] += 1
                judge_stats["false_positives"] += 1
        
        # Calculate final metrics
        if judge_stats["total_valid"] > 0:
            judge_stats["accuracy"] = judge_stats["correct_predictions"] / judge_stats["total_valid"]
            judge_stats["precision"] = judge_stats["true_positives"] / judge_stats["total_valid"]
            judge_stats["recall"] = judge_stats["true_positives"] / (judge_stats["true_positives"] + judge_stats["false_negatives"]) if judge_stats["true_positives"] + judge_stats["false_negatives"] > 0 else 0.0
            judge_stats["f1"] = 2 * judge_stats["precision"] * judge_stats["recall"] / (judge_stats["precision"] + judge_stats["recall"]) if judge_stats["precision"] + judge_stats["recall"] > 0 else 0.0
        
        rm_name = rm if tiebreaker == 'first' else f"{rm}_random"
        metrics[rm_name] = judge_stats
    
    return metrics

def naive_ensemble(vd: VerifierDataset, verifiers: List[str]) -> Tuple[float, float, float, float, int, int, int, int]:
    rm_scores = []
    for i, sample in enumerate(vd.samples):
        if 'correctness' not in sample:
            continue 
        scores = []
        for j, rm in enumerate(verifiers):
            if rm in sample:
                s = np.array(sample[rm])
                s = np.where(s == None, 0, s).astype(float)
                scores.append(s)
        scores = np.array(scores).T 
        rm_scores.append(scores)

    rm_scores = np.array(rm_scores)

    naive_ensemble = rm_scores.sum(axis=2)
    best_idx = naive_ensemble.argmax(axis=1)
    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()

    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0



def log_model_performance(stats: Dict[str, Any]):
    """Log the performance metrics for each model and judge in table format."""
    logger.info("\nIndividual Model Performance Metrics:")
    
    headers = [
        "Model",
        "Total Rows",
        "Accuracy",
        "Precision",
        "Recall",
        "F1",
        "TP",
        "TN",
        "FP",
        "FN"
    ]
    
    # Log reward model highest score performance
    if "reward_models" in stats:
        logger.info("\nReward Models - Highest Rewarded Sample Performance:")
        table_data = []
        for model_name, metrics in sorted(stats["reward_models"].items()):
            if metrics["total_valid"] > 0:
                table_data.append([
                    model_name,
                    metrics["total_valid"],
                    f"{metrics['accuracy']:.2%}",
                    f"{metrics['precision']:.2%}",
                    f"{metrics['recall']:.2%}",
                    f"{metrics['f1']:.2%}",
                    metrics["true_positives"],
                    metrics["true_negatives"],
                    metrics["false_positives"],
                    metrics["false_negatives"]
                ])
        if table_data:
            logger.info("\n" + tabulate(table_data, headers=headers, tablefmt="grid"))
    
    if "lm_judges" in stats:
        logger.info("LM Judges Performance:")
        table_data = []
        for model_name, metrics in sorted(stats["lm_judges"].items()):
            if metrics["total_valid"] > 0:
                table_data.append([
                    model_name,
                    metrics["total_valid"],
                    f"{metrics['accuracy']:.2%}",
                    f"{metrics['precision']:.2%}",
                    f"{metrics['recall']:.2%}",
                    f"{metrics['f1']:.2%}",
                    metrics["true_positives"],
                    metrics["true_negatives"],
                    metrics["false_positives"],
                    metrics["false_negatives"]
                ])
        if table_data:
            logger.info("\n" + tabulate(table_data, headers=headers, tablefmt="grid"))
    
    # Log baseline performance
    if "baselines" in stats:
        logger.info("\nBaseline Performance:")
        baseline_data = []
        for baseline_name, metrics in stats["baselines"].items():
            total = metrics.get("total_valid", metrics.get("total_rows", 0))
            if total > 0:
                baseline_data.append([
                    baseline_name,
                    total,
                    f"{metrics.get('accuracy', 0.0):.2%}",
                    f"{metrics.get('precision', 0.0):.2%}",
                    f"{metrics.get('recall', 0.0):.2%}",
                    f"{metrics.get('f1', 0.0):.2%}",
                    metrics.get("true_positives", 0),
                    metrics.get("true_negatives", 0),
                    metrics.get("false_positives", 0),
                    metrics.get("false_negatives", 0)
                ])
        if baseline_data:
            logger.info("\n" + tabulate(baseline_data, headers=headers, tablefmt="grid"))
    
def log_predictive_accuracies_with_ws(stats: Dict[str, Any]):
    """Log the predictive accuracies for each model and judge in table format."""
    logger.info("\nIndividual Model Performance Metrics:")
    
    headers = [
        "Verifier",
        "Accuracy",
        "WS Accuracy",
        "TPR",
        "WS TPR",
        "TNR",
        "WS TNR"
    ]

    logger.info("\nPredictive accuracies of each verifier:\n")
    table_data = []
    for verifier_name, metrics in stats.items():
        table_data.append([
            verifier_name,
            f"{metrics['accuracy']:.2%}",
            f"{metrics['ws_accuracy']:.2%}",
            f"{metrics['tpr']:.2%}",
            f"{metrics['ws_tpr']:.2%}",
            f"{metrics['tnr']:.2%}",
            f"{metrics['ws_tnr']:.2%}"
        ])

    if table_data:
        logger.info("\n" + tabulate(table_data, headers=headers, tablefmt="grid"))

    accs = [metrics['accuracy'] for _, metrics in stats.items()]
    ws_accs = [metrics['ws_accuracy'] for _, metrics in stats.items()]
    tpr = [metrics['tpr'] for _, metrics in stats.items()]
    ws_tpr = [metrics['ws_tpr'] for _, metrics in stats.items()]
    tnr = [metrics['tnr'] for _, metrics in stats.items()]
    ws_tnr = [metrics['ws_tnr'] for _, metrics in stats.items()]

    logger.info(f"\nPearson correlation of true vs ws accuracies: {pearsonr(accs, ws_accs)[0]}")
    logger.info(f"\nPearson correlation of true vs ws tpr: {pearsonr(tpr, ws_tpr)[0]}")
    logger.info(f"\nPearson correlation of true vs ws tnr: {pearsonr(tnr, ws_tnr)[0]}")

    logger.info(f"\nSpearman correlation of true vs ws accuracies: {spearmanr(accs, ws_accs)[0]}")
    logger.info(f"\nSpearman correlation of true vs ws tpr: {spearmanr(tpr, ws_tpr)[0]}")
    logger.info(f"\nSpearman correlation of true vs ws tnr: {spearmanr(tnr, ws_tnr)[0]}")

def log_predictive_accuracies(stats: Dict[str, Any]):
    """Log the predictive accuracies for each model and judge in table format."""
    logger.info("\nIndividual Model Performance Metrics:")
    
    headers = [
        "Verifier",
        "Accuracy",
        "Recall",
        "Precision",
    ]

    logger.info("\nPredictive accuracies of each verifier:\n")
    table_data = []
    for verifier_name, metrics in stats.items():
        table_data.append([
            verifier_name,
            f"{metrics['accuracy']:.2%}",
            f"{metrics['tpr']:.2%}",
            f"{metrics['precision']:.2%}",
        ])

    if table_data:
        logger.info("\n" + tabulate(table_data, headers=headers, tablefmt="grid"))


def calculate_first_positive_joint_baseline(vd: VerifierDataset, thresholds: List[float]) -> Dict[str, Dict[str, Any]]:
    """Calculate accuracy of first sample that gets majority positive verdicts from both LM judges AND reward models."""
        
    results = {}
    
    for threshold in thresholds:
        stats = {
            "total_rows": 0,
            "correct_rows": 0,
            "accuracy": 0.0,
            "true_positives": 0,
            "true_negatives": 0,
            "false_positives": 0,
            "false_negatives": 0,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0
        }
        
        for i, sample in enumerate(vd.samples):
            if 'correctness' not in sample:
                continue
                
            correctness = sample.get('correctness', [])
            if not isinstance(correctness, (list, np.ndarray)) or len(correctness) == 0:
                continue
                
            stats["total_rows"] += 1
            
            # Look for first sample that meets both thresholds
            found_positive = False
            selected_idx = None
            
            for idx in range(len(correctness)):
                # Check LM judges
                positive_verdicts = 0
                total_valid_verdicts = 0
                
                for judge in vd.lm_judges:
                    if judge in sample and isinstance(sample[judge], (list, np.ndarray)):
                        verdicts = sample[judge]
                        if idx < len(verdicts) and verdicts[idx] is not None and verdicts[idx] in [0, 1]:
                            total_valid_verdicts += 1
                            if verdicts[idx] == 1:
                                positive_verdicts += 1
                
                lm_judge_approval = (total_valid_verdicts > 0 and 
                                   (positive_verdicts / total_valid_verdicts) >= threshold)
                
                # Check reward models
                positive_scores = 0
                total_valid_scores = 0
                
                for j, rm in enumerate(vd.reward_models):
                    # allows for a static threshold and per-problem/per-verifier mean/median threshold
                    thresh = vd.reward_thresholds if type(vd.reward_thresholds) == float else vd.reward_thresholds[j, i]

                    if rm in sample and isinstance(sample[rm], (list, np.ndarray)):
                        scores = sample[rm]
                        if idx < len(scores) and scores[idx] is not None and isinstance(scores[idx], (int, float)):
                            total_valid_scores += 1
                            if scores[idx] >= thresh:
                                positive_scores += 1
                
                reward_model_approval = (total_valid_scores > 0 and 
                                       (positive_scores / total_valid_scores) >= threshold)
                
                # Check if both approve
                if lm_judge_approval and reward_model_approval:
                    found_positive = True
                    selected_idx = idx
                    break
            
            # Make prediction and update stats
            prediction = 1 if found_positive else 0
            has_correct_answer = any(correctness)
            
            if found_positive:
                if correctness[selected_idx]:
                    stats["correct_rows"] += 1
                    stats["true_positives"] += 1
                else:
                    stats["false_positives"] += 1
            else:
                if not has_correct_answer:
                    stats["correct_rows"] += 1
                    stats["true_negatives"] += 1
                else:
                    stats["false_negatives"] += 1
        
        # Calculate metrics
        if stats["total_rows"] > 0:
            stats["accuracy"] = stats["correct_rows"] / stats["total_rows"]
            
            if stats["true_positives"] + stats["false_positives"] > 0:
                stats["precision"] = stats["true_positives"] / (stats["true_positives"] + stats["false_positives"])
            
            if stats["true_positives"] + stats["false_negatives"] > 0:
                stats["recall"] = stats["true_positives"] / (stats["true_positives"] + stats["false_negatives"])
            
            if stats["precision"] + stats["recall"] > 0:
                stats["f1"] = 2 * stats["precision"] * stats["recall"] / (stats["precision"] + stats["recall"])
        
        results[f"threshold_{threshold:.2f}"] = stats
    
    return results

def calculate_first_sample_baseline(vd: VerifierDataset) -> Dict[str, Any]:
    """Calculate accuracy of always using the first sample."""
    first_sample_stats = {
        "total_rows": 0,
        "correct_rows": 0,
        "accuracy": 0.0,
        "true_positives": 0,
        "false_positives": 0,  # Changed from false_negatives
        "precision": 0.0
    }
    
    for sample in vd.samples:
        if 'correctness' not in sample:
            continue
            
        correctness = sample.get('correctness', [])
        extracted_answer = sample.get('extracted_answer', [])
        
        if not isinstance(correctness, (list, np.ndarray)) or len(correctness) == 0:
            continue
            
        # Skip if first answer is NO_ANSWER
        if len(extracted_answer) > 0 and extracted_answer[0] == "NO_ANSWER":
            continue
        
        first_sample_stats["total_rows"] += 1
        
        # Update stats - we always make a positive prediction for non-NO_ANSWER samples
        if correctness[0]:
            first_sample_stats["correct_rows"] += 1
            first_sample_stats["true_positives"] += 1
        else:
            first_sample_stats["false_positives"] += 1  # Changed from false_negatives
    
    # Calculate metrics
    if first_sample_stats["total_rows"] > 0:
        first_sample_stats["accuracy"] = first_sample_stats["correct_rows"] / first_sample_stats["total_rows"]
        first_sample_stats["precision"] = first_sample_stats["true_positives"] / first_sample_stats["total_rows"]
    
    return {"first_sample": first_sample_stats}

def normalize_reward_model_scores(samples: List[Dict], reward_models: List[str]) -> List[Dict]:
    """Normalize reward model scores across all samples for each model separately."""
    normalized_samples = copy.deepcopy(samples)
    for rm in reward_models:
        # Collect all valid scores for this reward model
        all_scores = []
        for sample in samples:
            if rm in sample:
                scores = sample[rm]
                if isinstance(scores, (list, np.ndarray)):
                    all_scores.extend([s for s in scores if s is not None and isinstance(s, (int, float))])
        
        if not all_scores:
            continue
            
        # Calculate normalization parameters for this reward model
        min_score = min(all_scores)
        max_score = max(all_scores)
        
        if min_score == max_score:
            # If all scores are the same, set them all to 0.5
            normalized_value = 0.5
        else:
            # Normalize each score in each sample
            for sample in normalized_samples:
                if rm in sample and isinstance(sample[rm], (list, np.ndarray)):
                    sample[rm] = [
                        (s - min_score) / (max_score - min_score)
                        if s is not None and isinstance(s, (int, float))
                        else None
                        for s in sample[rm]
                    ]
    
    return normalized_samples


def _get_top_k_verifiers(verifier_accuracies: Dict[str, float], k: int) -> List[str]:
    top_k = sorted(verifier_accuracies.items(), key=lambda x: x[1], reverse=True)[:k]
    return [v for (v, _) in top_k]


def _format_results(acc, prec, rec, f1, tp, tn, fp, fn, n_problems):
    results = {
        "total_valid": n_problems,
        "correct_predictions": int(acc * n_problems),
        "incorrect_predictions": int((1-acc) * n_problems),
        "accuracy": acc,
        "true_positives": tp,
        "true_negatives": tn,
        "false_positives": fp,
        "false_negatives": fn,
        "precision": prec,
        "recall": rec,
        "f1": f1 
    }
    return results 

def calculate_baselines(vd: VerifierDataset) -> Dict[str, Any]:
    """Calculate baseline metrics including majority voting and weak supervision.
    
    Args:
        vd: verifier dataset
    """

    rm_stats = calculate_rm_performance(vd)['reward_models']
    rm_accuracies = {rm : stats['accuracy'] for rm, stats in rm_stats.items()}
    judge_stats = calculate_judge_performance(vd)
    judge_accuracies = {judge : stats['accuracy'] for judge, stats in judge_stats.items()}
    
    highest_rm_score_baseline = {}
    highest_lm_agreement_baseline = {}

    if not vd.subsampled_verifiers:
        acc, prec, rec, f1, tp, tn, fp, fn = naive_ensemble(vd, vd.reward_models) 
        highest_rm_score_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

        acc, prec, rec, f1, tp, tn, fp, fn = naive_ensemble(vd, vd.lm_judges)
        highest_lm_agreement_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn = naive_ensemble(vd, vd.verifiers)
    highest_joint_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    
    top3_rm_score_baseline = {}
    if len(rm_accuracies) > 3:
        top3_rms = _get_top_k_verifiers(rm_accuracies, 3)
        acc, prec, rec, f1, tp, tn, fp, fn = naive_ensemble(vd, top3_rms)
        top3_rm_score_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    top5_rm_score_baseline = {}
    if len(rm_accuracies) > 5:
        top5_rms = _get_top_k_verifiers(rm_accuracies, 5)
        acc, prec, rec, f1, tp, tn, fp, fn = naive_ensemble(vd, top5_rms)
        top5_rm_score_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    top3_lm_agreement_baseline = {}
    if len(judge_accuracies) > 3:
        top3_judges = _get_top_k_verifiers(judge_accuracies, 3)
        acc, prec, rec, f1, tp, tn, fp, fn = naive_ensemble(vd, top3_judges)
        top3_lm_agreement_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    top5_lm_agreement_baseline = {}
    if len(judge_accuracies) > 5:
        top5_judges = _get_top_k_verifiers(judge_accuracies, 5)
        acc, prec, rec, f1, tp, tn, fp, fn = naive_ensemble(vd, top5_judges)
        top5_lm_agreement_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)


    # Get correlation weighted ensemble results
    acc, prec, rec, f1, tp, tn, fp, fn = evaluate_correlation_weighted_ensemble(vd)
    correlation_weighted_ensemble = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn = selection_accuracy_weighted_ensemble(vd, rm_accuracies, judge_accuracies)
    selection_accuracy_weighted_ensemble_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn = accuracy_weighted_ensemble(vd)
    accuracy_weighted_ensemble_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn = precision_weighted_ensemble(vd)
    precision_weighted_ensemble_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn = recall_weighted_ensemble(vd)
    recall_weighted_ensemble_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn = f1_weighted_ensemble(vd)
    f1_weighted_ensemble_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn, nb_pred_results = naive_bayes(vd)
    naive_bayes_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    # Get RM-only ensemble results
    rm_correlation_weighted_ensemble = {}
    lm_judge_alignment_weighted_ensemble = {}

    if not vd.subsampled_verifiers:
        acc, prec, rec, f1, tp, tn, fp, fn = evaluate_rm_correlation_weighted_ensemble(vd)
        rm_correlation_weighted_ensemble = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

        # Get LM-only ensemble results
        acc, prec, rec, f1, tp, tn, fp, fn = evaluate_lm_judge_alignment_weighted_ensemble(vd)
        lm_judge_alignment_weighted_ensemble = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    if vd.subsampled_verifiers:
        search_args = {"search_size": None, "grid_size": 21}
    else:
        search_args = {"search_size": 10, "grid_size": None}
    
    acc, prec, rec, f1, tp, tn, fp, fn, best_weights = search_weighted_ensemble(vd, **search_args)
    search_weighted_ensemble_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    acc, prec, rec, f1, tp, tn, fp, fn, lr_pred_results = logistic_regression_ensemble(vd)
    logistic_regression_ensemble_baseline = _format_results(acc, prec, rec, f1, tp, tn, fp, fn, vd.n_problems)

    # At the end, combine all baselines into a single dictionary:
    return {
        "highest_rm_score": highest_rm_score_baseline,
        "highest_lm_agreement": highest_lm_agreement_baseline,
        "highest_joint_score": highest_joint_baseline,
        "top3_rm_score": top3_rm_score_baseline,
        "top5_rm_score": top5_rm_score_baseline,
        "top3_lm_agreement": top3_lm_agreement_baseline,
        "top5_lm_agreement": top5_lm_agreement_baseline,
        "correlation_weighted_ensemble": correlation_weighted_ensemble,
        "rm_correlation_weighted_ensemble": rm_correlation_weighted_ensemble,
        "lm_judge_alignment_weighted_ensemble": lm_judge_alignment_weighted_ensemble,
        "search_weighted_ensemble": search_weighted_ensemble_baseline,
        "logistic_regression_ensemble": logistic_regression_ensemble_baseline,
        "selection_accuracy_weighted_ensemble": selection_accuracy_weighted_ensemble_baseline,
        f"prediction_accuracy_weighted_ensemble_{vd.ws_threshold_strategy}": accuracy_weighted_ensemble_baseline,
        f"prediction_precision_weighted_ensemble_{vd.ws_threshold_strategy}": precision_weighted_ensemble_baseline,
        f"prediction_recall_weighted_ensemble_{vd.ws_threshold_strategy}": recall_weighted_ensemble_baseline,
        f"prediction_f1_weighted_ensemble_{vd.ws_threshold_strategy}": f1_weighted_ensemble_baseline,
        f"naive_bayes_{vd.ws_threshold_strategy}": naive_bayes_baseline
    }

def analyze_dataset_stats(samples: List[Dict]) -> Dict[str, Any]:
    """Analyze basic dataset statistics."""
    total_rows = len(samples)
    total_samples = sum(len(sample.get('samples', [])) for sample in samples)
    avg_samples = total_samples / total_rows if total_rows > 0 else 0
    
    # Count rows that have at least one correct sample
    solvable_rows = sum(
        1 for sample in samples 
        if 'correctness' in sample and any(sample['correctness'])
    )
    
    return {
        "total_rows": total_rows,
        "solvable_rows": solvable_rows,
        "total_samples": total_samples,
        "avg_samples_per_row": avg_samples
    }

def log_dataset_stats(stats: Dict[str, Any]):
    """Log dataset statistics in table format."""
    logger.info("\nDataset Overview:")
    table_data = [
        ["Total rows", stats["total_rows"]],
        ["Rows with correct sample", stats["solvable_rows"]],
        ["Total samples", stats["total_samples"]],
        ["Average samples per row", f"{stats['avg_samples_per_row']:.1f}"]
    ]
    logger.info("\n" + tabulate(table_data, tablefmt="grid"))

def evaluate_reward_model(samples: List[Dict], rm_name: str, threshold: float) -> Tuple[float, float, float, float, int, int, int, int]:
    """Evaluate individual reward model performance."""
    predictions = []
    ground_truth = []
    
    for sample in samples:
        if rm_name in sample and 'correctness' in sample:
            scores = sample[rm_name]
            correctness = sample['correctness']
            
            # Get highest scoring sample
            if len(scores) > 0 and len(correctness) > 0:
                best_idx = np.argmax(scores)
                predictions.append(1 if scores[best_idx] > threshold else 0)
                ground_truth.append(correctness[best_idx])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn

def evaluate_lm_judge(samples: List[Dict], judge_name: str) -> Tuple[float, float, float, float, int, int, int, int]:
    """Evaluate individual LM judge performance."""
    predictions = []
    ground_truth = []
    
    for sample in samples:
        if judge_name in sample and 'correctness' in sample:
            verdicts = sample[judge_name]
            correctness = sample['correctness']
            
            # Get first positive verdict
            if len(verdicts) > 0 and len(correctness) > 0:
                positive_indices = [i for i, v in enumerate(verdicts) if v == 1]
                if positive_indices:
                    predictions.append(1)
                    ground_truth.append(correctness[positive_indices[0]])
                else:
                    predictions.append(0)
                    ground_truth.append(correctness[0])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn

def evaluate_highest_rm_score(samples: List[Dict], reward_models: List[str]) -> Tuple[float, float, float, float, int, int, int, int]:
    """Evaluate highest reward model score baseline."""
    predictions = []
    ground_truth = []
    
    for sample in samples:
        if 'correctness' in sample:
            # Get max score across all RMs for each sample
            max_scores = []
            for rm in reward_models:
                if rm in sample:
                    max_scores.append(max(sample[rm]))
            
            if max_scores:
                best_rm_idx = np.argmax(max_scores)
                best_rm = reward_models[best_rm_idx]
                best_sample_idx = np.argmax(sample[best_rm])
                
                predictions.append(1)  # Predict the highest scoring sample
                ground_truth.append(sample['correctness'][best_sample_idx])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn

def evaluate_highest_lm_agreement(samples: List[Dict], lm_judges: List[str]) -> Tuple[float, float, float, float, int, int, int, int]:
    """Evaluate highest LM agreement baseline."""
    predictions = []
    ground_truth = []
    
    for sample in samples:
        if 'correctness' in sample:
            # Count positive verdicts for each sample
            positive_counts = np.zeros(len(sample['correctness']))
            for judge in lm_judges:
                if judge in sample:
                    verdicts = sample[judge]
                    positive_counts += np.array(verdicts)
            
            if len(positive_counts) > 0:
                best_idx = np.argmax(positive_counts)
                predictions.append(1)  # Predict sample with most positive verdicts
                ground_truth.append(sample['correctness'][best_idx])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn

def weighted_ensemble_selection(samples: List[Dict], reward_models: List[str], 
                              lm_judges: List[str], rm_correlations: Dict[str, float],
                              judge_accuracies: Dict[str, float]) -> Tuple[float, float, float, float, int, int, int, int]:
    """Evaluate weighted ensemble selection."""
    predictions = []
    ground_truth = []
    
    for sample in samples:
        if 'correctness' in sample:
            # Calculate weighted scores for each sample
            weighted_scores = np.zeros(len(sample['correctness']))
            
            # Add RM scores
            for rm in reward_models:
                if rm in sample and rm in rm_correlations:
                    weight = max(0, rm_correlations[rm])  # Only use positive correlations
                    weighted_scores += weight * np.array(sample[rm])
            
            # Add LM verdicts
            for judge in lm_judges:
                if judge in sample and judge in judge_accuracies:
                    weight = max(0, judge_accuracies[judge] - 0.5)  # Only use judges better than random
                    weighted_scores += weight * np.array(sample[judge])
            
            if len(weighted_scores) > 0:
                best_idx = np.argmax(weighted_scores)
                predictions.append(1)  # Predict best weighted score
                ground_truth.append(sample['correctness'][best_idx])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn



def logistic_regression_ensemble(vd: VerifierDataset) -> Tuple[float, float, float, float, int, int, int, int]:
    flattened_score_matrices = vd.score_matrices.reshape((-1, vd.n_verifiers))
    flattened_labels = vd.true_labels.flatten()
    clf = LogisticRegression(random_state=0).fit(flattened_score_matrices, flattened_labels)

    probs = clf.predict_proba(flattened_score_matrices)[:, 1]
    probs = probs.reshape(vd.n_problems, vd.n_generations)

    preds = np.round(probs)
    pred_acc = accuracy_score(vd.true_labels.flatten(), preds.flatten())
    pred_recall = recall_score(vd.true_labels.flatten(), preds.flatten())
    pred_precision = precision_score(vd.true_labels.flatten(), preds.flatten())

    pred_results = {
        'acc': pred_acc,
        'recall': pred_recall, 
        'precision': pred_precision
    }

    best_idx = probs.argmax(axis=1)
    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()

    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0, pred_results


def selection_accuracy_weighted_ensemble(vd: VerifierDataset, rm_accuracies, judge_accuracies) -> Tuple[float, float, float, float, int, int, int, int]:
    weights = np.array(list(rm_accuracies.values()) + list(judge_accuracies.values()))
    weighted_ensemble = vd.score_matrices.dot(weights)
    best_idx = weighted_ensemble.argmax(axis=1)

    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()
    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0

def accuracy_weighted_ensemble(vd: VerifierDataset) -> Tuple[float, float, float, float, int, int, int, int]:
    results = calculate_predictive_accuracies(vd)
    weights = np.array([res['accuracy'] for v, res in results.items()]) 
    weights = np.maximum(0, weights - 0.5)
    weights /= weights.sum()
    weighted_ensemble = vd.score_matrices.dot(weights)
    best_idx = weighted_ensemble.argmax(axis=1)

    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()
    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0

def precision_weighted_ensemble(vd: VerifierDataset) -> Tuple[float, float, float, float, int, int, int, int]:
    results = calculate_predictive_accuracies(vd)
    weights = np.array([res['precision'] for v, res in results.items()]) 
    weights = np.maximum(0, weights - 0.5)
    weighted_ensemble = vd.score_matrices.dot(weights)
    best_idx = weighted_ensemble.argmax(axis=1)

    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()
    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0

def naive_bayes(vd: VerifierDataset) -> Tuple[float, float, float, float, int, int, int, int]:
    results = calculate_predictive_accuracies(vd)

    tpr = np.array([res['tpr'] for v, res in results.items()]) 
    tnr = np.array([res['tnr'] for v, res in results.items()]) 

    # Compute Pr(feature_i = 1 | y = 0) as (1 - TNR)
    fpr = 1 - tnr

    # Compute Pr(feature_i = 0 | y = 1) as (1 - TPR)
    fnr = 1 - tpr

    # Calculate likelihood for y = 1 across all problems and samples
    likelihood_y1 = np.prod(
        (vd.vote_matrices * tpr + (1 - vd.vote_matrices) * fnr), axis=2
    )  # (problems x samples)

    # Calculate likelihood for y = 0 across all problems and samples
    likelihood_y0 = np.prod(
        (vd.vote_matrices * fpr + (1 - vd.vote_matrices) * tnr), axis=2
    )  # (problems x samples)

    # Compute posterior probabilities for y = 1

    cb = vd.true_labels.mean()

    posterior_y1 = likelihood_y1 * cb
    posterior_y0 = likelihood_y0 * (1 - cb)

    prob_y1_given_features = posterior_y1 / (posterior_y1 + posterior_y0)

    preds = np.round(prob_y1_given_features)
    pred_acc = accuracy_score(vd.true_labels.flatten(), preds.flatten())
    pred_recall = recall_score(vd.true_labels.flatten(), preds.flatten())
    pred_precision = precision_score(vd.true_labels.flatten(), preds.flatten())

    pred_results = {
        'acc': pred_acc,
        'recall': pred_recall, 
        'precision': pred_precision
    }

    # Find the index of the sample with the highest posterior probability for each problem
    best_idx = np.argmax(prob_y1_given_features, axis=1)  # (problems)

    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()
    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0, pred_results


def recall_weighted_ensemble(vd: VerifierDataset) -> Tuple[float, float, float, float, int, int, int, int]:
    results = calculate_predictive_accuracies(vd)
    weights = np.array([res['tpr'] for v, res in results.items()]) 
    weights = np.maximum(0, weights - 0.5)
    weighted_ensemble = vd.score_matrices.dot(weights)
    best_idx = weighted_ensemble.argmax(axis=1)

    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()
    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0

def f1_weighted_ensemble(vd: VerifierDataset) -> Tuple[float, float, float, float, int, int, int, int]:
    results = calculate_predictive_accuracies(vd)
    weights = np.array([res['f1'] for v, res in results.items()]) 
    weights = np.maximum(0, weights - 0.5)
    weighted_ensemble = vd.score_matrices.dot(weights)
    best_idx = weighted_ensemble.argmax(axis=1)

    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()
    precision = accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    return accuracy, precision, recall, f1, accuracy*vd.n_problems, 0, vd.n_problems*(1 - accuracy), 0


def search_weighted_ensemble(
    vd: VerifierDataset, search_size: int = None, grid_size: int = None
) -> Tuple[float, float, float, float, int, int, int, int]:
    """
    Evaluate best weighted ensemble found using a grid search. We assume that the entire dataset's labels are used to determine the best weights.  
    Returns:
        accuracy, precision, recall, f1, true_positives, true_negatives, false_positives, false_negatives
    """

    results = {}

    if search_size is not None:
        search_weights = np.random.random((search_size, vd.n_verifiers))
    elif grid_size is not None:
        sweep = np.linspace(0, 1, grid_size) 
        search_weights = list(product(sweep, repeat=vd.n_verifiers))
    else:
        raise ValueError("Either search_size or grid_size must be set.")

    for weights in tqdm(search_weights):
        if sum(weights) == 0:
            continue 

        weights = np.array(weights)
        weights /= sum(weights) # normalize (this won't change using the weighted ensemble to rank)

        if 1 in weights:
            continue # we don't consider [0, 0, .., 1] cases, which are equivalent to just using one verifier.

        weights_key = "_".join(np.round(weights, 3).astype(str))
        if weights_key in results:
            continue 

        weighted_ensemble = vd.score_matrices.dot(weights)
        best_idx = weighted_ensemble.argmax(axis=1)
    
        accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()
        results[weights_key] = accuracy 

    results = {k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)} # sort by largest acc 

    sorted_accs = list(results.values())
    sorted_weights = list(results.keys())

    best_accuracy = sorted_accs[0]
    best_weights = [float(w) for w in sorted_weights[0].split("_")]

    precision = best_accuracy 
    recall = 1.0 
    f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0 

    print(f"Number of weight combinations searched: {len(results)}.\n Best weights: {best_weights}")

    print(f"Top 10 weights: \n{np.array([float(w) for weights in sorted_weights[:10] for w in weights.split('_')]).reshape(10, -1)}")

    return best_accuracy, precision, recall, f1, best_accuracy*vd.n_problems, 0, vd.n_problems*(1 - best_accuracy), 0, best_weights


def evaluate_correlation_weighted_ensemble(
    vd: VerifierDataset, 
) -> Tuple[float, float, float, float, int, int, int, int]:
    """
    Evaluate ensemble weighted by point-biserial correlation (RMs) and verdict alignment (LM judges).
    Returns:
        accuracy, precision, recall, f1, true_positives, true_negatives, false_positives, false_negatives
    """
    # Calculate weights for each verifier
    rm_weights = {}  # Point-biserial correlations
    judge_weights = {}  # Verdict alignments
    
    # Calculate RM weights using point-biserial correlation

    for rm in vd.reward_models:
        scores = np.array([sample[rm] for sample in vd.samples]).flatten()
        correlation = np.corrcoef(scores, vd.true_labels.flatten())[0, 1]
        rm_weights[rm] = max(0, correlation)

    for judge in vd.lm_judges:
        scores = np.array([sample[judge] for sample in vd.samples]).flatten()
        valid_score_idxs = np.where(scores != None)[0]
        accuracy = (scores[valid_score_idxs] == vd.true_labels.flatten()[valid_score_idxs]).mean()
        judge_weights[judge] = max(0, accuracy - 0.5)
    
    weights = np.array(list(rm_weights.values()) + list(judge_weights.values()))

    weighted_ensemble = vd.score_matrices.dot(weights)
    best_idx = weighted_ensemble.argmax(axis=1)
    
    accuracy = np.array([vd.true_labels[i, idx] for i, idx in enumerate(best_idx)]).mean()

    
    # Make predictions using weighted scores
    predictions = []
    ground_truth = []
    
    for sample in vd.samples:
        if 'correctness' in sample:
            weighted_scores = np.zeros(len(sample['correctness']))
            total_weight = 0
            
            # Add weighted RM scores
            for rm in vd.reward_models:
                if rm in sample and rm in rm_weights:
                    scores = sample[rm]
                    # Convert scores to numpy array, replacing None with 0
                    scores_array = np.array([s if s is not None and isinstance(s, (int, float)) else 0 
                                          for s in scores])
                    weight = rm_weights[rm]
                    weighted_scores += weight * scores_array
                    total_weight += weight
            
            # Add weighted LM verdicts
            for judge in vd.lm_judges:
                if judge in sample and judge in judge_weights:
                    verdicts = sample[judge]
                    # Convert verdicts to numpy array, replacing None with 0
                    verdicts_array = np.array([v if v is not None and v in [0, 1] else 0 
                                             for v in verdicts])
                    weight = judge_weights[judge]
                    weighted_scores += weight * verdicts_array
                    total_weight += weight
            
            if total_weight > 0:  # Normalize by total weight
                weighted_scores /= total_weight
                
            if len(weighted_scores) > 0:
                best_idx = np.argmax(weighted_scores)
                predictions.append(1)  # Predict best weighted score
                ground_truth.append(sample['correctness'][best_idx])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn

def evaluate_rm_correlation_weighted_ensemble(
    vd: VerifierDataset, 
) -> Tuple[float, float, float, float, int, int, int, int]:
    """
    Evaluate ensemble weighted by point-biserial correlation for reward models only.
    Returns:
        accuracy, precision, recall, f1, true_positives, true_negatives, false_positives, false_negatives
    """
    # Calculate weights for each reward model
    rm_weights = {}  # Point-biserial correlations
    
    # Calculate RM weights using point-biserial correlation
    for rm in vd.reward_models:
        all_scores = []
        all_correctness = []
        for sample in vd.samples:
            if rm in sample and 'correctness' in sample:
                scores = sample[rm]
                correctness = sample['correctness']
                valid_pairs = [(s, c) for s, c in zip(scores, correctness) 
                             if s is not None and isinstance(s, (int, float))]
                if valid_pairs:
                    scores, correct = zip(*valid_pairs)
                    all_scores.extend(scores)
                    all_correctness.extend(correct)
        if all_scores:
            correlation = np.corrcoef(all_scores, all_correctness)[0,1]
            rm_weights[rm] = max(0, correlation)  # Only use positive correlations
    
    # Make predictions using weighted scores
    predictions = []
    ground_truth = []
    
    for sample in vd.samples:
        if 'correctness' in sample:
            weighted_scores = np.zeros(len(sample['correctness']))
            total_weight = 0
            
            # Add weighted RM scores
            for rm in vd.reward_models:
                if rm in sample and rm in rm_weights:
                    scores = sample[rm]
                    scores_array = np.array([s if s is not None and isinstance(s, (int, float)) else 0 
                                          for s in scores])
                    weight = rm_weights[rm]
                    weighted_scores += weight * scores_array
                    total_weight += weight
            
            if total_weight > 0:  # Normalize by total weight
                weighted_scores /= total_weight
                
            if len(weighted_scores) > 0:
                best_idx = np.argmax(weighted_scores)
                predictions.append(1)  # Predict best weighted score
                ground_truth.append(sample['correctness'][best_idx])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn

def evaluate_lm_judge_alignment_weighted_ensemble(
    vd: VerifierDataset, 
) -> Tuple[float, float, float, float, int, int, int, int]:
    """
    Evaluate ensemble weighted by verdict alignment for LM judges only.
    Returns:
        accuracy, precision, recall, f1, true_positives, true_negatives, false_positives, false_negatives
    """
    # Calculate weights for each judge
    judge_weights = {}  # Verdict alignments
    
    # Calculate LM judge weights using verdict alignment
    for judge in vd.lm_judges:
        all_verdicts = []
        all_correctness = []
        for sample in vd.samples:
            if judge in sample and 'correctness' in sample:
                verdicts = sample[judge]
                correctness = sample['correctness']
                valid_pairs = [(v, c) for v, c in zip(verdicts, correctness) 
                             if v is not None and v in [0, 1]]
                if valid_pairs:
                    verdicts, correct = zip(*valid_pairs)
                    all_verdicts.extend(verdicts)
                    all_correctness.extend(correct)
        if all_verdicts:
            accuracy = (np.array(all_verdicts) == np.array(all_correctness)).mean()
            judge_weights[judge] = max(0, accuracy - 0.5)  # Only use better than random
    
    # Make predictions using weighted scores
    predictions = []
    ground_truth = []
    
    for sample in vd.samples:
        if 'correctness' in sample:
            weighted_scores = np.zeros(len(sample['correctness']))
            total_weight = 0
            
            # Add weighted LM verdicts
            for judge in vd.lm_judges:
                if judge in sample and judge in judge_weights:
                    verdicts = sample[judge]
                    verdicts_array = np.array([v if v is not None and v in [0, 1] else 0 
                                             for v in verdicts])
                    weight = judge_weights[judge]
                    weighted_scores += weight * verdicts_array
                    total_weight += weight
            
            if total_weight > 0:  # Normalize by total weight
                weighted_scores /= total_weight
                
            if len(weighted_scores) > 0:
                best_idx = np.argmax(weighted_scores)
                predictions.append(1)  # Predict best weighted score
                ground_truth.append(sample['correctness'][best_idx])
    
    if not predictions:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
        
    predictions = np.array(predictions)
    ground_truth = np.array(ground_truth)
    
    tp = ((predictions == 1) & (ground_truth == 1)).sum()
    tn = ((predictions == 0) & (ground_truth == 0)).sum()
    fp = ((predictions == 1) & (ground_truth == 0)).sum()
    fn = ((predictions == 0) & (ground_truth == 1)).sum()
    
    accuracy = (predictions == ground_truth).mean()
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn


def calculate_ws_methods(vd, ws_data_strategy, ws_class_balance, ws_label_model, use_continuous_marginals):
    if ws_data_strategy == "first_sample":
        acc, prec, rec, f1, tp, tn, fp, fn = evaluate_weak_supervision_first_sample(
            vd, 
            class_balance=ws_class_balance
        )
    elif ws_data_strategy == "per_problem":
        # use WS on each problem; voting matrix is (n_problems, n_generations, n_verifiers)
        return evaluate_weak_supervision_per_problem(
            vd, 
            class_balance=ws_class_balance,
            ws_label_model=ws_label_model,
            use_continuous_marginals=use_continuous_marginals
        )
    elif ws_data_strategy == "all_data":
        return evaluate_weak_supervision_all_data(
            vd, 
            class_balance=ws_class_balance,
            ws_label_model=ws_label_model,
            use_continuous_marginals=use_continuous_marginals
        )
    
    ws_baselines = {}
    ws_baselines[f"weak_supervision_threshold_{vd.ws_threshold_strategy}_{ws_data_strategy}"] = {
        "total_valid": vd.n_problems,
        "correct_predictions": tp + tn,
        "incorrect_predictions": fp + fn,
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "true_positives": tp,
        "true_negatives": tn,
        "false_positives": fp,
        "false_negatives": fn
    }

    return ws_baselines



def _get_ws_estimated_verifier_class_accuracies(label_model, n_verifiers):
    """
        After computing the label model, we get each verifier's class-conditional accuracy, estimated using WS:
        Pr(lf_i = y | y = 1) and Pr(lf_i = y | y = 0)

        This is useful in comparing WS's estimated TPR/FPR against the true TPR/FPR. 

        Args:
        - label_model: trained WS label model.
        - n_verifiers: number of verifiers 

        Returns:
        - (n_verifiers, n_classes) accuracy matrix A where A[i, j] = Pr(lf_i = y | y = j).
    """

    weights = label_model.get_conditional_probs().reshape((n_verifiers, -1, label_model.k))

    weights = weights[:, 1:, :]  # This keeps only the last two rows of the 3-row dimension
    verifier_accuracies = np.array([np.diag(matrix) for matrix in weights])

    return verifier_accuracies

def _get_ws_estimated_verifier_accuracies(label_model, n_verifiers):
    """
        After computing the label model, we get each verifier's accuracy, estimated using WS:
        Pr(lf_i = y)

        This can be used in weighting each verifier's score.  

        Args:
        - label_model: trained WS label model.
        - n_verifiers: number of verifiers 

        Returns:
        - (n_verifiers) accuracy vector A where A[i] = Pr(lf_i = y).
    """
    conditional_accs = _get_ws_estimated_verifier_class_accuracies(label_model, n_verifiers)
    # law of total probability: Pr(lf_i = y) = Pr(lf_i = y | y = 1)Pr(y = 1) + Pr(lf_i = y | y = 0)Pr(y = 0)
    return conditional_accs[:, 0] * label_model.p[0] + conditional_accs[:, 1] * label_model.p[1]


def _get_ws_estimated_verifier_conditional_accuracies(label_model, n_verifiers):
    """
        After computing the label model, we get each verifier's accuracy conditioned on the verifier output, estimated using WS:
        Pr(lf_i = y | lf_i = 1) and Pr(lf_i = y | lf_i = 0)

        This is useful for determining the accuracy of a verifier, given that it votes in a certain way.

        Args:
        - label_model: trained WS label model.
        - n_verifiers: number of verifiers 

        Returns:
        - (n_verifiers, n_classes) accuracy matrix A where A[i, j] = Pr(lf_i = y | lf_i = j).
    """
    weights = label_model.get_conditional_probs().reshape((n_verifiers, -1, label_model.k))
    weights = weights[:, 1:, :]  # This keeps only the last two rows of the 3-row dimension

    vote_conditional_accuracies = []
    for idx in range(n_verifiers):
        all_votes_one_hot = np.eye(2)
        X = np.exp(all_votes_one_hot @ np.log(weights[idx]) + np.log(label_model.p))
        X = X / X.sum(axis=1).reshape(-1, 1)
        vote_conditional_accuracies.append(np.diag(X))

    return np.array(vote_conditional_accuracies) 

def evaluate_weak_supervision_first_sample(vd: VerifierDataset, 
    class_balance: Union[float, None]) -> Tuple[np.ndarray, np.ndarray, Any]:
    """Run weak supervision and return predictions with probabilities."""    

    if len(vd.vote_matrices) == 0:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
    
    vote_matrices = vd.vote_matrices[:, 0, :] # first sample 
    ground_truth = vd.true_labels[:, 0]


    # Scale votes to be 1-based as required by the metal library
    votes_scaled = vote_matrices + 1
    ground_truth_scaled = ground_truth + 1
    
    # Initialize label model
    label_model = LabelModel(k=2, seed=123)

    if class_balance is None:
        print(f"True class balance is {ground_truth.mean()}")
        cb_args = {'Y_dev': ground_truth_scaled, 'class_balance': None}
    else:
        cb_args = {'Y_dev': None, 'class_balance': np.array([1-class_balance, class_balance])}

    # Train model using the correct API
    label_model.train_model(
        L_train=votes_scaled,
        abstains=False,
        symmetric=False,
        n_epochs=1000,
        log_train_every=100,
        lr=0.001,
        **cb_args,
    )
    
    # Get predictions and probabilities
    probs = label_model.predict_proba(votes_scaled)
    preds = probs.argmax(axis=1)

    tn, fp, fn, tp = confusion_matrix(ground_truth, preds).flatten()
    accuracy = accuracy_score(ground_truth, preds)
    precision = precision_score(ground_truth, preds)
    recall = recall_score(ground_truth, preds)
    f1 = f1_score(ground_truth, preds)
    
    return accuracy, precision, recall, f1, tp, tn, fp, fn
        
def evaluate_weak_supervision_per_problem(vd: VerifierDataset,
    class_balance: Union[float, None] = None, ws_label_model: str = 'binary', use_continuous_marginals: bool = False) -> Tuple[float, float, float, float, int, int, int, int]:
    """Evaluate weak supervision using voting matrices and selecting best sample per row."""
    
    # Get voting matrices for each problem

    if len(vd.vote_matrices) == 0:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0
    
    tp = defaultdict(int)
    fp = defaultdict(int)
    ws_selection_strategies = ['probability', 'weight_by_accs', 'weight_by_class_accs', 'weight_by_conditional_accs', 'weight_by_accs_top_3', 'weight_by_accs_top_5', 'weight_by_accs_top_10']
    predictive_results = defaultdict(lambda: defaultdict(list))
    for i, (votes, truth) in enumerate(zip(vd.vote_matrices, vd.true_labels)):
        print(f"Problem {i} out of {vd.n_problems}")
        label_model = LabelModel(k=2, seed=123)
        votes_scaled = votes + 1

        if class_balance is None:
            print(f"True class balance is {truth.mean()}")
            cb_args = {'Y_dev': truth+1, 'class_balance': None}
        else:
            cb_args = {'Y_dev': None, 'class_balance': np.array([1-class_balance, class_balance])}

        label_model.train_model(
            votes_scaled, 
            L_train_continuous=vd.score_matrices[i] if use_continuous_marginals else None,
            abstains=False, 
            symmetric=False, 
            n_epochs=1000, 
            log_train_every=100,
            lr=0.001,
            **cb_args,
            )
        
        probs = label_model.predict_proba(votes_scaled)
        scores = probs[:, 1] 

        verifier_class_accuracies = _get_ws_estimated_verifier_class_accuracies(label_model, vd.n_verifiers)
        verifier_conditional_accuracies = _get_ws_estimated_verifier_conditional_accuracies(label_model, vd.n_verifiers)
        verifier_accuracies = _get_ws_estimated_verifier_accuracies(label_model, vd.n_verifiers)

        for i, v in enumerate(vd.verifiers):
            acc = verifier_accuracies[i]
            tpr = verifier_class_accuracies[i, 1]
            tnr = verifier_class_accuracies[i, 0]
            predictive_results[v]['ws_accuracy'].append(acc)
            predictive_results[v]['ws_tpr'].append(tpr)
            predictive_results[v]['ws_tnr'].append(tnr)

        for strategy in ws_selection_strategies:
            if strategy == "probability":
                # Simply select the largest Pr(y = 1 | binary verifier votes)
                best_idx = np.argmax(scores)
            elif "weight" in strategy:
                # extract the estimated accuracies from weak supervision, and use those as weights in ensembling
                # here we can do the top k and use one of the following accuracies. 
                if strategy == "weight_by_class_accs": # weight_by_class_accs
                    weights = verifier_class_accuracies
                elif strategy == "weight_by_conditional_accs":  
                    weights = verifier_conditional_accuracies
                elif "weight_by_accs" in strategy:
                    weights = verifier_accuracies

                if len(weights.shape) == 2:
                    weights = weights[np.arange(len(weights)), votes.astype(int)]
                    weighted_ensemble = np.sum(np.maximum(0, weights-0.5) * vd.score_matrices[i], axis=1)  # Shape: (1000,)
                elif "top_" in strategy:
                    # zero out the smallest verifiers 
                    k = int(strategy.split("top_")[-1])
                    smallest_indices = np.argsort(weights)[:len(weights)-k] 
                    weights[smallest_indices] = 0
                    weighted_ensemble = np.maximum(0, weights-0.5).dot(vd.score_matrices[i].T)
                else:
                    # we center the accuracies at 0.5, since 0.5 acc = random guessing. We also cut out worse than random verifiers (this hasn't impacted performance so far)
                    weighted_ensemble = np.maximum(0, weights-0.5).dot(vd.score_matrices[i].T)
                best_idx = np.argmax(weighted_ensemble)

            # Always select highest voted sample
            if truth[best_idx]:
                tp[strategy] += 1
            else:
                fp[strategy] += 1
    
    # Calculate metrics
    for i, v in enumerate(vd.verifiers):
        predictive_results[v]['ws_accuracy'] = np.array(predictive_results[v]['ws_accuracy']).mean()
        predictive_results[v]['ws_tpr'] = np.array(predictive_results[v]['ws_tpr']).mean()
        predictive_results[v]['ws_tnr'] = np.array(predictive_results[v]['ws_tnr']).mean()

    ws_results = {}
    for strategy in ws_selection_strategies:
        accuracy = tp[strategy] / vd.n_problems
        precision = tp[strategy] / vd.n_problems  # Since we always make a positive prediction
        recall = 1.0  # We always select a sample
        f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0
        tn = 0
        fn = 0

        ws_results[f"weak_supervision_threshold_{vd.ws_threshold_strategy}_{strategy}{'_continuous_m' if use_continuous_marginals else ''}"] = {
            "total_valid": vd.n_problems,
            "correct_predictions": tp[strategy] + tn,
            "incorrect_predictions": fp[strategy] + fn,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "true_positives": tp[strategy],
            "true_negatives": tn,
            "false_positives": fp[strategy],
            "false_negatives": fn
        }
    
    return ws_results, predictive_results

def evaluate_weak_supervision_all_data(vd: VerifierDataset, 
    class_balance: Union[float, None] = None, ws_label_model: str = "binary", use_continuous_marginals: bool = False) -> Tuple[float, float, float, float, int, int, int, int]:
    """Evaluate weak supervision using voting matrices and selecting best sample per row."""
    
    if len(vd.vote_matrices) == 0:
        return 0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0

    voting_matrix = vd.vote_matrices.reshape((-1, vd.n_verifiers))
    ground_truth_flattened = vd.true_labels.flatten()

    if ws_label_model == "binary":
        label_model = LabelModel(k=2, seed=123)
        voting_matrix  = voting_matrix + 1
        n_verifiers = vd.n_verifiers
    else:
        label_model = ContinuousModel(k=2, seed=123)
        voting_matrix = voting_matrix[:, :len(vd.reward_models)] # temporary hack to just evaluate on reward models 
        n_verifiers=len(vd.reward_models)

    if class_balance is None:
        print(f"True class balance is {ground_truth_flattened.mean()}")
        cb_args = {'Y_dev': ground_truth_flattened+1, 'class_balance': None}
    else:
        cb_args = {'Y_dev': None, 'class_balance': np.array([1-class_balance, class_balance])}

    label_model.train_model(
        voting_matrix, 
        L_train_continuous=vd.score_matrices.reshape((-1, vd.n_verifiers)) if use_continuous_marginals else None,
        abstains=False, 
        symmetric=False, 
        n_epochs=50000, 
        log_train_every=10000,
        lr=0.0001,
        **cb_args,
        )
    
    probs = label_model.predict_proba(voting_matrix)
    scores = probs[:, 1]
    scores = scores.reshape(vd.n_problems, vd.n_generations)

    preds = np.round(scores)

    acc = accuracy_score(vd.true_labels.flatten(), preds.flatten())
    f1 = f1_score(vd.true_labels.flatten(), preds.flatten())

    indices_0 = np.where(vd.true_labels.flatten() == 0)[0]
    indices_1 = np.where(vd.true_labels.flatten() == 1)[0]
        
    acc_0 = accuracy_score(vd.true_labels.flatten()[indices_0], preds.flatten()[indices_0])
    acc_1 = accuracy_score(vd.true_labels.flatten()[indices_1], preds.flatten()[indices_1])

    logger.info(f"Predictive accuracy for WS (all_data) is {acc}\nY=0 acc: {acc_0}\nY=1 acc: {acc_1}\nF1: {f1}")

    predictive_results = {v : {} for v in vd.verifiers}
    if ws_label_model == 'binary':
        verifier_class_accuracies = _get_ws_estimated_verifier_class_accuracies(label_model, n_verifiers)
        verifier_conditional_accuracies = _get_ws_estimated_verifier_conditional_accuracies(label_model, n_verifiers)
        verifier_accuracies = _get_ws_estimated_verifier_accuracies(label_model, n_verifiers)

        for i, v in enumerate(vd.verifiers):
            acc = verifier_accuracies[i]
            tpr = verifier_class_accuracies[i, 1]
            tnr = verifier_class_accuracies[i, 0]
            predictive_results[v]['ws_accuracy'] = acc
            predictive_results[v]['ws_tpr'] = tpr 
            predictive_results[v]['ws_tnr'] = tnr 

    ws_selection_strategies = ['probability', 'weight_by_accs', 'weight_by_class_accs', 'weight_by_conditional_accs', 'weight_by_accs_top_3', 'weight_by_accs_top_5', 'weight_by_accs_top_10', 'weight_by_recall']
    ws_results = {}
    for strategy in ws_selection_strategies:
        tp = 0
        fp = 0
        for i in range(vd.n_problems):
            if strategy == "probability":
                # Simply select the largest Pr(y = 1 | binary verifier votes)
                best_idx = np.argmax(scores[i])
            elif "weight" in strategy:
                if strategy == "weight_by_class_accs": 
                    weights = verifier_class_accuracies
                elif strategy == "weight_by_conditional_accs":  
                    weights = verifier_conditional_accuracies
                elif "weight_by_accs" in strategy: # either weighting by accs or top-k
                    weights = verifier_accuracies 
                elif strategy == "weight_by_recall":
                    weights = verifier_class_accuracies[:, 1]

                # when we have class-conditional accuracies, we need to use one or the other depending on the value of the binary verifier score
                if len(weights.shape) == 2:
                    weights = weights[np.arange(len(weights)), vd.vote_matrices[i].astype(int)]  # Shape: (1000, 15) 
                    weighted_ensemble = np.sum(np.maximum(0, weights-0.5) * vd.score_matrices[i], axis=1)  # Shape: (1000,)
                elif "top_" in strategy:
                    # zero out the smallest verifiers 
                    k = int(strategy.split("top_")[-1])
                    smallest_indices = np.argsort(weights)[:len(weights)-k] 
                    weights[smallest_indices] = 0
                    weighted_ensemble = np.maximum(0, weights-0.5).dot(vd.score_matrices[i].T)
                else:
                    # we center the accuracies at 0.5, since 0.5 acc = random guessing. We also cut out worse than random verifiers (this hasn't impacted performance so far)
                    weighted_ensemble = np.maximum(0, weights-0.5).dot(vd.score_matrices[i].T)

                best_idx = np.argmax(weighted_ensemble)

            # Always select highest voted sample
            if vd.true_labels[i, best_idx]:
                tp += 1
            else:
                fp += 1
        
        # Calculate metrics
        accuracy = tp / vd.n_problems
        precision = tp / vd.n_problems  # Since we always make a positive prediction
        recall = 1.0  # We always select a sample
        f1 = 2 * precision / (precision + 1) if precision > 0 else 0.0
        tn = 0
        fn = 0

        ws_results[f"weak_supervision_threshold_{vd.ws_threshold_strategy}_{strategy}{'_continuous_m' if use_continuous_marginals else ''}"] = {
            "total_valid": vd.n_problems,
            "correct_predictions": tp + tn,
            "incorrect_predictions": fp + fn,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "true_positives": tp,
            "true_negatives": tn,
            "false_positives": fp,
            "false_negatives": fn
        }

    return ws_results, predictive_results

def _compute_thresholds(samples: List[Dict], 
                        reward_threshold: Union[float, None], 
                        ws_threshold_strategy: str,
                        reward_models: List[str],
                        ws_class_balance: Union[float, None]) -> Tuple[np.ndarray, Union[np.ndarray, float]]:
    """
        Determines thresholds for converting reward model scores into {0, 1}. 
        If reward_threshold is not None, we just use a static threshold for the reward models (when used in first_sample). 
        Otherwise, if reward_threshold is None, we use the WS threshold strategy to set thresholds for the first_sample reward models too.

        If ws_threshold_strategy is `mean_per_problem`, we use the mean value of each reward model per problem to determine the threshold.
        Similarly, if ws_threshold_strategy is just `mean`, we use the mean value of each reward model across all problems to determine the threshold.

        We construct similar definitions for when ws_threshold_strategy is `median_per_problem` and `median`. 

        Returns:
        - Array of weak supervision thresholds. This is a (n_verifiers, n_problems) matrix.
        - Array of reward model thresholds. This is either just a float, or is a (n_verifiers, n_problems) matrix. Note that this is only really used in first_sample baselines!
    """

    rt = None 
    if reward_threshold is not None: 
        rt = reward_threshold 

    all_normalized_scores = np.array([sample[rm] for rm in reward_models for sample in samples]).reshape(len(reward_models), len(samples), -1)
    all_labels = np.array([sample['correctness'] for sample in samples]).reshape(len(samples), -1)
    try:
        wst = np.ones((len(reward_models), len(samples))) * float(ws_threshold_strategy)
        if rt is None:
            rt = wst.copy()
        return wst, rt 
    except ValueError:
        if "per_problem" in ws_threshold_strategy:
            if "mean" in ws_threshold_strategy:
                print(f"Using the mean reward model score per problem as the threshold.")
                means = all_normalized_scores.mean(axis=2) # means is (len(reward_models), num_samples)
                wst = means
            elif "median" in ws_threshold_strategy:
                print(f"Using the median reward model score per problem as the threshold.")
                medians = np.median(all_normalized_scores, axis=2)
                wst = medians
            elif "threshold_cb" in ws_threshold_strategy:
                print(f"Splitting by class balance.")
                wst = np.zeros((len(reward_models), len(samples)))
                for i in range(len(reward_models)):
                    for j in range(len(samples)):
                        cb = ws_class_balance if ws_class_balance is not None else all_labels[j].mean()
                        sorted_row = np.sort(all_normalized_scores[i, j])
                        index = int(np.ceil((1-cb) * len(sorted_row))) - 1
                        wst[i, j] = sorted_row[index]     
        else:
            all_normalized_scores = all_normalized_scores.reshape(len(reward_models), -1)
            if "mean" in ws_threshold_strategy:
                print(f"Using the mean reward model score across all data as the threshold.")
                means = all_normalized_scores.mean(axis=1)
                means = np.tile(means, (len(samples), 1)).T
                wst = means
            elif "median" in ws_threshold_strategy:
                print(f"Using the median reward model score across all data as the threshold.")
                medians = np.median(all_normalized_scores, axis=1)
                medians = np.tile(medians, (len(samples), 1)).T
                wst = medians
            elif "threshold_cb" in ws_threshold_strategy:
                print(f"Splitting by class balance.")
                wst = np.zeros(len(reward_models))
                cb = ws_class_balance if ws_class_balance is not None else all_labels.mean()
                for i in range(len(reward_models)):
                    sorted_row = np.sort(all_normalized_scores[i])
                    index = int(np.ceil((1-cb) * len(sorted_row))) - 1
                    wst[i] = sorted_row[index]

                wst = np.tile(wst, (len(samples), 1)).T        

        if rt is None:
            rt = wst.copy()

        return wst, rt 

def main():
    parser = argparse.ArgumentParser(description="Analyze dataset statistics")
    parser.add_argument("--dataset_path", type=str, required=True,
                      help="Path to the dataset")
    parser.add_argument("--reward_threshold", type=float, default=None,
                      help="Threshold for converting reward model scores to binary predictions (0.0-1.0).")
    parser.add_argument("--threshold_strategy", type=str, default='mean_per_problem',
                      help="Threshold strategy: 'mean', 'median', 'mean_per_problem', 'median_per_problem', 'threshold_cb', 'threshold_cb_per_problem'. Must be set if --reward_threshold is None.")
    parser.add_argument("--ws_allow_abstains", type=bool, default=False,
                      help="Whether to allow weak supervision to abstain from voting")
    parser.add_argument("--ws_learning_rate", type=float, default=0.001,
                      help="Learning rate for weak supervision model training")
    parser.add_argument("--skip_ws", action="store_true",
                      help="Skip weak supervision calculations")
    parser.add_argument("--ws_data_strategy", type=str, default='per_problem',
                      choices=['first_sample', 'per_problem', 'all_data'],
                      help="How to construct voting matrices across the (n_problems, n_generations, n_verifiers) data.")
    parser.add_argument("--ws_class_balance", type=float, default=None,
                      help="Estimate of class balance (probability a generation is correct). If set to none, 'cheats' and uses labeled data to estimate class balance.")
    parser.add_argument("--ws_label_model", type=str, default="binary", choices = ["binary", "continuous"])
    parser.add_argument("--ws_continuous_marginals", action="store_true")
    parser.add_argument("--verifier_subset", nargs="+", default=None, 
                      help="If set, we use a subset of verifiers.")
    parser.add_argument("--greedy", action="store_true", 
                      help="If set, we evaluate first_sample baselines too.")
    parser.add_argument("--verbose", action="store_true", 
                      help="If set, logs extra information.")
    parser.add_argument("--tiebreaker", type=str, default='first', 
                      help="Principle for tiebreaking LM judge outputs (only implemented for single judge selection)")
    parser.add_argument("--seed", type=int, default=0, 
                      help="random seed")
    parser.add_argument("--hard_problem_cutoff", type=float, default=None, 
                      help="If set, we only keep problems where the fraction of correct generations is less than this cutoff.")



    args = parser.parse_args()

    np.random.seed(args.seed)

    if args.verifier_subset is not None and len(args.verifier_subset) < 3 and not args.skip_ws:
        raise ValueError("Cannot do weak supervision with less than 3 verifiers (underparameterized)")
    
    vd = VerifierDataset(
        args.verifier_subset,
        args.dataset_path,
        args.reward_threshold,
        args.threshold_strategy,
        args.ws_class_balance,
        args.hard_problem_cutoff,
        verbose=args.verbose
    )


    logger.info(f"Class balance is {vd.true_labels.mean()}")
    logger.info(f"Dimensions: {vd.n_problems} problems, {vd.n_generations} generations, {vd.n_verifiers} verifiers")

    predictive_results = calculate_predictive_accuracies(vd)
    

    # Calculate model performance metrics
    logger.info(f"\nCalculating individual model performance metrics...")
    model_performance_stats = calculate_rm_performance(vd)
    
    model_performance_stats["lm_judges"] = calculate_judge_performance(vd, tiebreaker=args.tiebreaker)


    # Calculate baseline metrics with multiple WS thresholds
    logger.info("\nCalculating baseline metrics...")
    baseline_stats = calculate_baselines(vd)

    ws_stats = {}
    if not args.skip_ws:
        logger.info("\nCalculating Weak Supervision metrics...")
        if args.ws_continuous_marginals:
            assert args.ws_label_model == 'binary'
        ws_stats, ws_predictive_results = calculate_ws_methods(
            vd, 
            args.ws_data_strategy, 
            args.ws_class_balance,
            args.ws_label_model,
            args.ws_continuous_marginals)
    

    # Calculate other baseline metrics
    if args.greedy:
        # the thresholds used to determine if enough of the verifiers 'pass' do not change, but we define passing in terms of reward_thresholds (so either set to a float, or constructed using mean/median)
        first_sample_thresholds = [round(t, 2) for t in np.arange(0.5, 1.0, 0.05)]
        majority_baseline_stats = calculate_majority_baseline(vd)
        first_sample_baseline_stats = calculate_first_sample_baseline(vd)
        first_positive_lm_judge_stats = calculate_first_positive_lm_judge_baseline(vd, first_sample_thresholds)
        first_positive_reward_model_stats = calculate_first_positive_reward_model_baseline(vd, first_sample_thresholds)
        first_positive_joint_stats = calculate_first_positive_joint_baseline(vd, first_sample_thresholds)
        
        # Combine all metrics
        model_performance_stats["baselines"] = {
            **baseline_stats,
            **ws_stats,
            "majority_classification": majority_baseline_stats,
            "first_sample": first_sample_baseline_stats,
            **{f"first_positive_lm_judge_{k}": v for k, v in first_positive_lm_judge_stats.items()},
            **{f"first_positive_reward_model_{k}": v for k, v in first_positive_reward_model_stats.items()},
            **{f"first_positive_joint_{k}": v for k, v in first_positive_joint_stats.items()}
        }

        individual_first_positive_stats = calculate_individual_model_first_positive(vd, first_sample_thresholds)
        model_performance_stats.update(individual_first_positive_stats)
    else:
        model_performance_stats["baselines"] = {
            **baseline_stats,
            **ws_stats,
        }

    if not args.skip_ws:
        predictive_results = {
            key: {**predictive_results[key], **ws_predictive_results[key]} for key in predictive_results.keys()
        }
        log_predictive_accuracies_with_ws(predictive_results)
    else:
        log_predictive_accuracies(predictive_results)
    
    # Log all performance metrics
    log_model_performance(model_performance_stats)

    if args.verifier_subset is not None and len(args.verifier_subset) < 10:
        model_performance_stats['verifiers'] = vd.verifiers 


        if len(args.verifier_subset) == 2:
            folder_name = "pairs"
        elif len(args.verifier_subset) == 3:
            folder_name = "triples"
        elif len(args.verifier_subset) == 4:
            folder_name = "quadruples"
        else:
            raise NotImplementedError("Folder name not defined for sets larger than 4 verifiers.")

        folder = DATASET_TO_COMBINATIONS_RESULTS_PATH[args.dataset_path]

        if not os.path.exists(os.path.join(folder, folder_name)):
            os.makedirs(os.path.join(folder, folder_name))

        with open(os.path.join(folder, folder_name, f"{'_'.join(vd.verifiers)}.json"), "w") as f: 
            json.dump(model_performance_stats, f, cls=NpEncoder)

    if args.verbose:
        logger.info("\nCalculating correlations...")
        correlation_stats = calculate_correlations(vd)
        log_correlations(correlation_stats)

if __name__ == "__main__":
    main()