
import argparse
import logging
from datasets import load_from_disk, load_dataset
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import re

from collections import Counter

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def normalize_scores(scores):
    """Min-max normalize scores to 0-1 range."""
    if not scores:
        return []
    
    # Filter out None values
    valid_scores = [s for s in scores if s is not None]
    if not valid_scores:
        return [None] * len(scores)
        
    min_score = min(valid_scores)
    max_score = max(valid_scores)
    
    if max_score == min_score:
        return [0.5 if s is not None else None for s in scores]  # All valid scores equal
        
    return [(s - min_score) / (max_score - min_score) if s is not None else None for s in scores]

def get_model_columns(dataset) -> Dict[str, Dict[str, str]]:
    """Extract model verdict columns from dataset."""
    model_columns = {}
    
    # Add unit test approach
    if 'unit_test_results' in dataset.column_names:
        model_columns['unit_tests'] = {
            'single': 'unit_test_results'
        }
    
    # Find all verdict columns
    for col in dataset.column_names:
        if col.endswith('_verdicts_v1'):
            model_name = col.replace('_verdicts_v1', '')
            if model_name not in model_columns:
                model_columns[model_name] = {}
            model_columns[model_name]['single'] = col
        elif col.endswith('_bootstrapped_verdicts_v1'):
            model_name = col.replace('_bootstrapped_verdicts_v1', '')
            if model_name not in model_columns:
                model_columns[model_name] = {}
            model_columns[model_name]['bootstrapped'] = col
    
    # Only keep models that have at least one type of verdict
    return {k: v for k, v in model_columns.items() if v}

def calculate_row_selection_metrics(row_verdicts: List[List[Optional[bool]]], 
                                  row_ground_truth: List[List[bool]],
                                  problem_indices: Optional[List[int]] = None) -> Dict:
    """Calculate Selection@1 metrics for rows."""
    results = {
        'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0,
        'total': len(row_verdicts)
    }
    
    # Group by problem for normalized precision
    problem_stats = defaultdict(lambda: {'tp': 0, 'fp': 0, 'has_correct': False})
    
    for verdicts, ground_truth, prob_idx in zip(row_verdicts, row_ground_truth, problem_indices or range(len(row_verdicts))):
        has_correct = any(ground_truth)
        problem_stats[prob_idx]['has_correct'] |= has_correct
        
        # Convert None to False in verdicts
        processed_verdicts = [False if v is None else v for v in verdicts]
        
        try:
            # Find first True verdict
            first_true_idx = next(i for i, v in enumerate(processed_verdicts) if v)
            
            if ground_truth[first_true_idx]:
                results['tp'] += 1  # Selected correct when one exists
                problem_stats[prob_idx]['tp'] += 1
            else:
                results['fp'] += 1  # Selected incorrect
                problem_stats[prob_idx]['fp'] += 1
        except StopIteration:
            # No True verdicts (abstained)
            if has_correct:
                results['fn'] += 1  # Failed to select when correct exists
            else:
                results['tn'] += 1  # Correctly abstained
    
    total = results['total']
    correct = results['tp'] + results['tn']
    results['accuracy'] = correct / total if total > 0 else 0.0
    
    # Calculate precision, recall, and F1
    tp = results['tp']
    fp = results['fp']
    fn = results['fn']
    
    results['precision'] = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    results['recall'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    results['f1'] = 2 * (results['precision'] * results['recall']) / (results['precision'] + results['recall']) if (results['precision'] + results['recall']) > 0 else 0.0
    
    # Calculate normalized precision
    problem_precisions = []
    for prob_idx, stats in problem_stats.items():
        if stats['tp'] + stats['fp'] == 0:  # No positive predictions
            if stats['has_correct']:
                problem_precisions.append(0.0)  # Count as 0 if problem has correct solution
            # Skip if problem has no correct solution
        else:
            precision = stats['tp'] / (stats['tp'] + stats['fp'])
            problem_precisions.append(precision)
    
    results['norm_precision'] = sum(problem_precisions) / len(problem_precisions) if problem_precisions else 0.0
    
    return results

def calculate_normalized_precision(verdicts: List[List[Optional[bool]]], 
                                 ground_truth: List[List[bool]], 
                                 problem_indices: List[int]) -> float:
    """Calculate normalized precision across problems."""
    # Group verdicts and ground truth by problem
    problem_stats = defaultdict(lambda: {'tp': 0, 'fp': 0, 'has_correct': False})
    
    for verdict_row, gt_row, prob_idx in zip(verdicts, ground_truth, problem_indices):
        # Check if problem has any correct solutions
        problem_stats[prob_idx]['has_correct'] |= any(gt_row)
        
        # Convert None to False in verdicts
        processed_verdicts = [False if v is None else v for v in verdict_row]
        
        # Count TP and FP for this sample
        for v, gt in zip(processed_verdicts, gt_row):
            if v:  # Model predicted True
                if gt:
                    problem_stats[prob_idx]['tp'] += 1
                else:
                    problem_stats[prob_idx]['fp'] += 1
    
    # Calculate precision for each problem
    problem_precisions = []
    for prob_idx, stats in problem_stats.items():
        if stats['tp'] + stats['fp'] == 0:  # No positive predictions
            if stats['has_correct']:
                problem_precisions.append(0.0)  # Count as 0 if problem has correct solution
            # Skip if problem has no correct solution
        else:
            precision = stats['tp'] / (stats['tp'] + stats['fp'])
            problem_precisions.append(precision)
    
    # Calculate normalized precision
    return sum(problem_precisions) / len(problem_precisions) if problem_precisions else 0.0

def calculate_metrics(verdicts: List[Optional[bool]], ground_truth: List[bool], 
                     problem_indices: Optional[List[int]] = None) -> Dict:
    """Calculate accuracy metrics for a set of verdicts."""
    # Convert None to False and zip with ground truth
    processed_verdicts = [(False if v is None else v) for v in verdicts]
    pairs = list(zip(processed_verdicts, ground_truth))
    
    if not pairs:
        return {
            'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0,
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1': 0.0,
            'norm_precision': 0.0,
            'total': 0
        }
    
    tp = sum(1 for v, gt in pairs if v and gt)
    tn = sum(1 for v, gt in pairs if not v and not gt)
    fp = sum(1 for v, gt in pairs if v and not gt)
    fn = sum(1 for v, gt in pairs if not v and gt)
    
    total = len(pairs)
    correct = tp + tn
    accuracy = correct / total if total > 0 else 0.0
    
    # Calculate precision, recall, and F1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    # Calculate normalized precision if problem indices are provided
    norm_precision = 0.0
    if problem_indices is not None:
        # Reshape verdicts and ground truth into problem-level lists
        problems = defaultdict(list)
        for v, gt, idx in zip(processed_verdicts, ground_truth, problem_indices):
            problems[idx].append((v, gt))
        
        # Calculate normalized precision
        problem_precisions = []
        for prob_idx, prob_pairs in problems.items():
            prob_tp = sum(1 for v, gt in prob_pairs if v and gt)
            prob_fp = sum(1 for v, gt in prob_pairs if v and not gt)
            has_correct = any(gt for _, gt in prob_pairs)
            
            if prob_tp + prob_fp == 0:  # No positive predictions
                if has_correct:
                    problem_precisions.append(0.0)
            else:
                prob_precision = prob_tp / (prob_tp + prob_fp)
                problem_precisions.append(prob_precision)
        
        norm_precision = sum(problem_precisions) / len(problem_precisions) if problem_precisions else 0.0
    
    return {
        'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'norm_precision': norm_precision,
        'total': total
    }

def print_precision_distribution(name: str, problem_precisions: List[float], bin_width: float = 0.1):
    """Print ASCII histogram of precision distribution."""
    print(f"\nPrecision distribution for {name}:")
    
    # Calculate histogram
    bins = {}
    for i in range(int(1/bin_width)):
        bin_start = i * bin_width
        bin_end = (i + 1) * bin_width
        bin_key = f"{bin_start:.1f}-{bin_end:.1f}"
        bins[bin_key] = 0
    
    # Count values in each bin
    skipped = 0
    valid_precisions = []
    for p in problem_precisions:
        if p is None:
            skipped += 1
            continue
        valid_precisions.append(p)
        bin_idx = min(int(p / bin_width), int(1/bin_width) - 1)
        bin_start = bin_idx * bin_width
        bin_end = (bin_idx + 1) * bin_width
        bin_key = f"{bin_start:.1f}-{bin_end:.1f}"
        bins[bin_key] += 1
    
    # Find max count for scaling
    max_count = max(bins.values())
    scale_factor = 50 / max_count if max_count > 0 else 1
    
    # Print histogram
    for bin_range, count in bins.items():
        bar_length = int(count * scale_factor)
        print(f"{bin_range}: {'█' * bar_length} ({count})")
    
    # Print statistics
    if valid_precisions:
        mean = sum(valid_precisions) / len(valid_precisions)
        sorted_prec = sorted(valid_precisions)
        median = sorted_prec[len(sorted_prec)//2]
        variance = sum((x - mean) ** 2 for x in valid_precisions) / len(valid_precisions)
        std_dev = variance ** 0.5
        print(f"Mean: {mean:.3f}  Median: {median:.3f}  Std: {std_dev:.3f}")
    
    if skipped > 0:
        print(f"Problems with no predictions: {skipped}")
    print("-" * 80)

def print_results_table(results: Dict):
    """Print formatted results table."""
    print("\nSample-level Results:")
    print("-" * 200)
    header = "{:<60} | {:>8} | {:>9} | {:>9} | {:>7} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4} | {:>12}".format(
        "Model", "Accuracy", "Precision", "Norm Prec", "Recall", "F1", "TP", "TN", "FP", "FN", "Correct/Total"
    )
    print(header)
    print("-" * 200)
    
    for model_name, model_results in sorted(results['models'].items()):
        if 'single' in model_results and 'sample_metrics' in model_results['single']:
            metrics = model_results['single']['sample_metrics']
            correct = metrics['tp'] + metrics['tn']
            print("{:<60} | {:>7.1%} | {:>8.1%} | {:>8.1%} | {:>6.1%} | {:>3.1%} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4}/{:<4}".format(
                f"{model_name} (single)",
                metrics['accuracy'],
                metrics['precision'],
                metrics['norm_precision'],
                metrics['recall'],
                metrics['f1'],
                metrics['tp'], metrics['tn'],
                metrics['fp'], metrics['fn'],
                correct, metrics['total']
            ))
        
        # Print bootstrapped results if available
        if 'bootstrapped' in model_results and 'sample_metrics' in model_results['bootstrapped']:
            metrics = model_results['bootstrapped']['sample_metrics']
            correct = metrics['tp'] + metrics['tn']
            print("{:<60} | {:>7.1%} | {:>8.1%} | {:>6.1%} | {:>3.1%} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4}/{:<4}".format(
                f"{model_name} (bootstrapped)",
                metrics['accuracy'],
                metrics['precision'],
                metrics['recall'],
                metrics['f1'],
                metrics['tp'], metrics['tn'],
                metrics['fp'], metrics['fn'],
                correct, metrics['total']
            ))
    print("-" * 180)

def print_selection_accuracies(results: Dict):
    """Print Selection@1 accuracy table."""
    print("\nSelection@1 Accuracies:")
    print("-" * 180)
    header = "{:<60} | {:>8} | {:>9} | {:>9} | {:>7} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4} | {:>12}".format(
        "Model", "Accuracy", "Precision", "Norm Prec", "Recall", "F1", "TP", "TN", "FP", "FN", "Correct/Total"
    )
    print(header)
    print("-" * 180)
    
    for model_name, model_results in sorted(results['models'].items()):
        # Print single verdict results if available
        if 'single' in model_results and 'row_metrics' in model_results['single']:
            metrics = model_results['single']['row_metrics']
            correct = metrics['tp'] + metrics['tn']
            print("{:<60} | {:>7.1%} | {:>8.1%} | {:>8.1%} | {:>6.1%} | {:>3.1%} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4}/{:<4}".format(
                f"{model_name} (single)",
                metrics['accuracy'],
                metrics['precision'],
                metrics['norm_precision'],
                metrics['recall'],
                metrics['f1'],
                metrics['tp'], metrics['tn'],
                metrics['fp'], metrics['fn'],
                correct, metrics['total']
            ))
        
        # Print bootstrapped results if available
        if 'bootstrapped' in model_results and 'row_metrics' in model_results['bootstrapped']:
            metrics = model_results['bootstrapped']['row_metrics']
            correct = metrics['tp'] + metrics['tn']
            print("{:<60} | {:>7.1%} | {:>8.1%} | {:>6.1%} | {:>3.1%} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4}/{:<4}".format(
                f"{model_name} (bootstrapped)",
                metrics['accuracy'],
                metrics['precision'],
                metrics['recall'],
                metrics['f1'],
                metrics['tp'], metrics['tn'],
                metrics['fp'], metrics['fn'],
                correct, metrics['total']
            ))
    print("-" * 180)

def print_ensemble_results(results: Dict, dataset, reward_threshold: float):
    """Print ensemble results table."""
    ensemble_metrics = calculate_ensemble_metrics(results, dataset, reward_threshold)
    if not ensemble_metrics:
        print("\nEnsembling Approaches:")
        print("No valid ensemble results available")
        return
    
    print("\nEnsembling Approaches:")
    print("-" * 180)
    header = "{:<60} | {:>8} | {:>9} | {:>9} | {:>7} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4} | {:>12}".format(
        "Approach", "Accuracy", "Precision", "Norm Prec", "Recall", "F1", "TP", "TN", "FP", "FN", "Correct/Total"
    )
    print(header)
    print("-" * 180)
    
    ensemble_names = {
        'top_3_unanimous': 'Top-3 Unanimous Ensemble',
        'all_unanimous': 'All Judge Unanimous Ensemble',
        'top_3_majority': 'Top-3 Majority Ensemble',
        'complete_majority': 'Complete Majority Ensemble',
        'Two-Stage (UT5 + Unanimous LM Judges)': 'Two-Stage (UT5 + Unanimous LM Judges)',
        'Two-Stage (Top-3 Freq + LM Judge Ranking)': 'Two-Stage (Top-3 Freq + LM Judge Ranking)',
        'Two-Stage (Top-5 Freq + LM Judge Ranking)': 'Two-Stage (Top-5 Freq + LM Judge Ranking)',
        'Two-Stage (Top-10 Freq + LM Judge Ranking)': 'Two-Stage (Top-10 Freq + LM Judge Ranking)',
        'Two-Stage (Top-20 Freq + LM Judge Ranking)': 'Two-Stage (Top-20 Freq + LM Judge Ranking)',
        'Two-Stage (Top-3 Freq + Unanimous LM Judge Ranking)': 'Two-Stage (Top-3 Freq + Unanimous LM Judge Ranking)',
        'Two-Stage (Top-5 Freq + Unanimous LM Judge Ranking)': 'Two-Stage (Top-5 Freq + Unanimous LM Judge Ranking)',
        'Two-Stage (Top-10 Freq + Unanimous LM Judge Ranking)': 'Two-Stage (Top-10 Freq + Unanimous LM Judge Ranking)',
        'Two-Stage (Top-20 Freq + Unanimous LM Judge Ranking)': 'Two-Stage (Top-20 Freq + Unanimous LM Judge Ranking)',
        'Two-Stage (Top-3 Freq + Top-3 RM Ranking)': 'Two-Stage (Top-3 Freq + Top-3 RM Ranking)',
        'Two-Stage (Top-5 Freq + Top-3 RM Ranking)': 'Two-Stage (Top-5 Freq + Top-3 RM Ranking)',
        'Two-Stage (Top-10 Freq + Top-3 RM Ranking)': 'Two-Stage (Top-10 Freq + Top-3 RM Ranking)',
        'Two-Stage (Top-20 Freq + Top-3 RM Ranking)': 'Two-Stage (Top-20 Freq + Top-3 RM Ranking)',
        'Top-1 RM Ensemble': 'Top-1 RM Ensemble',
        'Top-3 RM Ensemble': 'Top-3 RM Ensemble',
        'Top-5 RM Ensemble': 'Top-5 RM Ensemble',
        'Top-10 RM Ensemble': 'Top-10 RM Ensemble',
        'Top-All RM Ensemble': 'All RM Ensemble'
    }
    
    for key, name in ensemble_names.items():
        if key in ensemble_metrics:
            metrics = ensemble_metrics[key]
            correct = metrics['tp'] + metrics['tn']
            print("{:<60} | {:>7.1%} | {:>8.1%} | {:>8.1%} | {:>6.1%} | {:>3.1%} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4}/{:<4}".format(
                name,
                metrics['accuracy'],
                metrics['precision'],
                metrics['norm_precision'],
                metrics['recall'],
                metrics['f1'],
                metrics['tp'], metrics['tn'],
                metrics['fp'], metrics['fn'],
                correct, metrics['total']
            ))
    print("-" * 180)

    """print("\nPrecision Distributions by Ensemble:")
    for key, name in ensemble_names.items():
        if key in ensemble_metrics:
            # Extract per-problem precisions from the metrics calculation
            problem_stats = ensemble_metrics[key]['problem_stats']
            precisions = []
            for prob_stats in problem_stats.values():
                if prob_stats['tp'] + prob_stats['fp'] == 0:
                    precisions.append(None)
                else:
                    precisions.append(prob_stats['tp'] / (prob_stats['tp'] + prob_stats['fp']))
            
            print_precision_distribution(name, precisions)"""

def get_top_reward_models(results: Dict, n: int = 3) -> List[str]:
    """Get top N reward models based on normalized precision."""
    # Mapping from results names to dataset column names
    reward_models = []
    for model_name, model_results in results['models'].items():
        if model_name.endswith('(reward)') and 'single' in model_results:
            metrics = model_results['single']['sample_metrics']
            if 'norm_precision' in metrics:
                reward_models.append((model_name.replace(' (reward)', ''), metrics['norm_precision']))
    
    # Sort by normalized precision and return top N (or all if n is None)
    sorted_models = sorted(reward_models, key=lambda x: x[1], reverse=True)
    if n is None:
        return [name for name, _ in sorted_models]
    else:
        return [name for name, _ in sorted_models[:n]]

def calculate_rm_ensemble_score(answer_samples: List[int], dataset, row_idx: int, reward_models: List[str]) -> float:
    """Calculate average reward model score for an answer's samples."""
    if not answer_samples:
        return 0.0
        
    total_score = 0.0
    valid_scores = 0
    
    # For each reward model
    for rm_name in reward_models:  # Now using dataset column names directly
        row_scores = dataset[rm_name][row_idx]
        if not row_scores:
            continue
            
        # Get scores for the samples that produced this answer
        sample_scores = [row_scores[i] for i in answer_samples if i < len(row_scores)]
        valid_sample_scores = [s for s in sample_scores if s is not None]
        
        if valid_sample_scores:
            # Normalize scores for this RM
            min_score = min(valid_sample_scores)
            max_score = max(valid_sample_scores)
            if max_score > min_score:
                normalized_scores = [(s - min_score) / (max_score - min_score) for s in valid_sample_scores]
                total_score += sum(normalized_scores) / len(normalized_scores)
                valid_scores += 1
    
    return total_score / valid_scores if valid_scores > 0 else 0.0

def process_dataset(dataset, max_rows: Optional[int] = None, reward_threshold: float = 0.5) -> Dict:
    """Process dataset and calculate metrics for all models."""
    # Get number of rows to process
    total_rows = len(dataset)
    rows_to_process = min(max_rows, total_rows) if max_rows else total_rows
    
    # Get model columns
    model_columns = get_model_columns(dataset)
    logger.info(f"Found {len(model_columns)} models with verdict columns")
    
    # Initialize results
    results = {
        'models': defaultdict(dict),
        'dataset_stats': {
            'total_rows': rows_to_process,
            'rows_with_correct': 0,
            'first_correct': 0
        }
    }
    
    # Process dataset
    dataset = dataset.select(range(rows_to_process))
    
    # Get problem indices
    problem_indices = dataset['problem_id'] if 'problem_id' in dataset.column_names else list(range(rows_to_process))
    
    # Calculate dataset statistics
    ground_truth = dataset['answer_correct']
    results['dataset_stats']['rows_with_correct'] = sum(1 for row in ground_truth if any(row))
    results['dataset_stats']['first_correct'] = sum(1 for row in ground_truth if row[0])
    
    # Calculate majority voting accuracy
    majority_correct = 0
    majority_total = 0
    
    for i in range(rows_to_process):
        answers = dataset['extracted_answers_using_R.E.'][i]
        answer_correct = dataset['answer_correct'][i]
        
        # Get answer counts, ignoring NO_ANSWER
        answer_counts = {}
        for ans, is_correct in zip(answers, answer_correct):
            if ans != 'NO_ANSWER':
                if ans not in answer_counts:
                    answer_counts[ans] = {'count': 0, 'correct': 0}
                answer_counts[ans]['count'] += 1
                if is_correct:
                    answer_counts[ans]['correct'] += 1
        
        if answer_counts:
            # Find majority answer
            majority_ans = max(answer_counts.items(), key=lambda x: x[1]['count'])[0]
            stats = answer_counts[majority_ans]
            
            # Check if majority of samples with this answer are correct
            if stats['correct'] > stats['count'] / 2:
                majority_correct += 1
            majority_total += 1
    
    results['dataset_stats']['majority_voting_correct'] = majority_correct
    results['dataset_stats']['majority_voting_total'] = majority_total
    
    # Process each model
    for model_name, columns in model_columns.items():
        model_results = {'single': {}, 'bootstrapped': {}}
        
        # Special handling for unit test results
        if model_name == 'unit_tests':
            try:
                unit_test_results = dataset[columns['single']]
                
                # Process different numbers of tests
                for num_tests in [5, 10, 15, None]:  # None means all tests
                    suffix = f" (first {num_tests})" if num_tests else ""
                    model_key = f"unit_tests{suffix}"
                    
                    # Convert unit test results to verdicts (True if all tests passed)
                    verdicts = [
                        [all(result == '[Passed]' for result in sample_results[:num_tests] or sample_results)
                         for sample_results in row]
                        for row in unit_test_results
                    ]
                    
                    valid_pairs = [(v, gt, idx) for v, gt, idx in zip(verdicts, ground_truth, problem_indices)]
                    if valid_pairs:
                        valid_verdicts, valid_ground_truth, valid_indices = zip(*valid_pairs)
                        
                        # Store results under new model key
                        results['models'][model_key] = {
                            'single': {
                                'verdicts': list(valid_verdicts),
                                'ground_truth': list(valid_ground_truth),
                                'sample_metrics': calculate_metrics(
                                    [v for row in valid_verdicts for v in row],
                                    [gt for row in valid_ground_truth for gt in row],
                                    [idx for idx in valid_indices for _ in range(len(valid_verdicts[0]))]
                                ),
                                'row_metrics': calculate_row_selection_metrics(
                                    list(valid_verdicts),
                                    list(valid_ground_truth)
                                )
                            }
                        }
                    else:
                        logger.warning(f"No valid unit test results found for {model_key}")
            except Exception as e:
                logger.error(f"Error processing unit test results: {str(e)}")
                continue
        else:
            # Process single verdicts if available
            if 'single' in columns:
                try:
                    verdicts = dataset[columns['single']]
                    valid_pairs = [(v, gt, idx) for v, gt, idx in zip(verdicts, ground_truth, problem_indices) if v is not None]
                    if valid_pairs:
                        valid_verdicts, valid_ground_truth, valid_indices = zip(*valid_pairs)
                        
                        # Store raw verdicts and ground truth for ensemble calculations
                        model_results['single']['verdicts'] = list(valid_verdicts)
                        model_results['single']['ground_truth'] = list(valid_ground_truth)
                        
                        # Calculate sample-level metrics
                        model_results['single']['sample_metrics'] = calculate_metrics(
                            [v for row in valid_verdicts for v in row],
                            [gt for row in valid_ground_truth for gt in row],
                            [idx for idx in valid_indices for _ in range(len(valid_verdicts[0]))]  # Repeat indices for each sample
                        )
                        model_results['single']['row_metrics'] = calculate_row_selection_metrics(
                            list(valid_verdicts),
                            list(valid_ground_truth)
                        )
                    else:
                        logger.warning(f"No valid single verdicts found for {model_name}")
                except Exception as e:
                    logger.error(f"Error processing single verdicts for {model_name}: {str(e)}")
                    continue
            
            # Process bootstrapped verdicts if available
            if 'bootstrapped' in columns:
                try:
                    verdicts = dataset[columns['bootstrapped']]
                    valid_pairs = [(v, gt, idx) for v, gt, idx in zip(verdicts, ground_truth, problem_indices) if v is not None]
                    if valid_pairs:
                        valid_verdicts, valid_ground_truth, valid_indices = zip(*valid_pairs)
                        
                        # Store raw verdicts and ground truth for ensemble calculations
                        model_results['bootstrapped']['verdicts'] = list(valid_verdicts)
                        model_results['bootstrapped']['ground_truth'] = list(valid_ground_truth)
                        
                        # Calculate sample-level metrics
                        model_results['bootstrapped']['sample_metrics'] = calculate_metrics(
                            [v for row in valid_verdicts for v in row],
                            [gt for row in valid_ground_truth for gt in row],
                            [idx for idx in valid_indices for _ in range(len(valid_verdicts[0]))]  # Repeat indices for each sample
                        )
                        model_results['bootstrapped']['row_metrics'] = calculate_row_selection_metrics(
                            list(valid_verdicts),
                            list(valid_ground_truth)
                        )
                    else:
                        logger.warning(f"No valid bootstrapped verdicts found for {model_name}")
                except Exception as e:
                    logger.error(f"Error processing bootstrapped verdicts for {model_name}: {str(e)}")
                    continue
        
        # Only add model results if we have valid data
        if model_results['single'] or model_results['bootstrapped']:
            results['models'][model_name] = model_results
        else:
            logger.warning(f"Skipping {model_name} due to no valid results")

    reward_models = [
        'skyworks_scores', 'urm_scores', 'offset_bias_scores', 'qrm_scores',
        'grm_scores', 'grm_llama32_scores', 'grm_gemma_scores', 'gpm_scores',
        'inform_scores', 'internlm_scores', 'internlm2_scores', 'qwen25_math_scores',
        'eurus_prm_scores', 'eurus_prm2_scores', 'armor_rm_score', 'skywork_gemma_scores',
        'armor_rm_correctness', 'qrm_gemma_correctness'
    ]

    # Process reward models
    for model_name in reward_models:
        if model_name not in dataset.column_names:
            logger.warning(f"Reward model column {model_name} not found in dataset")
            continue
            
        # Check if model has any valid scores
        has_valid_scores = False
        for i in range(rows_to_process):
            row_scores = dataset[model_name][i]
            if row_scores and any(score is not None for score in row_scores):
                has_valid_scores = True
                break
                
        if not has_valid_scores:
            logger.warning(f"Skipping {model_name} due to no valid scores")
            continue

        # Process each row
        problem_indices = []
        
        verdicts = []
        ground_truth = []
        problem_stats = defaultdict(lambda: {'tp': 0, 'fp': 0})
        
        for i in range(rows_to_process):
            row_scores = dataset[model_name][i]
            gt_row = dataset['answer_correct'][i]
            
            # Handle None or empty row_scores
            if not row_scores or all(s is None for s in row_scores):
                verdicts.append([None] * len(gt_row))
                ground_truth.append(gt_row)
                continue
            
            # For sample-level metrics: normalize and apply threshold
            normalized_scores = normalize_scores(row_scores)
            sample_verdicts = [score >= reward_threshold if score is not None else None 
                             for score in normalized_scores]
            
            # For Selection@1: only select highest scoring sample
            valid_scores = [(idx, score) for idx, score in enumerate(row_scores) if score is not None]
            if valid_scores:
                max_score_idx = max(valid_scores, key=lambda x: x[1])[0]
                selection_verdicts = [False] * len(row_scores)
                selection_verdicts[max_score_idx] = True
            else:
                selection_verdicts = [None] * len(row_scores)
            
            verdicts.append(selection_verdicts)
            ground_truth.append(gt_row)

            # Calculate per-problem stats using sample-level verdicts
            for v, gt in zip(sample_verdicts, gt_row):
                if v is not None:
                    if v and gt:
                        problem_stats[i]['tp'] += 1
                    elif v:
                        problem_stats[i]['fp'] += 1

        if verdicts:
            metrics = calculate_metrics(
                [v for row in verdicts for v in row],
                [gt for row in ground_truth for gt in row],
                [idx for idx in range(len(verdicts)) for _ in range(len(ground_truth[0]))]
            )
            metrics['problem_stats'] = dict(problem_stats)

            results['models'][f"{model_name} (reward)"] = {
                'single': {
                    'verdicts': verdicts,
                    'ground_truth': ground_truth,
                    'sample_metrics': metrics,
                    'row_metrics': calculate_row_selection_metrics(verdicts, ground_truth)
                }
            }

    return results

def calculate_majority_at_k(dataset, k: int) -> float:
    """Calculate Majority@k accuracy across problems."""
    correct_problems = 0
    total_problems = 0
    
    for i in range(len(dataset)):
        # Get answers and their correctness for this problem
        answers = dataset['extracted_answers_using_R.E.'][i]
        answer_correct = dataset['answer_correct'][i]
        
        # Count frequency of each answer, ignoring NO_ANSWER
        answer_counts = {}
        answer_correctness = {}
        for ans, is_correct in zip(answers, answer_correct):
            if ans != 'NO_ANSWER':
                if ans not in answer_counts:
                    answer_counts[ans] = 0
                    answer_correctness[ans] = []
                answer_counts[ans] += 1
                answer_correctness[ans].append(is_correct)
        
        if answer_counts:  # Only process if we have valid answers
            total_problems += 1
            # Get top k most common answers
            top_k_answers = sorted(answer_counts.items(), key=lambda x: x[1], reverse=True)[:k]
            
            # Check if any of top k answers has majority correct
            for ans, _ in top_k_answers:
                if sum(answer_correctness[ans]) > len(answer_correctness[ans]) / 2:
                    correct_problems += 1
                    break
    
    return correct_problems / total_problems if total_problems > 0 else 0.0

def print_dataset_statistics(stats: Dict, dataset):
    """Print dataset-level statistics."""
    print("\nDataset Statistics:")
    print("-" * 80)
    print(f"Total rows processed: {stats['total_rows']}")
    if stats['total_rows'] > 0:
        print(f"Rows with ≥1 correct answer: {stats['rows_with_correct']} "
              f"({stats['rows_with_correct']/stats['total_rows']:.1%})")
        print(f"First samples correct: {stats['first_correct']} "
              f"({stats['first_correct']/stats['total_rows']:.1%})")
        if 'majority_voting_total' in stats and stats['majority_voting_total'] > 0:
            print(f"Majority voting accuracy: {stats['majority_voting_correct']} / {stats['majority_voting_total']} "
                  f"({stats['majority_voting_correct']/stats['majority_voting_total']:.1%})")
        
        # Calculate and print Majority@k metrics
        for k in [1, 3, 5, 10, 20]:
            accuracy = calculate_majority_at_k(dataset, k)
            print(f"Majority@{k} accuracy: {accuracy:.1%}")
    else:
        print("No rows were successfully processed")
    print("-" * 80)

def calculate_ensemble_metrics(results: Dict, dataset, reward_threshold: float) -> Dict:
    """Process dataset and calculate metrics for all ensemble approaches."""
    # Get models with single verdict results
    models_with_single = []
    for model_name, model_results in results['models'].items():
        if ('single' in model_results and 
            'verdicts' in model_results['single'] and
            not model_name.endswith('(reward)') and  # Exclude reward models from ensembles
            not model_name.endswith('_bootstrapped') and
            not model_name.startswith('unit_tests_bootstrapped')):
            models_with_single.append(model_name)
    
    if not models_with_single:
        return {}
    
    # Get LM models (excluding unit tests) for top-3 calculations
    lm_models = [m for m in models_with_single if not m.startswith('unit_tests')]
    
    # Sort LM models by sample-level precision for top-3
    sorted_lm_models = sorted(
        lm_models,
        key=lambda m: results['models'][m]['single']['sample_metrics']['precision'],
        reverse=True
    )
    top_3_models = sorted_lm_models[:3]  # Get top 3 LM judges only
    
    # Get verdicts and ground truth
    all_verdicts = []
    all_ground_truth = []
    problem_indices = []
    
    # Use first model's data to get structure
    first_model = results['models'][sorted_lm_models[0]]['single']
    for i, gt_row in enumerate(first_model['ground_truth']):
        verdict_rows = []
        for model in models_with_single:  # Use all models for complete ensembles
            verdict_rows.append(results['models'][model]['single']['verdicts'][i])
        all_verdicts.append(verdict_rows)
        all_ground_truth.append(gt_row)
        problem_indices.append(i)
    
    ensemble_results = {}
    
    # Calculate unanimous ensemble (all judges)
    unanimous_verdicts = []
    problem_stats_unanimous = defaultdict(lambda: {'tp': 0, 'fp': 0})
    for i, (verdict_rows, gt_row) in enumerate(zip(all_verdicts, all_ground_truth)):
        # Get verdicts only from LM models (no unit tests)
        lm_indices = [j for j, model in enumerate(models_with_single) if not model.startswith('unit_tests')]
        lm_rows = [verdict_rows[j] for j in lm_indices]
        row_verdicts = list(zip(*lm_rows))
        
        unanimous = [all(v for v in col if v is not None) if any(v is not None for v in col) else None
                    for col in row_verdicts]
        unanimous_verdicts.append(unanimous)
        
        # Calculate per-problem stats
        for v, gt in zip(unanimous, gt_row):
            if v is not None:
                if v and gt:
                    problem_stats_unanimous[i]['tp'] += 1
                elif v:
                    problem_stats_unanimous[i]['fp'] += 1
    
    metrics = calculate_metrics(
        [v for row in unanimous_verdicts for v in row],
        [gt for row in all_ground_truth for gt in row],
        [idx for idx in problem_indices for _ in range(len(all_ground_truth[0]))]
    )
    metrics['problem_stats'] = dict(problem_stats_unanimous)
    ensemble_results['all_unanimous'] = metrics
    
    # Calculate top-3 unanimous ensemble
    top_3_unanimous_verdicts = []
    problem_stats_top3_unanimous = defaultdict(lambda: {'tp': 0, 'fp': 0})
    for i, (verdict_rows, gt_row) in enumerate(zip(all_verdicts, all_ground_truth)):
        # Get verdicts only from top 3 LM models
        top_3_indices = [j for j, model in enumerate(models_with_single) if model in top_3_models]
        top_3_rows = [verdict_rows[j] for j in top_3_indices]
        row_verdicts = list(zip(*top_3_rows))
        unanimous = [all(v for v in col if v is not None) if any(v is not None for v in col) else None
                    for col in row_verdicts]
        top_3_unanimous_verdicts.append(unanimous)
        
        # Calculate per-problem stats
        for v, gt in zip(unanimous, gt_row):
            if v is not None:
                if v and gt:
                    problem_stats_top3_unanimous[i]['tp'] += 1
                elif v:
                    problem_stats_top3_unanimous[i]['fp'] += 1
    
    metrics = calculate_metrics(
        [v for row in top_3_unanimous_verdicts for v in row],
        [gt for row in all_ground_truth for gt in row],
        [idx for idx in problem_indices for _ in range(len(all_ground_truth[0]))]
    )
    metrics['problem_stats'] = dict(problem_stats_top3_unanimous)
    ensemble_results['top_3_unanimous'] = metrics
    
    # Calculate majority voting ensemble (all judges)
    majority_verdicts = []
    problem_stats_majority = defaultdict(lambda: {'tp': 0, 'fp': 0})
    for i, (verdict_rows, gt_row) in enumerate(zip(all_verdicts, all_ground_truth)):
        row_verdicts = list(zip(*verdict_rows))
        majority = []
        for col in row_verdicts:
            valid_votes = [v for v in col if v is not None]
            if valid_votes:
                true_count = sum(1 for v in valid_votes if v)
                majority.append(true_count > len(valid_votes) / 2)
            else:
                majority.append(None)
        majority_verdicts.append(majority)
        
        # Calculate per-problem stats
        for v, gt in zip(majority, gt_row):
            if v is not None:
                if v and gt:
                    problem_stats_majority[i]['tp'] += 1
                elif v:
                    problem_stats_majority[i]['fp'] += 1
    
    metrics = calculate_metrics(
        [v for row in majority_verdicts for v in row],
        [gt for row in all_ground_truth for gt in row],
        [idx for idx in problem_indices for _ in range(len(all_ground_truth[0]))]
    )
    metrics['problem_stats'] = dict(problem_stats_majority)
    ensemble_results['complete_majority'] = metrics
    
    # Calculate top-3 majority voting ensemble
    top_3_majority_verdicts = []
    problem_stats_top3_majority = defaultdict(lambda: {'tp': 0, 'fp': 0})
    for i, (verdict_rows, gt_row) in enumerate(zip(all_verdicts, all_ground_truth)):
        # Get verdicts only from top 3 LM models
        top_3_indices = [j for j, model in enumerate(models_with_single) if model in top_3_models]
        top_3_rows = [verdict_rows[j] for j in top_3_indices]
        row_verdicts = list(zip(*top_3_rows))
        majority = []
        for col in row_verdicts:
            valid_votes = [v for v in col if v is not None]
            if valid_votes:
                true_count = sum(1 for v in valid_votes if v)
                majority.append(true_count > len(valid_votes) / 2)
            else:
                majority.append(None)
        top_3_majority_verdicts.append(majority)
        
        # Calculate per-problem stats
        for v, gt in zip(majority, gt_row):
            if v is not None:
                if v and gt:
                    problem_stats_top3_majority[i]['tp'] += 1
                elif v:
                    problem_stats_top3_majority[i]['fp'] += 1
    
    metrics = calculate_metrics(
        [v for row in top_3_majority_verdicts for v in row],
        [gt for row in all_ground_truth for gt in row],
        [idx for idx in problem_indices for _ in range(len(all_ground_truth[0]))]
    )
    metrics['problem_stats'] = dict(problem_stats_top3_majority)
    ensemble_results['top_3_majority'] = metrics

    # Add two-stage frequency + unanimous ranking approach
    for k in [3, 5, 10, 20]:
        freq_unanimous_verdicts = []
        problem_stats_freq_unanimous = defaultdict(lambda: {'tp': 0, 'fp': 0})
        freq_strict_unanimous_verdicts = []
        problem_stats_freq_strict_unanimous = defaultdict(lambda: {'tp': 0, 'fp': 0})
        
        for i, (verdict_rows, gt_row) in enumerate(zip(all_verdicts, all_ground_truth)):
            # Get answers and their frequencies for this problem
            answers = dataset['extracted_answers_using_R.E.'][i]
            answer_counts = {}
            for ans in answers:
                if ans != 'NO_ANSWER':
                    answer_counts[ans] = answer_counts.get(ans, 0) + 1
            
            # Get top K most frequent answers
            top_k_answers = sorted(answer_counts.items(), key=lambda x: x[1], reverse=True)[:k]
            if not top_k_answers:
                freq_unanimous_verdicts.append([None] * len(gt_row))
                freq_strict_unanimous_verdicts.append([None] * len(gt_row))
                continue
                
            # Get LM judge verdicts (excluding unit tests)
            lm_indices = [j for j, model in enumerate(models_with_single) if not model.startswith('unit_tests')]
            lm_rows = [verdict_rows[j] for j in lm_indices]
            
            # Calculate scores for both variants
            answer_scores = {}
            strict_answer_scores = {}
            for ans, _ in top_k_answers:
                # Find samples with this answer
                ans_indices = [idx for idx, a in enumerate(answers) if a == ans]
                if not ans_indices:
                    continue
                    
                # Regular scoring (average of verdicts)
                total_verdicts = 0
                true_verdicts = 0
                for idx in ans_indices:
                    for judge_row in lm_rows:
                        if judge_row[idx] is not None:
                            total_verdicts += 1
                            if judge_row[idx]:
                                true_verdicts += 1
                
                if total_verdicts > 0:
                    answer_scores[ans] = true_verdicts / total_verdicts
                
                # Strict unanimous scoring
                total_samples = 0
                unanimous_samples = 0
                for idx in ans_indices:
                    verdicts = [judge_row[idx] for judge_row in lm_rows if judge_row[idx] is not None]
                    if verdicts:  # Only count if we have at least one verdict
                        total_samples += 1
                        if all(verdicts):  # All judges must agree True
                            unanimous_samples += 1
                
                if total_samples > 0:
                    strict_answer_scores[ans] = unanimous_samples / total_samples
            
            # Select answers with highest scores for both variants
            if answer_scores:
                best_answer = max(answer_scores.items(), key=lambda x: x[1])[0]
                row_verdict = [ans == best_answer for ans in answers]
                freq_unanimous_verdicts.append(row_verdict)
                
                # Calculate per-problem stats
                for v, gt in zip(row_verdict, gt_row):
                    if v is not None:
                        if v and gt:
                            problem_stats_freq_unanimous[i]['tp'] += 1
                        elif v:
                            problem_stats_freq_unanimous[i]['fp'] += 1
            else:
                freq_unanimous_verdicts.append([None] * len(gt_row))
            
            if strict_answer_scores:
                best_strict_answer = max(strict_answer_scores.items(), key=lambda x: x[1])[0]
                row_strict_verdict = [ans == best_strict_answer for ans in answers]
                freq_strict_unanimous_verdicts.append(row_strict_verdict)
                
                # Calculate per-problem stats
                for v, gt in zip(row_strict_verdict, gt_row):
                    if v is not None:
                        if v and gt:
                            problem_stats_freq_strict_unanimous[i]['tp'] += 1
                        elif v:
                            problem_stats_freq_strict_unanimous[i]['fp'] += 1
            else:
                freq_strict_unanimous_verdicts.append([None] * len(gt_row))
        
        # Calculate metrics for regular variant
        metrics = calculate_metrics(
            [v for row in freq_unanimous_verdicts for v in row],
            [gt for row in all_ground_truth for gt in row],
            [idx for idx in problem_indices for _ in range(len(all_ground_truth[0]))]
        )
        metrics['problem_stats'] = dict(problem_stats_freq_unanimous)
        ensemble_results[f'Two-Stage (Top-{k} Freq + LM Judge Ranking)'] = metrics
        
        # Calculate metrics for strict unanimous variant
        metrics = calculate_metrics(
            [v for row in freq_strict_unanimous_verdicts for v in row],
            [gt for row in all_ground_truth for gt in row],
            [idx for idx in problem_indices for _ in range(len(all_ground_truth[0]))]
        )
        metrics['problem_stats'] = dict(problem_stats_freq_strict_unanimous)
        ensemble_results[f'Two-Stage (Top-{k} Freq + Unanimous LM Judge Ranking)'] = metrics

    # Get normalized scores for reward models
    normalized_scores_cache = {}
    all_reward_models = [model for model in models_with_single if 'reward' in model]
    
    # Sort RMs by Selection@1 accuracy
    rm_accuracies = []
    for model in all_reward_models:
        if 'selection_metrics' in results['models'][model]['single']:
            accuracy = results['models'][model]['single']['selection_metrics']['accuracy']
            rm_accuracies.append((model, accuracy))
    sorted_rms = [rm for rm, _ in sorted(rm_accuracies, key=lambda x: x[1], reverse=True)]
    
    # Pre-calculate normalized scores
    for rm_name in sorted_rms:
        if rm_name not in dataset.column_names:
            continue
        normalized_scores_cache[rm_name] = []
        for i in problem_indices:
            row_scores = dataset[rm_name][i]
            if not row_scores:
                normalized_scores_cache[rm_name].append([None] * len(dataset['answer_correct'][i]))
            else:
                normalized_scores_cache[rm_name].append(normalize_scores(row_scores))

    # Direct RM ensemble results
    for num_rms in [1, 3, 5, 10, None]:
        ensemble_name = f'Top-{num_rms if num_rms else "All"} RM Ensemble'
        selected_rms = sorted_rms[:num_rms] if num_rms else sorted_rms
        
        row_verdicts = []
        row_ground_truth = []
        
        for i, idx in enumerate(problem_indices):
            answers = dataset['extracted_answers_using_R.E.'][idx]
            ground_truth = dataset['answer_correct'][idx]
            
            if not answers:
                row_verdicts.append([None] * len(ground_truth))
                row_ground_truth.append(ground_truth)
                continue
                
            # Get samples that pass threshold for all RMs
            passing_samples = defaultdict(list)  # answer -> list of sample indices
            for sample_idx, answer in enumerate(answers):
                if answer == 'NO_ANSWER':
                    continue
                    
                # Check if sample passes threshold for all RMs
                passes_all = True
                for rm_name in selected_rms:
                    if rm_name not in normalized_scores_cache:
                        continue
                    score = normalized_scores_cache[rm_name][i][sample_idx]
                    if score is None or score < reward_threshold:
                        passes_all = False
                        break
                
                if passes_all:
                    passing_samples[answer].append(sample_idx)
            
            # Calculate ratio of passing samples for each answer
            answer_ratios = []
            for answer, indices in passing_samples.items():
                total_samples = sum(1 for idx, ans in enumerate(answers) if ans == answer)
                ratio = len(indices) / total_samples if total_samples > 0 else 0
                answer_ratios.append((answer, ratio, indices))
            
            # Select answer with highest ratio of passing samples
            verdict = [False] * len(ground_truth)
            if answer_ratios:
                best_answer = max(answer_ratios, key=lambda x: x[1])
                selected_indices = best_answer[2]
                for idx in selected_indices:
                    if idx < len(verdict):
                        verdict[idx] = True
            
            row_verdicts.append(verdict)
            row_ground_truth.append(ground_truth)
        
        metrics = calculate_metrics(
            [v for row in row_verdicts for v in row],
            [gt for row in row_ground_truth for gt in row],
            problem_indices * len(row_ground_truth[0])
        )
        ensemble_results[ensemble_name] = metrics

    # Two-stage frequency + reward model ranking approaches
    for k in [3, 5, 10, 20]:
        ensemble_name = f'Two-Stage (Top-{k} Freq + Top-3 RM Ranking)'
        row_verdicts = []
        row_ground_truth = []
        
        for i, idx in enumerate(problem_indices):
            answers = dataset['extracted_answers_using_R.E.'][idx]
            ground_truth = dataset['answer_correct'][idx]
            
            if not answers:
                row_verdicts.append([None] * len(ground_truth))
                row_ground_truth.append(ground_truth)
                continue
            
            # First stage: Get top-K most frequent answers
            answer_counts = Counter(ans for ans in answers if ans != 'NO_ANSWER')
            top_k_answers = answer_counts.most_common(k)
            
            if not top_k_answers:
                row_verdicts.append([None] * len(ground_truth))
                row_ground_truth.append(ground_truth)
                continue
            
            # Second stage: Score answers using RM threshold approach
            answer_scores = []
            for answer, _ in top_k_answers:
                sample_indices = [idx for idx, ans in enumerate(answers) if ans == answer]
                passing_indices = []
                
                for sample_idx in sample_indices:
                    passes_all = True
                    for rm_name in sorted_rms[:3]:  # Top-3 RMs
                        if rm_name not in normalized_scores_cache:
                            continue
                        score = normalized_scores_cache[rm_name][i][sample_idx]
                        if score is None or score < reward_threshold:
                            passes_all = False
                            break
                    if passes_all:
                        passing_indices.append(sample_idx)
                
                ratio = len(passing_indices) / len(sample_indices) if sample_indices else 0
                answer_scores.append((answer, ratio, passing_indices))
            
            # Select answer with highest ratio
            verdict = [False] * len(ground_truth)
            if answer_scores:
                best_answer = max(answer_scores, key=lambda x: x[1])
                for idx in best_answer[2]:
                    if idx < len(verdict):
                        verdict[idx] = True
            
            row_verdicts.append(verdict)
            row_ground_truth.append(ground_truth)
        
        metrics = calculate_metrics(
            [v for row in row_verdicts for v in row],
            [gt for row in row_ground_truth for gt in row],
            problem_indices * len(row_ground_truth[0])
        )
        ensemble_results[ensemble_name] = metrics

    return ensemble_results

def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description='Analyze merged judge results dataset')
    parser.add_argument('--dataset', type=str, required=True,
                      help='Path to merged dataset')
    parser.add_argument('--max_rows', type=int,
                      help='Maximum number of rows to process')
    parser.add_argument('--reward_threshold', type=float, default=0.5,
                      help='Threshold for reward model scores (after normalization)')
    parser.add_argument('--verbose', action='store_true',
                      help='Enable verbose output')
    args = parser.parse_args()
    
    # Load dataset
    logger.info(f"Loading dataset: {args.dataset}")
    try:
        dataset = load_from_disk(args.dataset)
    except:
        dataset = load_dataset(args.dataset)['data']
    
    # Process dataset
    results = process_dataset(dataset, args.max_rows, reward_threshold=args.reward_threshold)
    
    # Print results
    print_results_table(results)
    print_selection_accuracies(results)
    print_ensemble_results(results, dataset, args.reward_threshold)
    print_dataset_statistics(results['dataset_stats'], dataset)

if __name__ == "__main__":
    main()
