import torch
from datasets import load_dataset
import numpy as np
import argparse
from scipy import stats as scipy_stats
import re
import os
import logging
import time
from typing import List, Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from collections import defaultdict
from utils import generate_openai, generate_anthropic, generate_together

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Global statistics for tracking API calls
generation_stats = {
    'total_attempts': 0,
    'failed_attempts': 0,
    'retried_attempts': 0,
    'failures_by_model': defaultdict(int),
    'retries_by_model': defaultdict(int)
}

JUDGE_MODELS = [
    "gpt-4o",
    "gpt-4o-mini",
    "claude-3-5-sonnet-latest",
    "claude-3-5-haiku-latest",
    "meta-llama/Llama-3.3-70B-Instruct-Turbo",
    "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
    "Qwen/Qwen2-72B-Instruct",
    "Qwen/Qwen2.5-72B-Instruct-Turbo",
    "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"
]

def process_with_retries(func, *args, model_name=None, max_retries=5, **kwargs):
    """Execute function with exponential backoff retry logic."""
    generation_stats['total_attempts'] += 1
    
    for attempt in range(max_retries):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            if model_name:
                generation_stats['retries_by_model'][model_name] += 1
            generation_stats['retried_attempts'] += 1
            
            if attempt == max_retries - 1:
                if model_name:
                    generation_stats['failures_by_model'][model_name] += 1
                generation_stats['failed_attempts'] += 1
                logger.error(f"All {max_retries} attempts failed for {model_name}: {str(e)}")
                raise e
                
            sleep_time = 2 ** attempt  # 1, 2, 4, 8, 16 seconds
            logger.warning(f"Attempt {attempt + 1} failed for {model_name}, retrying in {sleep_time}s: {str(e)}")
            time.sleep(sleep_time)

def get_generate_function(model_name: str):
    """Return the appropriate generate function based on model name."""
    if "gpt" in model_name.lower():
        return generate_openai
    elif "claude" in model_name.lower():
        return generate_anthropic
    else:
        return generate_together

def create_evaluation_prompt(instruction: str, sample: str) -> List[Dict[str, str]]:
    """Creates a prompt for evaluating solutions with detailed analysis."""
    system_message = """You are a rigorous evaluator for reasoning tasks spanning mathematics, science, coding, logic, and other technical domains. Your task is to determine if a given solution demonstrates valid reasoning and reaches the correct conclusion.
    
    First, analyze the solution by carefully examining the reasoning and steps presented. Then provide your verdict as EXACTLY "True" or "False" on a new line at the end of your response.
    
    Example response format:
    The solution demonstrates valid reasoning because...
    True

    or:
    The solution contains an error in...
    False"""
    
    user_message = f"""Problem: {instruction}

Solution to evaluate:
{sample}

Evaluate this solution. Provide your analysis, then on a new line respond with EXACTLY True or False:"""

    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ]

def extract_verdict(response: str) -> Optional[bool]:
    """Extract True/False verdict from model response."""
    if not response:
        return None
        
    # Split into lines and remove empty lines
    lines = [line.strip() for line in response.split('\n') if line.strip()]
    
    if not lines:
        return None
        
    # Check last few non-empty lines for verdict
    for line in reversed(lines[-3:]):  # Check last 3 non-empty lines
        cleaned_line = line.lower().strip('.,!? \t\n\r')
        if cleaned_line == "true":
            return True
        elif cleaned_line == "false":
            return False
            
    return None

def calculate_judge_verdicts(model_name: str, prompt: str, response: str, temperature: float = 0.0, 
                           is_correct=None, verbose=False) -> Optional[bool]:
    """Get verdict from judge model for a response."""
    try:
        # Get appropriate generate function
        generate_fn = get_generate_function(model_name)
        
        # Create evaluation messages
        messages = create_evaluation_prompt(prompt, response)
        
        # Get model's response with retries
        judge_response = process_with_retries(
            generate_fn,
            model=model_name,
            messages=messages,
            temperature=temperature,
            model_name=model_name
        )
        
        if verbose:
            print(f"\nJudge Model: {model_name}")
            print("Response:", judge_response)
            
        # Extract verdict
        verdict = extract_verdict(judge_response)
        
        if verbose:
            print(f"Extracted verdict: {verdict}")
            if is_correct is not None:
                print(f"Ground truth: {is_correct}")
                
        return verdict
        
    except Exception as e:
        logger.error(f"Error getting verdict from {model_name}: {str(e)}")
        return None

def get_bootstrapped_verdict(model_name: str, prompt: str, response: str, num_samples: int = 10, 
                           temperature: float = 0.7, num_workers: int = 4) -> Optional[bool]:
    """Get majority verdict from multiple parallel generations.
    
    Args:
        model_name: Name of the judge model
        prompt: The problem/instruction
        response: The solution to evaluate
        num_samples: Number of times to sample from the judge model
        temperature: Temperature for generation sampling
        num_workers: Number of parallel workers for sampling
    """
    try:
        generate_fn = get_generate_function(model_name)
        messages = create_evaluation_prompt(prompt, response)
        
        def get_single_verdict():
            judge_response = process_with_retries(
                generate_fn,
                model=model_name,
                messages=messages,
                temperature=temperature,
                model_name=model_name
            )
            return extract_verdict(judge_response)
        
        # Collect verdicts using parallel execution
        verdicts = []
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            future_verdicts = [executor.submit(get_single_verdict) for _ in range(num_samples)]
            for future in future_verdicts:
                try:
                    verdict = future.result()
                    if verdict is not None:
                        verdicts.append(verdict)
                except Exception as e:
                    logger.error(f"Error in parallel verdict collection: {str(e)}")
        
        if not verdicts:
            return None
            
        # Return majority verdict
        true_count = sum(1 for v in verdicts if v)
        return true_count > len(verdicts) / 2
        
    except Exception as e:
        logger.error(f"Error getting bootstrapped verdict from {model_name}: {str(e)}")
        return None

def print_selection_accuracies(results: Dict):
    """Print Selection@1 accuracy table."""
    print("\nSelection@1 Accuracies:")
    print("-" * 108)
    header = "{:<60} | {:<8} | {:<4} | {:<4} | {:<4} | {:<4} | {:<12}".format(
        "Model", "Accuracy", "TP", "TN", "FP", "FN", "Correct/Total"
    )
    print(header)
    print("-" * 108)
    
    for model_name in [m for m in results.keys() if m != 'dataset_stats']:
        # Single verdict approach
        single_selections = results[model_name]['single'].get('row_selections', [])
        if single_selections:
            # Get ground truth for each row
            ground_truth_by_row = results[model_name]['single'].get('ground_truth_by_row', [])
            has_correct_answers = [any(row) for row in ground_truth_by_row]
            
            # Count TP, TN, FP, FN based on row selections and whether row had correct answers
            tp = sum(1 for sel, has_correct in zip(single_selections, has_correct_answers) 
                    if sel and has_correct)  # Selected correct when correct exists
            tn = sum(1 for sel, has_correct in zip(single_selections, has_correct_answers) 
                    if sel and not has_correct)  # Correctly abstained when no correct exists
            fp = sum(1 for sel, has_correct in zip(single_selections, has_correct_answers) 
                    if not sel and not has_correct)  # Failed to abstain when no correct exists
            fn = sum(1 for sel, has_correct in zip(single_selections, has_correct_answers) 
                    if not sel and has_correct)  # Failed to select when correct exists
            
            total = len(single_selections)
            correct = tp + tn  # Both True Positives and True Negatives count as correct
            accuracy = correct / total if total > 0 else 0
            print("{:<60} | {:>7.1%} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4}/{:<4}".format(
                f"{model_name} (single, temp=0.0)", accuracy, tp, tn, fp, fn, correct, total
            ))
        else:
            print(f"{model_name} (single, temp=0.0)".ljust(60), "| No valid results")
        
        # Bootstrapped approach
        bootstrapped_selections = results[model_name]['bootstrapped'].get('row_selections', [])
        if bootstrapped_selections:
            # Get ground truth for each row
            ground_truth_by_row = results[model_name]['bootstrapped'].get('ground_truth_by_row', [])
            has_correct_answers = [any(row) for row in ground_truth_by_row]
            
            # Count TP, TN, FP, FN based on row selections and whether row had correct answers
            tp = sum(1 for sel, has_correct in zip(bootstrapped_selections, has_correct_answers) 
                    if sel and has_correct)  # Selected correct when correct exists
            tn = sum(1 for sel, has_correct in zip(bootstrapped_selections, has_correct_answers) 
                    if sel and not has_correct)  # Correctly abstained when no correct exists
            fp = sum(1 for sel, has_correct in zip(bootstrapped_selections, has_correct_answers) 
                    if not sel and not has_correct)  # Failed to abstain when no correct exists
            fn = sum(1 for sel, has_correct in zip(bootstrapped_selections, has_correct_answers) 
                    if not sel and has_correct)  # Failed to select when correct exists
            
            total = len(bootstrapped_selections)
            correct = tp + tn  # Both True Positives and True Negatives count as correct
            accuracy = correct / total if total > 0 else 0
            print("{:<60} | {:>7.1%} | {:>4} | {:>4} | {:>4} | {:>4} | {:>4}/{:<4}".format(
                f"{model_name} (multi, temp=0.7)", accuracy, tp, tn, fp, fn, correct, total
            ))
        else:
            print(f"{model_name} (multi, temp=0.7)".ljust(60), "| No valid bootstrapped results")
    print("-" * 108)

def process_sample(model_name: str, prompt: str, response: str, is_correct: bool, 
                  num_judge_samples: int, num_sample_workers: int, verbose: bool):
    """Process a single sample with both single and bootstrapped verdicts."""
    results = {
        'single': None,
        'bootstrapped': None,
        'bootstrapped_verdicts': [],
        'bootstrapped_counts': {'correct': 0, 'total': 0},  # Add this
        'ground_truth': is_correct
    }
    
    try:
        # Get single verdict (temperature = 0.0)
        results['single'] = calculate_judge_verdicts(
            model_name=model_name,
            prompt=prompt,
            response=response,
            temperature=0.0,
            is_correct=is_correct,
            verbose=verbose
        )
        
        # Get multiple verdicts for bootstrapped approach only if num_judge_samples > 1
        if num_judge_samples > 1:
            verdicts = []
            for _ in range(num_judge_samples):
                verdict = calculate_judge_verdicts(
                    model_name=model_name,
                    prompt=prompt,
                    response=response,
                    temperature=0.7,
                    is_correct=is_correct,
                    verbose=verbose
                )
                if verdict is not None:
                    verdicts.append(verdict)
            
            results['bootstrapped_verdicts'] = verdicts
            if verdicts:
                results['bootstrapped'] = sum(1 for v in verdicts if v) > len(verdicts) / 2
                # Update counts
                results['bootstrapped_counts']['correct'] = sum(1 for v in verdicts if v)
                results['bootstrapped_counts']['total'] = len(verdicts)
        
    except Exception as e:
        logger.error(f"Error processing sample with {model_name}: {str(e)}")
    
    return results

def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Calculate judge verdicts for datasets')
    parser.add_argument('--dataset', type=str, required=True,
                      help='Dataset to analyze (required)')
    parser.add_argument('--max_rows', type=int, default=None,
                      help='Maximum number of rows to process')
    parser.add_argument('--max_samples', type=int, default=None,
                      help='Maximum number of samples per row to process')
    parser.add_argument('--num_judge_samples', type=int, default=10,
                      help='Number of times to sample from each judge model for bootstrapped verdict (default: 10)')
    parser.add_argument('--num_row_workers', type=int, default=4,
                      help='Number of parallel workers for processing rows (default: 4)')
    parser.add_argument('--num_sample_workers', type=int, default=4,
                      help='Number of parallel workers for processing samples within each row (default: 4)')
    parser.add_argument('--verbose', action='store_true',
                      help='Enable verbose output')
    parser.add_argument('--models', nargs='+', default=JUDGE_MODELS,
                      help=f'Judge models to use (default: all). Available models: {JUDGE_MODELS}')
    return parser.parse_args()

def print_generation_stats():
    """Print statistics about API call attempts and failures."""
    print("\nGeneration Statistics:")
    print("-" * 60)
    print(f"Total API calls attempted: {generation_stats['total_attempts']}")
    print(f"Failed attempts: {generation_stats['failed_attempts']}")
    print(f"Retried attempts: {generation_stats['retried_attempts']}")
    
    if generation_stats['failures_by_model']:
        print("\nFailures by model:")
        for model, count in generation_stats['failures_by_model'].items():
            print(f"  {model}: {count} failures")
    
    if generation_stats['retries_by_model']:
        print("\nRetries by model:")
        for model, count in generation_stats['retries_by_model'].items():
            print(f"  {model}: {count} retries")
    print("-" * 60)

def calculate_approach_statistics(verdicts: List[Optional[bool]], ground_truth: List[bool], 
                                row_selections: List[bool]):
    """Calculate statistics for a single approach (single or bootstrapped)."""
    # Filter out None values
    valid_pairs = [(v, gt) for v, gt in zip(verdicts, ground_truth) if v is not None]
    if not valid_pairs:
        return {
            'correlation': None,
            'p_value': None,
            'accuracy': None,
            'correct': 0,
            'total': len(verdicts),
            'selection_accuracy': None if not row_selections else sum(row_selections) / len(row_selections),
            'selection_correct': 0 if not row_selections else sum(row_selections),
            'selection_total': 0 if not row_selections else len(row_selections)
        }
    
    filtered_verdicts, filtered_ground_truth = zip(*valid_pairs)
    
    # Convert to numeric for correlation calculation
    numeric_verdicts = [1 if v else 0 for v in filtered_verdicts]
    numeric_ground_truth = [1 if gt else 0 for gt in filtered_ground_truth]
    
    # Calculate point-biserial correlation
    point_biserial = scipy_stats.pointbiserialr(numeric_verdicts, numeric_ground_truth)
    
    # Calculate accuracy - judge's verdict should match the actual correctness
    # If judge says True, they think answer is correct, so should match ground truth
    correct_predictions = sum(1 for v, gt in zip(filtered_verdicts, filtered_ground_truth) if v is gt)
    
    return {
        'correlation': point_biserial.correlation,
        'p_value': point_biserial.pvalue,
        'accuracy': correct_predictions / len(filtered_verdicts),
        'correct': correct_predictions,
        'total': len(verdicts),
        'selection_accuracy': None if not row_selections else sum(row_selections) / len(row_selections),
        'selection_correct': 0 if not row_selections else sum(row_selections),
        'selection_total': 0 if not row_selections else len(row_selections)
    }

def print_results_table(results: Dict):
    """Print formatted results table with both single and bootstrapped results."""
    print("\nResults:")
    print("-" * 140)
    print(f"{'Model':<60} | {'Accuracy':<8} | {'Correlation':<10} | {'P-value':<10} | {'Correct/Total'}")
    print("-" * 140)
    
    for model_name, model_results in results.items():
        if model_name == 'dataset_stats':  # Skip dataset stats
            continue
            
        # Print single verdict results
        print(f"{model_name} (single, temp=0.0)"[:60].ljust(60), end=" | ")
        verdicts = model_results['single']['verdicts']
        ground_truth = model_results['single']['ground_truth']
        
        if verdicts:
            valid_pairs = [(v, gt) for v, gt in zip(verdicts, ground_truth) if v is not None]
            if valid_pairs:
                correct = sum(1 for v, gt in valid_pairs if v == gt)
                total = len(valid_pairs)
                accuracy = correct / total
                print(f"{accuracy:>7.1%} | ", end="")
                print(f"{'-':>9} | ", end="")  # Placeholder for correlation
                print(f"{'-':>9} | ", end="")  # Placeholder for p-value
                print(f"{correct}/{total}")
            else:
                print("No valid results")
        else:
            print("No valid results")
            
        # Print bootstrapped results
        print(f"{model_name} (multi, temp=0.7)"[:60].ljust(60), end=" | ")
        if not model_results.get('bootstrapped', {}).get('verdicts', []):
            print("Bootstrapping disabled (requires num_judge_samples > 1)")
        else:
            verdicts = model_results['bootstrapped']['verdicts']
            ground_truth = model_results['bootstrapped']['ground_truth']
            
            if verdicts:
                valid_pairs = [(v, gt) for v, gt in zip(verdicts, ground_truth) if v is not None]
                if valid_pairs:
                    correct = sum(1 for v, gt in valid_pairs if v == gt)
                    total = len(valid_pairs)
                    accuracy = correct / total
                    print(f"{accuracy:>7.1%} | ", end="")
                    print(f"{'-':>9} | ", end="")  # Placeholder for correlation
                    print(f"{'-':>9} | ", end="")  # Placeholder for p-value
                    print(f"{correct}/{total}")
                else:
                    print("No valid results")
            else:
                print("No valid results")
        print("-" * 140)

def process_dataset(dataset, model_names: List[str], max_rows=None, max_samples=None, 
                   num_judge_samples: int = 10, num_row_workers: int = 4, 
                   num_sample_workers: int = 4, verbose=False):
    """Process dataset with parallel row processing."""
    total_rows = len(dataset)
    rows_to_process = range(min(max_rows, total_rows) if max_rows is not None else total_rows)
    
    # Initialize result dictionaries
    all_results = {model: {
        'single': {
            'verdicts': [],
            'ground_truth': [],
            'ground_truth_by_row': [],
            'row_selections': [],
            'bootstrapped_counts': []
        },
        'bootstrapped': {
            'verdicts': [],
            'ground_truth': [],
            'ground_truth_by_row': [],
            'row_selections': []
        }
    } for model in model_names}
    
    # Add dataset_stats at top level
    all_results['dataset_stats'] = {
        'total_rows': 0,
        'rows_with_correct': 0,
        'first_correct': 0,
        'ground_truth_by_row': []
    }
    
    logger.info(f"Processing {len(rows_to_process)} rows with {num_row_workers} row workers "
                f"and {num_sample_workers} sample workers per row")
    
    # Process rows in parallel
    with ThreadPoolExecutor(max_workers=num_row_workers) as executor:
        future_to_row = {}
        
        for i in rows_to_process:
            logger.info(f"Submitting row {i + 1}/{len(rows_to_process)}")
            future = executor.submit(
                process_row,
                row_data=dataset[i],
                model_names=model_names,
                num_judge_samples=num_judge_samples,
                num_sample_workers=num_sample_workers,
                max_samples=max_samples,
                verbose=verbose
            )
            future_to_row[future] = i
        
        # Collect results
        for future in future_to_row:
            row_idx = future_to_row[future]
            try:
                row_results = future.result()
                logger.info(f"Completed row {row_idx + 1}/{len(rows_to_process)}")
                
                # Get ground truth for dataset stats (using first model's results)
                first_model = next(iter(model_names))
                ground_truth = row_results[first_model]['single']['ground_truth']
                
                # Update dataset stats
                all_results['dataset_stats']['total_rows'] += 1
                all_results['dataset_stats']['ground_truth_by_row'].append(ground_truth)
                if any(ground_truth):
                    all_results['dataset_stats']['rows_with_correct'] += 1
                if ground_truth[0]:
                    all_results['dataset_stats']['first_correct'] += 1
                
                # Aggregate model results
                for model_name in model_names:
                    model_results = all_results[model_name]
                    
                    # Aggregate single verdict results
                    model_results['single']['verdicts'].extend(row_results[model_name]['single']['verdicts'])
                    model_results['single']['ground_truth'].extend(row_results[model_name]['single']['ground_truth'])
                    model_results['single']['ground_truth_by_row'].extend(row_results[model_name]['single']['ground_truth_by_row'])
                    model_results['single']['row_selections'].extend(row_results[model_name]['single']['row_selections'])
                    
                    # Aggregate bootstrapped results if available
                    if num_judge_samples > 1:
                        model_results['bootstrapped']['verdicts'].extend(row_results[model_name]['bootstrapped']['verdicts'])
                        model_results['bootstrapped']['ground_truth'].extend(row_results[model_name]['bootstrapped']['ground_truth'])
                        model_results['bootstrapped']['ground_truth_by_row'].extend(row_results[model_name]['bootstrapped']['ground_truth_by_row'])
                        model_results['bootstrapped']['row_selections'].extend(row_results[model_name]['bootstrapped']['row_selections'])
                    
            except Exception as e:
                logger.error(f"Error processing row {row_idx}: {str(e)}")
    
    return all_results

def process_row(row_data: Dict, model_names: List[str], num_judge_samples: int,
                num_sample_workers: int, max_samples: Optional[int] = None, verbose: bool = False):
    """Process a single row with parallel sample processing."""
    problem = row_data['problem']
    samples = row_data['samples'][:max_samples] if max_samples is not None else row_data['samples']
    correctness = row_data['answer_correct'][:max_samples] if max_samples is not None else row_data['answer_correct']
    
    # Check if row has any correct samples
    has_correct_sample = any(correctness)
    
    results = {model: {
        'single': {
            'verdicts': [],
            'ground_truth': [],
            'ground_truth_by_row': [],
            'row_selections': []
        },
        'bootstrapped': {
            'verdicts': [],
            'ground_truth': [],
            'ground_truth_by_row': [],
            'row_selections': []
        }
    } for model in model_names}
    
    try:
        # Process samples...
        for model_name in model_names:
            bootstrapped_counts = []  # Track counts for each sample
            
            for sample, is_correct in zip(samples, correctness):
                # Get single verdict
                single_result = process_sample(
                    model_name=model_name,
                    prompt=problem,
                    response=sample,
                    is_correct=is_correct,
                    num_judge_samples=1,  # Always 1 for single verdict
                    num_sample_workers=num_sample_workers,
                    verbose=verbose
                )
                
                # Store results
                results[model_name]['single']['verdicts'].append(single_result['single'])
                results[model_name]['single']['ground_truth'].append(is_correct)
                
                # For bootstrapped approach
                if num_judge_samples > 1:
                    bootstrapped_result = process_sample(
                        model_name=model_name,
                        prompt=problem,
                        response=sample,
                        is_correct=is_correct,
                        num_judge_samples=num_judge_samples,
                        num_sample_workers=num_sample_workers,
                        verbose=verbose
                    )
                    results[model_name]['bootstrapped']['verdicts'].append(bootstrapped_result['bootstrapped'])
                    results[model_name]['bootstrapped']['ground_truth'].append(is_correct)
                    # Store the count of True verdicts for this sample
                    bootstrapped_counts.append(
                        sum(1 for v in bootstrapped_result['bootstrapped_verdicts'] if v)
                    )
            
            # Store ground truth for this row
            results[model_name]['single']['ground_truth_by_row'].append(correctness)
            results[model_name]['bootstrapped']['ground_truth_by_row'].append(correctness)
            
            # Calculate Selection@1 for single verdict approach
            single_verdicts = results[model_name]['single']['verdicts']
            logger.info(f"Row verdicts: {single_verdicts}")
            logger.info(f"Row correctness: {correctness}")
            logger.info(f"Has correct sample: {has_correct_sample}")
            
            try:
                # Find first sample that got True verdict
                first_true_idx = next(i for i, v in enumerate(single_verdicts) if v is True)
                # If we picked a sample, check if it's correct
                if correctness[first_true_idx]:
                    # True Positive: picked a correct answer when one exists
                    results[model_name]['single']['row_selections'].append(True)
                    logger.info(f"True Positive: Selected correct answer at index {first_true_idx}")
                else:
                    # False Positive: picked an incorrect answer
                    results[model_name]['single']['row_selections'].append(False)
                    logger.info(f"False Positive: Selected incorrect answer at index {first_true_idx}")
            except StopIteration:
                # No True verdicts (abstained from picking)
                if has_correct_sample:
                    # False Negative: abstained when correct answer exists
                    results[model_name]['single']['row_selections'].append(False)
                    logger.info("False Negative: Abstained when correct answer exists")
                else:
                    # True Negative: correctly abstained when no correct answers
                    results[model_name]['single']['row_selections'].append(True)
                    logger.info("True Negative: Correctly abstained when no correct answers")
            
            # Calculate Selection@1 for bootstrapped approach
            if num_judge_samples > 1 and bootstrapped_counts:
                max_true_count = max(bootstrapped_counts)
                # Always select the sample with highest True count
                best_sample_idx = bootstrapped_counts.index(max_true_count)
                # Selection is correct if we picked a correct answer when one exists
                results[model_name]['bootstrapped']['row_selections'].append(
                    correctness[best_sample_idx] if has_correct_sample else False
                )
                    
        return results
    except Exception as e:
        logger.error(f"Error in process_row: {str(e)}")
        return results  # Return empty results structure on error

def print_row_statistics(statistics: Dict):
    """Print row-level statistics table."""
    print("\nRow-level Statistics:")
    print("-" * 80)
    
    dataset_stats = statistics.get('dataset_stats', {})
    total_rows = dataset_stats.get('total_rows', 0)
    rows_with_correct = dataset_stats.get('rows_with_correct', 0)
    first_correct = dataset_stats.get('first_correct', 0)
    
    print(f"Total rows processed: {total_rows}")
    if total_rows > 0:
        print(f"Rows with ≥1 correct answer: {rows_with_correct} ({rows_with_correct/total_rows:.1%})")
        print(f"First samples correct: {first_correct} ({first_correct/total_rows:.1%})")
    else:
        print("No rows were successfully processed")
    print("-" * 80)

def main():
    """Main entry point."""
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True, help='Path to dataset')
    parser.add_argument('--models', nargs='+', required=True, help='Models to use for judging')
    parser.add_argument('--max_rows', type=int, help='Maximum number of rows to process')
    parser.add_argument('--max_samples', type=int, help='Maximum number of samples per row')
    parser.add_argument('--num_judge_samples', type=int, default=1, help='Number of judge samples per evaluation')
    parser.add_argument('--num_row_workers', type=int, default=4, help='Number of parallel row workers')
    parser.add_argument('--num_sample_workers', type=int, default=4, help='Number of parallel sample workers')
    parser.add_argument('--verbose', action='store_true', help='Enable verbose output')
    parser.add_argument('--output_path', type=str, help='Path to save the dataset with verdicts')
    args = parser.parse_args()
    
    # Validate selected models
    for model in args.models:
        if model not in JUDGE_MODELS:
            raise ValueError(f"Invalid model: {model}. Available models: {JUDGE_MODELS}")
    
    # Load dataset
    logger.info(f"Loading dataset: {args.dataset}")
    dataset = load_dataset(args.dataset)
    try:
        data = dataset["data"]
    except:
        try:
            data = dataset["train"]
        except:
            data = dataset
    
    # Print processing configuration
    total_rows = len(data)
    rows_to_process = args.max_rows if args.max_rows is not None else total_rows
    logger.info(f"\nProcessing {rows_to_process}/{total_rows} rows")
    if args.max_samples is not None:
        logger.info(f"Limited to first {args.max_samples} samples per row")
    logger.info(f"Using {args.num_row_workers} row workers and {args.num_sample_workers} sample workers")
    logger.info(f"Bootstrapped verdict sampling: {args.num_judge_samples} samples per evaluation")
    if args.verbose:
        logger.info("Verbose output enabled")
    
    # Process dataset
    results = process_dataset(
        data,
        args.models,
        max_rows=args.max_rows,
        max_samples=args.max_samples,
        num_judge_samples=args.num_judge_samples,
        num_row_workers=args.num_row_workers,
        num_sample_workers=args.num_sample_workers,
        verbose=args.verbose
    )
    
    # Calculate statistics
    statistics = {}
    for model_name in args.models:
        statistics[model_name] = {
            'single': calculate_approach_statistics(
                verdicts=results[model_name]['single']['verdicts'],
                ground_truth=results[model_name]['single']['ground_truth'],
                row_selections=results[model_name]['single']['row_selections']
            ),
            'bootstrapped': calculate_approach_statistics(
                verdicts=results[model_name]['bootstrapped']['verdicts'],
                ground_truth=results[model_name]['bootstrapped']['ground_truth'],
                row_selections=results[model_name]['bootstrapped']['row_selections']
            )
        }
    
    # Add dataset stats at top level
    statistics['dataset_stats'] = results['dataset_stats']
    
    # Print results
    print_results_table(results)
    print_selection_accuracies(results)
    print_row_statistics(results)
    print_generation_stats()
    
    # Create new dataset with verdict columns
    if args.output_path:
        logger.info(f"Saving dataset with verdicts to: {args.output_path}")
        
        # Initialize verdict columns
        verdict_columns = {}
        for model in args.models:
            verdict_columns[f"{model}_verdicts_v1"] = []
            verdict_columns[f"{model}_bootstrapped_verdicts_v1"] = []
        
        # Collect verdicts for each row
        for i in range(len(results[args.models[0]]['single']['ground_truth_by_row'])):
            for model in args.models:
                # Get single verdicts for this row
                start_idx = i * args.max_samples if args.max_samples else i * len(data[0]['samples'])
                end_idx = start_idx + (args.max_samples if args.max_samples else len(data[0]['samples']))
                single_verdicts = results[model]['single']['verdicts'][start_idx:end_idx]
                verdict_columns[f"{model}_verdicts_v1"].append(single_verdicts)
                
                # Get bootstrapped verdicts if available
                if args.num_judge_samples > 1:
                    bootstrapped_verdicts = results[model]['bootstrapped']['verdicts'][start_idx:end_idx]
                    verdict_columns[f"{model}_bootstrapped_verdicts_v1"].append(bootstrapped_verdicts)
                else:
                    verdict_columns[f"{model}_bootstrapped_verdicts_v1"].append(None)
        
        # Create new dataset with only the processed rows
        processed_rows = args.max_rows if args.max_rows is not None else len(data)
        new_dataset = data.select(range(processed_rows))
        
        # Add verdict columns
        for col_name, verdicts in verdict_columns.items():
            new_dataset = new_dataset.add_column(col_name, verdicts)
        
        # Save dataset
        try:
            new_dataset.save_to_disk(args.output_path)
            logger.info(f"Dataset successfully saved to: {args.output_path}")
        except Exception as e:
            logger.error(f"Error saving dataset: {str(e)}")

if __name__ == "__main__":
    main()
