#!/usr/bin/env python3
"""
mmlu_standalone.py

Evaluate the MMLU multiple‐choice dataset with conformal risk control
and full OpenAI cost accounting. This is a standalone version that includes all
the necessary functionality without relying on external modules.
"""
import openai
import numpy as np
from datasets import load_dataset
from typing import List, Dict, Tuple, Any, Optional
import random
import json
import time
import os
import argparse
import concurrent.futures
from tqdm import tqdm

# Define available MMLU subjects
MMLU_SUBJECTS = [
    "abstract_algebra", "anatomy", "astronomy", "business_ethics", "clinical_knowledge", 
    "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", 
    "college_medicine", "college_physics", "computer_security", "conceptual_physics", 
    "econometrics", "electrical_engineering", "elementary_mathematics", "formal_logic", 
    "global_facts", "high_school_biology", "high_school_chemistry", "high_school_computer_science", 
    "high_school_european_history", "high_school_geography", "high_school_government_and_politics", 
    "high_school_macroeconomics", "high_school_mathematics", "high_school_microeconomics", 
    "high_school_physics", "high_school_psychology", "high_school_statistics", "high_school_us_history", 
    "high_school_world_history", "human_aging", "human_sexuality", "international_law", "jurisprudence", 
    "logical_fallacies", "machine_learning", "management", "marketing", "medical_genetics", 
    "miscellaneous", "moral_disputes", "moral_scenarios", "nutrition", "philosophy", "prehistory", 
    "professional_accounting", "professional_law", "professional_medicine", "professional_psychology", 
    "public_relations", "security_studies", "sociology", "us_foreign_policy", "virology", 
    "world_religions"
]

class OpenAIWrapper:
    """Wrapper for OpenAI API to get confidence scores with cost tracking."""
    def __init__(self, model_name, api_key=None):
        """Initialize with model name and optional API key."""
        self.model_name = model_name
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.api_calls = 0
        self.total_cost = 0.0
        
        # Define pricing per 1M tokens
        self.pricing = {
            "gpt-4.1-2025-04-14": {"input": 2.00, "output": 8.00, "cached_input": 0.50},
            "gpt-4.1-mini-2025-04-14": {"input": 0.40, "output": 1.60, "cached_input": 0.10},
            "gpt-4.1-nano-2025-04-14": {"input": 0.10, "output": 0.40, "cached_input": 0.025}
        }
        
        # Map shortened model names to full names for pricing lookup
        self.model_name_map = {
            "gpt-4.1": "gpt-4.1-2025-04-14",
            "gpt-4.1-mini": "gpt-4.1-mini-2025-04-14",
            "gpt-4.1-nano": "gpt-4.1-nano-2025-04-14"
        }
        
        if api_key:
            self.client = openai.OpenAI(api_key=api_key)
        else:
            # Use environment variable
            self.client = openai.OpenAI()
    
    def get_scores(self, problem, answer_choices, filter_choices=None):
        """
        Get confidence scores for each answer choice.
        
        Parameters:
        -----------
        problem : str
            The problem statement
        answer_choices : list
            List of answer choices
        filter_choices : list, optional
            If provided, only these choices (indices) will be sent to the model
        
        Returns:
        --------
        list
            Confidence scores for each answer choice in the original list
        """
        # Use original choices by default
        choices_to_send = answer_choices
        choice_indices = list(range(len(answer_choices)))
        
        # If filter_choices is provided, filter the choices
        if filter_choices is not None and len(filter_choices) < len(answer_choices):
            choices_to_send = [answer_choices[i] for i in filter_choices]
            choice_indices = filter_choices
        
        # More structured prompt specifically designed for consistent token outputs
        messages = [
            {"role": "system", "content": """You are an expert who evaluates multiple choice questions.

# Instructions
- Assign a confidence score to each answer choice on a scale from 0 to 1 
- 0 means certainly incorrect, 1 means certainly correct
- Don't assign similar scores to choices unless genuinely equally uncertain

# Response Format
- Output ONLY a valid JSON object with a "scores" key containing an array of numbers
- Example: {"scores": [0.1, 0.8, 0.05, 0.05]}
- NO explanations, just the JSON"""},
            {"role": "user", "content": f"""Question:
{problem}

Answer Choices:
{json.dumps(choices_to_send)}

Respond ONLY with a JSON object containing your confidence scores for these choices, like: {{"scores": [0.1, 0.8, 0.05, 0.05]}}"""}
        ]
        
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=0.1,
                response_format={"type": "json_object"},
                max_tokens=50  # Keep it small for just scores
            )
            
            # Track token usage
            prompt_tokens = response.usage.prompt_tokens
            completion_tokens = response.usage.completion_tokens
            
            # Update tracking counters
            self.total_prompt_tokens += prompt_tokens
            self.total_completion_tokens += completion_tokens
            self.api_calls += 1
            
            # Calculate cost
            # Get the pricing model
            pricing_model = self.model_name
            if pricing_model in self.model_name_map:
                pricing_model = self.model_name_map[pricing_model]
            
            # Calculate cost (convert from per 1M tokens to per token)
            input_cost = prompt_tokens * (self.pricing[pricing_model]["input"] / 1000000)
            output_cost = completion_tokens * (self.pricing[pricing_model]["output"] / 1000000)
            total_call_cost = input_cost + output_cost
            
            self.total_cost += total_call_cost
            
            # Parse the response
            content = response.choices[0].message.content
            
            try:
                # Parse the JSON response
                response_json = json.loads(content)
                
                # Directly expect "scores" field since we've specified it explicitly
                if "scores" in response_json:
                    filtered_scores = response_json["scores"]
                else:
                    # If no scores field, check if it's just an array
                    if isinstance(response_json, list):
                        filtered_scores = response_json
                    else:
                        # Try to find any array in the response
                        for key, value in response_json.items():
                            if isinstance(value, list) and len(value) == len(choices_to_send):
                                filtered_scores = value
                                break
                        else:
                            print(f"Could not find scores in response: {content}")
                            # Use a random distribution rather than uniform to ensure conformal sets don't get too large
                            filtered_scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
                
                # Ensure we have the right number of scores
                if len(filtered_scores) != len(choices_to_send):
                    print(f"Warning: Got {len(filtered_scores)} scores for {len(choices_to_send)} choices")
                    if len(filtered_scores) < len(choices_to_send):
                        # Add random values rather than uniform values to ensure variation
                        filtered_scores.extend([0.1 + 0.8 * random.random() for _ in range(len(choices_to_send) - len(filtered_scores))])
                    else:
                        filtered_scores = filtered_scores[:len(choices_to_send)]
                
                # Ensure all scores are between 0 and 1
                filtered_scores = [min(max(float(s), 0.0), 1.0) for s in filtered_scores]
                
                # Add small randomness to break ties and avoid all equal scores
                if all(s == filtered_scores[0] for s in filtered_scores):
                    print("All scores are equal, adding random variation")
                    filtered_scores = [s + random.uniform(-0.05, 0.05) for s in filtered_scores]
                    filtered_scores = [min(max(s, 0.0), 1.0) for s in filtered_scores]
                
                # If we filtered choices, expand back to the original size with zeros for missing choices
                if filter_choices is not None:
                    full_scores = [0.0] * len(answer_choices)
                    for i, score in zip(choice_indices, filtered_scores):
                        full_scores[i] = score
                    
                    # Normalize the scores
                    total = sum(full_scores)
                    if total > 0:
                        full_scores = [s / total for s in full_scores]
                    
                    return full_scores, total_call_cost, prompt_tokens, completion_tokens
                
                # Normalize to sum to 1
                total = sum(filtered_scores)
                if total > 0:
                    normalized_scores = [s / total for s in filtered_scores]
                else:
                    # Random distribution instead of uniform
                    normalized_scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
                    total = sum(normalized_scores)
                    normalized_scores = [s / total for s in normalized_scores]
                
                return normalized_scores, total_call_cost, prompt_tokens, completion_tokens
                
            except Exception as e:
                print(f"Error parsing JSON response: {e}")
                print(f"Raw content: {content}")
                # Use a random distribution to ensure conformal sets don't get too large
                scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
                total = sum(scores)
                normalized_scores = [s / total for s in scores]
                return normalized_scores, total_call_cost, prompt_tokens, completion_tokens
                
        except Exception as e:
            print(f"API request error: {e}")
            # Return random scores in case of API error
            scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
            total = sum(scores)
            normalized_scores = [s / total for s in scores]
            return normalized_scores, 0.0, 0, 0  # Zero cost for failed requests
    
    def reset_tracking(self):
        """Reset the tracking counters."""
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.api_calls = 0
        self.total_cost = 0.0



# Function to load MMLU dataset
def load_mmlu_data(subject=None,
                   num_examples: int = 1000,
                   random_seed: int = 42):
    """Load MMLU examples for conformal evaluation with balanced sampling across subjects.

    Parameters
    ----------
    subject : str | None
        If provided, load only that MMLU sub‑category. If *None*, sample 
        examples evenly across all available sub‑categories listed in
        ``MMLU_SUBJECTS``.
    num_examples : int
        Maximum number of total examples to return (after balancing).
    random_seed : int
        Seed for Python's ``random`` RNG so that repeated calls are
        reproducible.

    Returns
    -------
    data : dict
        ``{"train": [], "validation": examples, "test": []}``
        The *validation* split holds the entire pool of examples; *train*
        and *test* are left empty because our conformal pipeline will
        define its own calibration/evaluation split on every trial.
    subject_name : str
        Either the provided ``subject`` or the string "MMLU" when the full
        benchmark is sampled.
    """
    import random
    random.seed(random_seed)

    # Helper to collapse the three canonical HF splits (some subjects omit
    # one or two of them) into a single list.
    def _collect_examples(ds, subj_name: str):
        collected = []
        for split in ["train", "validation", "test"]:
            if split in ds:
                for ex in ds[split]:
                    ex["subject"] = subj_name
                    collected.append(ex)
        return collected

    try:
        if subject is not None:
            # ── Single‑subject mode ─────────────────────────────────────────
            ds = load_dataset("cais/mmlu", subject)
            examples = _collect_examples(ds, subject)

            if not examples:
                raise ValueError(f"No examples found for subject '{subject}'.")

            random.shuffle(examples)
            examples = examples[: min(num_examples, len(examples))]

            data = {"train": [], "validation": examples, "test": []}
            return data, subject
        else:
            # ── All‑subjects mode with balanced sampling ────────────────────
            examples_by_subject = {}
            valid_subjects = []

            # First collect examples for each subject
            for subj in MMLU_SUBJECTS:
                try:
                    ds = load_dataset("cais/mmlu", subj)
                    examples = _collect_examples(ds, subj)
                    if examples:
                        examples_by_subject[subj] = examples
                        valid_subjects.append(subj)
                        print(f"Found {len(examples)} examples in {subj}.")
                    else:
                        print(f"No examples found in {subj}.")
                except Exception as e:
                    print(f"Skipping subject '{subj}' due to error: {e}")

            if not valid_subjects:
                raise RuntimeError("Failed to load any MMLU subjects.")

            # Calculate how many examples to take from each subject
            examples_per_subject = num_examples // len(valid_subjects)
            remainder = num_examples % len(valid_subjects)
            
            print(f"Targeting {examples_per_subject} examples per subject across {len(valid_subjects)} subjects.")
            
            # Balance the examples across subjects
            balanced_examples = []
            for i, subj in enumerate(valid_subjects):
                # Add one extra to the first 'remainder' subjects to use up the full num_examples
                target_count = examples_per_subject + (1 if i < remainder else 0)
                
                # Shuffle the examples for this subject
                subject_examples = examples_by_subject[subj]
                random.shuffle(subject_examples)
                
                # Take up to target_count examples (or all if fewer are available)
                sampled = subject_examples[:min(target_count, len(subject_examples))]
                balanced_examples.extend(sampled)
                print(f"Added {len(sampled)} examples from {subj}.")

            # Shuffle the final balanced dataset
            random.shuffle(balanced_examples)
            
            print(f"Created balanced dataset with {len(balanced_examples)} total examples "
                  f"across {len(valid_subjects)} subjects.")
            
            # Calculate the actual distribution for reporting
            subject_counts = {}
            for ex in balanced_examples:
                subj = ex["subject"]
                subject_counts[subj] = subject_counts.get(subj, 0) + 1
            
            print("Final distribution of examples by subject:")
            for subj, count in sorted(subject_counts.items(), key=lambda x: x[1], reverse=True):
                print(f"  {subj}: {count} examples ({count/len(balanced_examples):.1%})")

            data = {"train": [], "validation": balanced_examples, "test": []}
            return data, "MMLU"

    except Exception as exc:
        print(f"Error loading MMLU data: {exc}")
        return {"train": [], "validation": [], "test": []}, "error"

# Function to process MMLU problems
def process_mmlu_problem(example):
    """Process MMLU problem into the format needed for conformal alignment."""
    # Extract the question
    question = example["question"]
    
    # Extract the choices (A, B, C, D)
    choices = [example["choices"][i] for i in range(len(example["choices"]))]
    
    # Extract the correct answer index (0-based)
    correct_index = example["answer"]
    
    return {
        "problem": question,
        "choices": choices,
        "correct_index": correct_index
    }


def process_single_calibration_example(args):
    """Process a single calibration example - to be used with parallel processing"""
    index, example, small_model, large_model = args
    
    processed = process_mmlu_problem(example)
    
    print(f"Getting scores for calibration example {index+1}")
    
    # Get scores from both models
    small_scores, small_cost, small_prompt_tokens, small_completion_tokens = small_model.get_scores(
        processed["problem"], processed["choices"]
    )
    print(f"Small model scores: {small_scores}")
    
    # Sleep to avoid rate limits - shorter sleep between batches
    time.sleep(0.5)
    
    large_scores, large_cost, large_prompt_tokens, large_completion_tokens = large_model.get_scores(
        processed["problem"], processed["choices"]
    )
    print(f"Large model scores: {large_scores}")
    
    return {
        "context": processed["problem"],
        "choices": processed["choices"],
        "small_scores": small_scores,
        "large_scores": large_scores,
        "correct_index": processed["correct_index"],
        "small_cost": small_cost,
        "large_cost": large_cost,
        "small_tokens": {"prompt": small_prompt_tokens, "completion": small_completion_tokens},
        "large_tokens": {"prompt": large_prompt_tokens, "completion": large_completion_tokens}
    }

def create_calibration_dataset_parallel(small_model, large_model, mmlu_data, num_samples=300, max_workers=4):
    """Create a calibration dataset with scores from both models using parallel processing."""
    # Reset tracking
    small_model.reset_tracking()
    large_model.reset_tracking()
    
    # Use a subset of validation data for calibration
    validation_sample = mmlu_data["validation"]
    if len(validation_sample) > num_samples:
        # Randomly sample to ensure diversity
        validation_sample = random.sample(validation_sample, num_samples)
    
    calibration_data = []
    
    # Prepare arguments for parallel processing
    args_list = [(i, example, small_model, large_model) 
                for i, example in enumerate(validation_sample)]
    
    # Process in parallel with a thread pool
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit tasks and process results as they complete
        future_to_index = {executor.submit(process_single_calibration_example, arg): i 
                           for i, arg in enumerate(args_list)}
        
        # Process results as they complete
        for future in tqdm(concurrent.futures.as_completed(future_to_index), 
                          total=len(args_list), desc="Processing calibration examples"):
            idx = future_to_index[future]
            try:
                result = future.result()
                calibration_data.append(result)
            except Exception as exc:
                print(f"Processing example {idx} generated an exception: {exc}")
    
    # Sort calibration data by original index to maintain consistency
    calibration_data.sort(key=lambda x: validation_sample.index(next(e for e in validation_sample 
                                                                   if e["question"] == x["context"])))
    
    return calibration_data

def compute_standard_conformal_threshold(calibration_data, alpha=0.1):
    """Compute the standard conformal threshold lambda for a given risk level alpha."""
    n = len(calibration_data)
    losses = []
    lambda_values = np.linspace(0, 1, 101)
    
    # Store detailed metrics for each lambda
    lambda_metrics = []
    
    for lam in lambda_values:
        loss_sum = 0
        conformal_set_sizes = []
        small_model_usage_count = 0
        
        for sample in calibration_data:
            small_scores = sample["small_scores"]
            large_scores = sample["large_scores"]
            correct_index = sample["correct_index"]
            
            # Construct conformal action set
            max_score = max(small_scores)
            conformal_set = [i for i, score in enumerate(small_scores) if score >= max_score - lam]
            conformal_set_sizes.append(len(conformal_set))
            
            # Count cases where we'd use the small model
            if len(conformal_set) == 1:
                small_model_usage_count += 1
            
            # Calculate loss based on correctness
            # Get large model's choice from all options
            large_model_choice = np.argmax(large_scores)
            large_model_correct = (large_model_choice == correct_index)

            if correct_index in conformal_set:
                # Correct answer is in the conformal set - no loss
                loss = 0
            else:
                # Correct answer is NOT in the conformal set
                if large_model_correct:
                    # Large model would get it right with all options
                    # But correct answer isn't in conformal set, so we lose this one
                    loss = 1
                else:
                    # Large model would get it wrong anyway, so no real loss
                    loss = 0
                            
            loss_sum += loss
        
        avg_loss = loss_sum / n
        avg_set_size = sum(conformal_set_sizes) / len(conformal_set_sizes)
        small_model_usage = small_model_usage_count / n
        
        lambda_metrics.append({
            "lambda": lam,
            "avg_loss": avg_loss,
            "avg_set_size": avg_set_size,
            "small_model_usage": small_model_usage,
            "risk_bound": (n / (n + 1)) * avg_loss + (1 / (n + 1))
        })
        
        losses.append((avg_loss, avg_set_size))
    
    # Find the smallest lambda that satisfies the risk constraint
    selected_lambda = None
    selected_idx = None
    
    for i, metrics in enumerate(lambda_metrics):
        if metrics["risk_bound"] <= alpha:
            selected_lambda = lambda_values[i]
            selected_idx = i
            break
    
    if selected_lambda is None:
        print("Warning: Could not find a suitable lambda, using the maximum value")
        selected_lambda = lambda_values[-1]
        selected_idx = len(lambda_values) - 1
    
    # Print detailed information about lambda selection
    print("\nLambda Selection Details:")
    print(f"Target alpha: {alpha}")
    print(f"Selected lambda: {selected_lambda:.4f}")
    
    for i in range(max(0, selected_idx-2), min(len(lambda_metrics), selected_idx+3)):
        metrics = lambda_metrics[i]
        print(f"λ={metrics['lambda']:.4f}: Loss={metrics['avg_loss']:.4f}, Risk={metrics['risk_bound']:.4f}, SmallUsage={metrics['small_model_usage']:.2%}")
    
    return selected_lambda, lambda_metrics


def precompute_test_scores(test_examples, small_model, large_model, max_workers=4):
    """Precompute scores for all test examples to ensure consistency."""
    
    def process_test_example(args):
        idx, example, small_model, large_model = args
        processed = process_mmlu_problem(example)
        
        # Get scores from both models
        small_scores, small_cost, small_prompt_tokens, small_completion_tokens = small_model.get_scores(
            processed["problem"], processed["choices"]
        )
        
        large_scores, large_cost, large_prompt_tokens, large_completion_tokens = large_model.get_scores(
            processed["problem"], processed["choices"]
        )
        
        return {
            "problem": processed["problem"],
            "choices": processed["choices"],
            "correct_index": processed["correct_index"],
            "small_scores": small_scores,
            "large_scores": large_scores,
            "small_pred": np.argmax(small_scores),
            "large_pred": np.argmax(large_scores),
            "small_cost": small_cost,
            "large_cost": large_cost,
            "small_tokens": {"prompt": small_prompt_tokens, "completion": small_completion_tokens},
            "large_tokens": {"prompt": large_prompt_tokens, "completion": large_completion_tokens}
        }
    
    # Prepare arguments for parallel processing
    args_list = [(i, example, small_model, large_model) for i, example in enumerate(test_examples)]
    
    # Process in parallel
    test_data = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_idx = {executor.submit(process_test_example, arg): i for i, arg in enumerate(args_list)}
        
        for future in tqdm(concurrent.futures.as_completed(future_to_idx), 
                          total=len(args_list), desc="Precomputing test example scores"):
            idx = future_to_idx[future]
            try:
                result = future.result()
                test_data.append(result)
                
                # Short sleep to avoid rate limits
                time.sleep(0.1)
            except Exception as exc:
                print(f"Processing test example {idx} generated an exception: {exc}")
    
    return test_data

def calculate_random_baseline(test_data, small_model_usage_fraction, num_samples=100):
    """
    Calculate a random baseline that randomly chooses between small and large model
    based on the observed small model usage fraction.
    
    Now also calculates average cost.
    """
    import numpy as np
    import random
    
    accuracies = []
    costs = []
    
    for _ in range(num_samples):
        correct_count = 0
        total_cost = 0.0
        
        for example in test_data:
            # Extract information
            small_pred = example["small_pred"]
            large_pred = example["large_pred"]
            correct_index = example["correct_index"]
            small_cost = example["small_cost"]
            large_cost = example["large_cost"]
            
            # Randomly decide whether to use small or large model
            if random.random() < small_model_usage_fraction:
                # Use small model
                pred = small_pred
                cost = small_cost
            else:
                # Use large model
                pred = large_pred
                cost = large_cost
            
            # Add cost
            total_cost += cost
            
            # Check if prediction is correct
            if pred == correct_index:
                correct_count += 1
        
        # Calculate accuracy and average cost for this sample
        accuracy = correct_count / len(test_data)
        avg_cost = total_cost / len(test_data)
        
        accuracies.append(accuracy)
        costs.append(avg_cost)
    
    # Return average and standard deviation for both accuracy and cost
    return (np.mean(accuracies), np.std(accuracies)), (np.mean(costs), np.std(costs))

def calculate_cost_matched_random_baseline(test_data, target_cost, small_cost_avg, large_cost_avg, num_samples=100):
    """
    Calculate a random baseline that matches a target cost.
    
    Parameters:
    -----------
    test_data : list
        List of test examples with precomputed scores for both models
    target_cost : float
        Target average cost per example to match
    small_cost_avg : float
        Average cost of small model per example
    large_cost_avg : float
        Average cost of large model per example
    num_samples : int
        Number of random trials to average over
    
    Returns:
    --------
    tuple
        (Accuracy mean/std, Cost mean/std, Small model usage fraction)
    """
    import numpy as np
    import random
    
    # Calculate what fraction of small model calls would achieve the target cost
    # Using the formula: target_cost = p*small_cost + (1-p)*large_cost
    # Solving for p: p = (large_cost - target_cost) / (large_cost - small_cost)
    
    if large_cost_avg == small_cost_avg:
        # Edge case: if costs are identical, use 50/50 split
        small_model_fraction = 0.5
    else:
        small_model_fraction = (large_cost_avg - target_cost) / (large_cost_avg - small_cost_avg)
    
    # Clamp the fraction to [0, 1] range
    small_model_fraction = max(0.0, min(1.0, small_model_fraction))
    
    accuracies = []
    costs = []
    
    for _ in range(num_samples):
        correct_count = 0
        total_cost = 0.0
        
        for example in test_data:
            # Extract information
            small_pred = example["small_pred"]
            large_pred = example["large_pred"]
            correct_index = example["correct_index"]
            small_cost = example["small_cost"]
            large_cost = example["large_cost"]
            
            # Randomly decide whether to use small or large model
            if random.random() < small_model_fraction:
                # Use small model
                pred = small_pred
                cost = small_cost
            else:
                # Use large model
                pred = large_pred
                cost = large_cost
            
            # Add cost
            total_cost += cost
            
            # Check if prediction is correct
            if pred == correct_index:
                correct_count += 1
        
        # Calculate accuracy and average cost for this sample
        accuracy = correct_count / len(test_data)
        avg_cost = total_cost / len(test_data)
        
        accuracies.append(accuracy)
        costs.append(avg_cost)
    
    # Return average and standard deviation for both accuracy and cost
    return (np.mean(accuracies), np.std(accuracies)), (np.mean(costs), np.std(costs)), small_model_fraction

def run_conformal_evaluation(
    alphas=[0.1, 0.2, 0.3, 0.4, 0.5], 
    api_key=None, 
    max_workers=4,
    num_trials=5,
    calibration_size=300,
    total_examples=1500,
    random_seed=42, 
    results_dir=None,
    subject=None,
    include_random_baseline=True,
    include_unrestricted_hybrid=True):
    """
    Run conformal evaluation with properly independent trials.
    Now includes cost tracking and unrestricted hybrid option.
    """
    print(f"Starting revised conformal evaluation with {num_trials} independent trials")
    print(f"Method: Standard conformal")
    
    # Create results directory first
    method_name = "standard"
    if results_dir is None:
        # We'll update with actual subject name after loading data
        results_dir = f"mmlu_results_{method_name}_revised"
    os.makedirs(results_dir, exist_ok=True)
    
    # Load the full dataset
    full_mmlu_data, subject_name = load_mmlu_data(subject=subject, num_examples=total_examples * 2)  # Request extra to ensure enough examples
    print(f"Subject: {subject_name}")
    
    # Update results directory with subject name
    if results_dir == f"mmlu_results_{method_name}_revised":
        results_dir = f"mmlu_results_{method_name}_revised_{subject_name}"
        os.makedirs(results_dir, exist_ok=True)
    
    # Combine all examples for the pool
    all_available_examples = full_mmlu_data["train"] + full_mmlu_data["validation"] + full_mmlu_data["test"]
    print(f"Total available examples: {len(all_available_examples)}")
    
    # Results storage for trial statistics
    trial_results = []
    small_model_accuracies = []
    large_model_accuracies = []
    small_model_costs = []
    large_model_costs = []
    
    # Hybrid model statistics
    hybrid_accuracies = {alpha: [] for alpha in alphas}
    large_model_calls = {alpha: [] for alpha in alphas}
    lambda_values = {alpha: [] for alpha in alphas}
    hybrid_costs = {alpha: [] for alpha in alphas}
    
    # Unrestricted hybrid statistics (if enabled)
    if include_unrestricted_hybrid:
        unrestricted_hybrid_accuracies = {alpha: [] for alpha in alphas}
        unrestricted_hybrid_costs = {alpha: [] for alpha in alphas}
        unrestricted_random_baseline_accuracies = {alpha: [] for alpha in alphas}
        unrestricted_random_baseline_costs = {alpha: [] for alpha in alphas}
        unrestricted_random_baseline_sm_fractions = {alpha: [] for alpha in alphas}
    
    # Random baseline statistics (if enabled)
    if include_random_baseline:
        random_baseline_accuracies = {alpha: [] for alpha in alphas}
        random_baseline_costs = {alpha: [] for alpha in alphas}
        random_baseline_sm_fractions = {alpha: [] for alpha in alphas}
    
    # For each trial
    for trial_idx in range(num_trials):
        print(f"\n{'='*30}")
        print(f" TRIAL {trial_idx+1}/{num_trials} ")
        print(f"{'='*30}")
        
        # Set trial-specific random seed for reproducibility
        trial_seed = random_seed + trial_idx
        random.seed(trial_seed)
        
        # Sample a subset of examples for this trial
        if len(all_available_examples) <= total_examples:
            trial_pool = all_available_examples.copy()
            print(f"Using all {len(trial_pool)} available examples for this trial")
        else:
            trial_pool = random.sample(all_available_examples, total_examples)
            print(f"Sampled {len(trial_pool)} examples for this trial")
        
        # Initialize models for this trial
        small_model = OpenAIWrapper("gpt-4.1-nano-2025-04-14", api_key)
        large_model = OpenAIWrapper("gpt-4.1-2025-04-14", api_key)
        
        # First check if we already have scored examples for this trial
        scored_examples_file = os.path.join(results_dir, f"scored_examples_trial_{trial_idx+1}.json")
        if os.path.exists(scored_examples_file):
            print(f"Found pre-computed scores for trial {trial_idx+1}, loading from {scored_examples_file}")
            try:
                with open(scored_examples_file, "r") as f:
                    saved_data = json.load(f)
                    if "examples" in saved_data and len(saved_data["examples"]) > 0:
                        print(f"Loaded {len(saved_data['examples'])} pre-computed examples")
                        scored_examples = saved_data["examples"]
                    else:
                        print("No examples found in saved file, computing scores...")
                        scored_examples = precompute_test_scores(
                            trial_pool,
                            small_model,
                            large_model,
                            max_workers=max_workers
                        )
            except Exception as e:
                print(f"Error loading pre-computed scores: {e}")
                print("Computing scores from scratch...")
                scored_examples = precompute_test_scores(
                    trial_pool,
                    small_model,
                    large_model,
                    max_workers=max_workers
                )
        else:
            # Score all examples in this trial's pool
            print(f"No pre-computed scores found for trial {trial_idx+1}, computing scores...")
            scored_examples = precompute_test_scores(
                trial_pool,
                small_model,
                large_model,
                max_workers=max_workers
            )
        
        # Calculate trial-specific baseline accuracies and costs
        small_correct = sum(1 for example in scored_examples if example["small_pred"] == example["correct_index"])
        large_correct = sum(1 for example in scored_examples if example["large_pred"] == example["correct_index"])
        
        small_acc = small_correct / len(scored_examples)
        large_acc = large_correct / len(scored_examples)
        
        small_avg_cost = sum(example["small_cost"] for example in scored_examples) / len(scored_examples)
        large_avg_cost = sum(example["large_cost"] for example in scored_examples) / len(scored_examples)
        
        small_model_accuracies.append(small_acc)
        large_model_accuracies.append(large_acc)
        small_model_costs.append(small_avg_cost)
        large_model_costs.append(large_avg_cost)
        
        print(f"Trial {trial_idx+1} baseline accuracies and costs:")
        print(f"  Small model: {small_acc:.4f}, Cost: ${small_avg_cost:.6f} per example")
        print(f"  Large model: {large_acc:.4f}, Cost: ${large_avg_cost:.6f} per example")
        
        # Save scored examples for this trial
        scored_examples_file = os.path.join(results_dir, f"scored_examples_trial_{trial_idx+1}.json")
        with open(scored_examples_file, "w") as f:
            json.dump({
                "trial": trial_idx + 1,
                "trial_seed": trial_seed,
                "subject": subject_name,
                "small_model_accuracy": small_acc,
                "small_model_cost": small_avg_cost,
                "large_model_accuracy": large_acc,
                "large_model_cost": large_avg_cost,
                "examples": scored_examples
            }, f, default=custom_json_serializer)
        
        # Shuffle examples before splitting
        random.shuffle(scored_examples)
        
        # Split into calibration and evaluation sets
        if calibration_size >= len(scored_examples):
            print(f"Warning: Requested calibration size {calibration_size} exceeds available data {len(scored_examples)}")
            calibration_size = len(scored_examples) // 2
            
        calibration_examples = scored_examples[:calibration_size]
        test_examples = scored_examples[calibration_size:]
        
        # Save the data split information for reproducibility
        data_split_info = {
            "trial": trial_idx + 1,
            "trial_seed": trial_seed,
            "subject": subject_name,
            "total_examples": len(scored_examples),
            "calibration_size": len(calibration_examples),
            "test_size": len(test_examples),
            "small_model_accuracy": small_acc,
            "small_model_cost": small_avg_cost,
            "large_model_accuracy": large_acc,
            "large_model_cost": large_avg_cost
        }
        
        # Save data split info to file
        split_file = os.path.join(results_dir, f"data_split_trial_{trial_idx+1}.json")
        with open(split_file, "w") as f:
            json.dump(data_split_info, f, default=custom_json_serializer)
        
        print(f"Using {len(calibration_examples)} examples for calibration and {len(test_examples)} for testing")
        
        # Check if we already have calibration data for this trial
        calibration_file = os.path.join(results_dir, f"calibration_data_{method_name}_trial_{trial_idx+1}.json")
        if os.path.exists(calibration_file):
            print(f"Found pre-computed calibration data for trial {trial_idx+1}, loading from {calibration_file}")
            try:
                with open(calibration_file, "r") as f:
                    calibration_data = json.load(f)
                print(f"Loaded {len(calibration_data)} pre-computed calibration examples")
            except Exception as e:
                print(f"Error loading pre-computed calibration data: {e}")
                print("Creating calibration data from scratch...")
                # Extract calibration data in the format needed for conformal prediction
                calibration_data = []
                for example in calibration_examples:
                    calibration_data.append({
                        "context": example["problem"],
                        "choices": example["choices"],
                        "small_scores": example["small_scores"],
                        "large_scores": example["large_scores"],
                        "correct_index": example["correct_index"],
                        "small_cost": example["small_cost"],
                        "large_cost": example["large_cost"]
                    })
                # Save calibration data for future reference
                with open(calibration_file, "w") as f:
                    json.dump(calibration_data, f, default=custom_json_serializer)
        else:
            print(f"Creating calibration data for trial {trial_idx+1}")
            # Extract calibration data in the format needed for conformal prediction
            calibration_data = []
            for example in calibration_examples:
                calibration_data.append({
                    "context": example["problem"],
                    "choices": example["choices"],
                    "small_scores": example["small_scores"],
                    "large_scores": example["large_scores"],
                    "correct_index": example["correct_index"],
                    "small_cost": example["small_cost"],
                    "large_cost": example["large_cost"]
                })
            # Save calibration data for future reference
            with open(calibration_file, "w") as f:
                json.dump(calibration_data, f, default=custom_json_serializer)
        
        # We will use test_examples directly since they already have all the scores we need
        test_data = test_examples
        
        # Check if we already have detailed trial results (which would include thresholds)
        detailed_trial_file = os.path.join(results_dir, f"detailed_trial_{trial_idx+1}_results.json")
        if os.path.exists(detailed_trial_file):
            print(f"Found existing detailed results for trial {trial_idx+1}, attempting to load thresholds...")
            try:
                with open(detailed_trial_file, "r") as f:
                    trial_result = json.load(f)
                
                # Try to extract lambda values from existing results
                if "hybrid_results" in trial_result:
                    standard_lambda_thresholds = {}
                    for alpha in alphas:
                        if alpha in trial_result["hybrid_results"] and "avg_lambda" in trial_result["hybrid_results"][alpha]:
                            # Use the stored lambda value
                            standard_lambda_thresholds[alpha] = trial_result["hybrid_results"][alpha]["avg_lambda"]
                            print(f"Loaded threshold for α={alpha}: {standard_lambda_thresholds[alpha]:.4f}")
                        else:
                            # Compute it if not found
                            lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                            standard_lambda_thresholds[alpha] = lambda_threshold
                            print(f"Computed threshold for α={alpha}: {lambda_threshold:.4f}")
                else:
                    # Compute all thresholds if not found
                    standard_lambda_thresholds = {}
                    for alpha in alphas:
                        lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                        standard_lambda_thresholds[alpha] = lambda_threshold
                        print(f"Standard lambda threshold for α={alpha}: {lambda_threshold:.4f}")
            except Exception as e:
                print(f"Error loading existing thresholds: {e}")
                print("Computing thresholds from scratch...")
                # Compute thresholds if loading failed
                standard_lambda_thresholds = {}
                for alpha in alphas:
                    lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                    standard_lambda_thresholds[alpha] = lambda_threshold
                    print(f"Standard lambda threshold for α={alpha}: {lambda_threshold:.4f}")
        else:
            # Compute thresholds if no file exists
            print("Computing conformal thresholds...")
            standard_lambda_thresholds = {}
            for alpha in alphas:
                lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                standard_lambda_thresholds[alpha] = lambda_threshold
                print(f"Standard lambda threshold for α={alpha}: {lambda_threshold:.4f}")
        
        # For each alpha, evaluate hybrid model performance using precomputed scores
        trial_hybrid_results = {}
        
        for alpha in alphas:
            print(f"\nEvaluating hybrid model with alpha = {alpha}")
            
            # Evaluate hybrid model and unrestricted hybrid model (if enabled) on each test example
            hybrid_results = []
            unrestricted_hybrid_results = [] if include_unrestricted_hybrid else None
            
            for example in tqdm(test_data, desc=f"Evaluating hybrid model (α={alpha})"):
                # Extract precomputed scores and data
                small_scores = example["small_scores"]
                large_scores = example["large_scores"]
                correct_index = example["correct_index"]
                small_cost = example["small_cost"]
                large_cost = example["large_cost"]
                choices = example["choices"]
                problem = example["problem"]
                
                # Use precomputed standard lambda threshold
                lambda_threshold = standard_lambda_thresholds[alpha]
                
                # Construct conformal action set
                max_score = max(small_scores)
                conformal_set = [i for i, score in enumerate(small_scores) if score >= max_score - lambda_threshold]
                
                # === STANDARD HYBRID MODEL ===
                # If only one action in conformal set, choose it directly (use small model)
                if len(conformal_set) == 1:
                    hybrid_pred = conformal_set[0]
                    used_large_model = False
                    example_cost = small_cost  # Only small model cost
                else:
                    # We need to simulate getting refined scores from the large model
                    # but only for choices in the conformal set
                    filtered_choices = [choices[i] for i in conformal_set]
                    
                    # Since we've already precomputed all large model scores, we can just filter them
                    # But in a real scenario, we would make a new API call with filtered choices
                    filtered_large_scores = [large_scores[i] for i in conformal_set]
                    
                    # Find best choice
                    best_idx_in_conformal = np.argmax(filtered_large_scores)
                    hybrid_pred = conformal_set[best_idx_in_conformal]
                    used_large_model = True
                    
                    # For cost calculation, we need to estimate what the cost would be
                    # if we had made the API call with only the filtered choices
                    # We'll assume input tokens scale linearly with number of choices,
                    # and output tokens remain constant
                    
                    # Calculate the scaling factor based on number of choices
                    scaling_factor = len(conformal_set) / len(choices)
                    
                    # Calculate the estimated cost of a large model call with filtered choices
                    large_prompt_tokens = example["large_tokens"]["prompt"]
                    large_completion_tokens = example["large_tokens"]["completion"]
                    
                    # Estimate tokens for filtered call
                    filtered_prompt_tokens = int(large_prompt_tokens * (0.5 + 0.5 * scaling_factor))
                    filtered_completion_tokens = large_completion_tokens
                    
                    # Recalculate cost using the pricing structure in the OpenAIWrapper
                    # Get the pricing model
                    pricing_model = "gpt-4.1-2025-04-14"  # The large model
                    
                    # Calculate cost (convert from per 1M tokens to per token)
                    pricing = {"input": 2.00, "output": 8.00}  # Hardcoded for simplicity
                    filtered_input_cost = filtered_prompt_tokens * (pricing["input"] / 1000000)
                    filtered_output_cost = filtered_completion_tokens * (pricing["output"] / 1000000)
                    filtered_large_cost = filtered_input_cost + filtered_output_cost
                    
                    # Total cost is small model + filtered large model
                    example_cost = small_cost + filtered_large_cost
                
                hybrid_correct = (hybrid_pred == correct_index)
                
                hybrid_results.append({
                    "hybrid_correct": hybrid_correct,
                    "used_large_model": used_large_model,
                    "lambda_value": lambda_threshold,
                    "cost": example_cost,
                    "conformal_set_size": len(conformal_set)
                })
                
                # === UNRESTRICTED HYBRID MODEL (if enabled) ===
                if include_unrestricted_hybrid:
                    # Unrestricted hybrid always uses small model first,
                    # then passes all options to large model if needed
                    # but decides to use large model based on conformal set size
                    
                    if len(conformal_set) == 1:
                        # Same as regular hybrid when conformal set has one element
                        unrestricted_pred = conformal_set[0]
                        unrestricted_used_large = False
                        unrestricted_cost = small_cost
                    else:
                        # Use large model with all choices (already precomputed)
                        unrestricted_pred = np.argmax(large_scores)
                        unrestricted_used_large = True
                        unrestricted_cost = small_cost + large_cost
                    
                    unrestricted_hybrid_correct = (unrestricted_pred == correct_index)
                    
                    unrestricted_hybrid_results.append({
                        "hybrid_correct": unrestricted_hybrid_correct,
                        "used_large_model": unrestricted_used_large,
                        "lambda_value": lambda_threshold,
                        "cost": unrestricted_cost,
                        "conformal_set_size": len(conformal_set)
                    })
            
            # Calculate hybrid model metrics
            hybrid_correct_count = sum(1 for r in hybrid_results if r["hybrid_correct"])
            large_model_used_count = sum(1 for r in hybrid_results if r["used_large_model"])
            avg_lambda = sum(r["lambda_value"] for r in hybrid_results) / len(hybrid_results)
            avg_cost = sum(r["cost"] for r in hybrid_results) / len(hybrid_results)
            
            hybrid_accuracy = hybrid_correct_count / len(test_data)
            large_model_usage = large_model_used_count / len(test_data)
            small_model_usage = 1.0 - large_model_usage
            
            hybrid_accuracies[alpha].append(hybrid_accuracy)
            large_model_calls[alpha].append(large_model_usage)
            lambda_values[alpha].append(avg_lambda)
            hybrid_costs[alpha].append(avg_cost)
            
            print(f"Hybrid model accuracy (α={alpha}): {hybrid_accuracy:.4f}")
            print(f"Large model usage: {large_model_usage:.2%}")
            print(f"Small model usage: {small_model_usage:.2%}")
            print(f"Average lambda: {avg_lambda:.4f}")
            print(f"Average cost per example: ${avg_cost:.6f}")
            
            # Store results for this alpha
            trial_hybrid_results[alpha] = {
                "accuracy": hybrid_accuracy,
                "large_model_usage": large_model_usage,
                "small_model_usage": small_model_usage,
                "avg_lambda": avg_lambda,
                "avg_cost": avg_cost
            }
            
            # Calculate unrestricted hybrid metrics (if enabled)
            if include_unrestricted_hybrid:
                unrestricted_correct_count = sum(1 for r in unrestricted_hybrid_results if r["hybrid_correct"])
                unrestricted_large_model_used_count = sum(1 for r in unrestricted_hybrid_results if r["used_large_model"])
                unrestricted_avg_cost = sum(r["cost"] for r in unrestricted_hybrid_results) / len(unrestricted_hybrid_results)
                
                unrestricted_accuracy = unrestricted_correct_count / len(test_data)
                unrestricted_large_model_usage = unrestricted_large_model_used_count / len(test_data)
                
                unrestricted_hybrid_accuracies[alpha].append(unrestricted_accuracy)
                unrestricted_hybrid_costs[alpha].append(unrestricted_avg_cost)
                
                print(f"Unrestricted hybrid accuracy (α={alpha}): {unrestricted_accuracy:.4f}")
                print(f"Unrestricted large model usage: {unrestricted_large_model_usage:.2%}")
                print(f"Unrestricted avg cost per example: ${unrestricted_avg_cost:.6f}")
                
                # Add unrestricted hybrid results
                trial_hybrid_results[alpha]["unrestricted"] = {
                    "accuracy": unrestricted_accuracy,
                    "large_model_usage": unrestricted_large_model_usage,
                    "avg_cost": unrestricted_avg_cost
                }
            
            # Calculate cost-matched random baseline for conformal method
            if include_random_baseline:
                print(f"Calculating cost-matched random baseline for conformal method (α={alpha})")
                (random_avg, random_std), (random_cost_avg, random_cost_std), sm_fraction = calculate_cost_matched_random_baseline(
                    test_data, avg_cost, small_avg_cost, large_avg_cost, num_samples=10
                )
                random_baseline_accuracies[alpha].append(random_avg)
                random_baseline_costs[alpha].append(random_cost_avg)
                random_baseline_sm_fractions[alpha].append(sm_fraction)
                
                print(f"Random baseline accuracy: {random_avg:.4f} ± {random_std:.4f}")
                print(f"Random baseline avg cost: ${random_cost_avg:.6f} ± ${random_cost_std:.6f}")
                print(f"Small model fraction: {sm_fraction:.2%}")
                
                # Add to trial results
                trial_hybrid_results[alpha]["random_baseline"] = {
                    "accuracy": random_avg,
                    "accuracy_std": random_std,
                    "avg_cost": random_cost_avg,
                    "cost_std": random_cost_std,
                    "small_model_fraction": sm_fraction
                }
                
                # If unrestricted hybrid is enabled, calculate a baseline for it too
                if include_unrestricted_hybrid:
                    print(f"Calculating cost-matched random baseline for unrestricted hybrid (α={alpha})")
                    (unrestricted_random_avg, unrestricted_random_std), (unrestricted_random_cost_avg, unrestricted_random_cost_std), unrestricted_sm_fraction = calculate_cost_matched_random_baseline(
                        test_data, unrestricted_avg_cost, small_avg_cost, large_avg_cost, num_samples=10
                    )
                    unrestricted_random_baseline_accuracies[alpha].append(unrestricted_random_avg)
                    unrestricted_random_baseline_costs[alpha].append(unrestricted_random_cost_avg)
                    unrestricted_random_baseline_sm_fractions[alpha].append(unrestricted_sm_fraction)
                    
                    print(f"Unrestricted random baseline accuracy: {unrestricted_random_avg:.4f} ± {unrestricted_random_std:.4f}")
                    print(f"Unrestricted random baseline avg cost: ${unrestricted_random_cost_avg:.6f} ± ${unrestricted_random_cost_std:.6f}")
                    print(f"Unrestricted small model fraction: {unrestricted_sm_fraction:.2%}")
                    
                    # Add to trial results
                    trial_hybrid_results[alpha]["unrestricted_random_baseline"] = {
                        "accuracy": unrestricted_random_avg,
                        "accuracy_std": unrestricted_random_std,
                        "avg_cost": unrestricted_random_cost_avg,
                        "cost_std": unrestricted_random_cost_std,
                        "small_model_fraction": unrestricted_sm_fraction
                    }
        
        # Save all results for this trial
        trial_results.append({
            "trial": trial_idx + 1,
            "trial_seed": trial_seed,
            "subject": subject_name,
            "method": "standard",
            "calibration_size": len(calibration_examples),
            "test_size": len(test_examples),
            "small_model_accuracy": small_acc,
            "large_model_accuracy": large_acc,
            "small_model_cost": small_avg_cost,
            "large_model_cost": large_avg_cost,
            "hybrid_results": trial_hybrid_results
        })
        
        # Save detailed trial results
        detailed_trial_file = os.path.join(results_dir, f"detailed_trial_{trial_idx+1}_results.json")
        with open(detailed_trial_file, "w") as f:
            json.dump(trial_results[-1], f, default=custom_json_serializer)
    
    # Compute summary statistics across all trials
    small_avg = np.mean(small_model_accuracies)
    small_std = np.std(small_model_accuracies)
    large_avg = np.mean(large_model_accuracies)
    large_std = np.std(large_model_accuracies)
    small_cost_avg = np.mean(small_model_costs)
    small_cost_std = np.std(small_model_costs)
    large_cost_avg = np.mean(large_model_costs)
    large_cost_std = np.std(large_model_costs)
    
    print(f"\n===== Summary Statistics Across {num_trials} Trials =====")
    print(f"Small model: {small_avg:.4f} ± {small_std:.4f}, Cost: ${small_cost_avg:.6f} ± ${small_cost_std:.6f}")
    print(f"Large model: {large_avg:.4f} ± {large_std:.4f}, Cost: ${large_cost_avg:.6f} ± ${large_cost_std:.6f}")
    
    hybrid_avg = {alpha: np.mean(hybrid_accuracies[alpha]) for alpha in alphas}
    hybrid_std = {alpha: np.std(hybrid_accuracies[alpha]) for alpha in alphas}
    calls_avg = {alpha: np.mean(large_model_calls[alpha]) for alpha in alphas}
    calls_std = {alpha: np.std(large_model_calls[alpha]) for alpha in alphas}
    lambda_avg = {alpha: np.mean(lambda_values[alpha]) for alpha in alphas}
    lambda_std = {alpha: np.std(lambda_values[alpha]) for alpha in alphas}
    hybrid_cost_avg = {alpha: np.mean(hybrid_costs[alpha]) for alpha in alphas}
    hybrid_cost_std = {alpha: np.std(hybrid_costs[alpha]) for alpha in alphas}
    
    # Unrestricted hybrid statistics (if enabled)
    if include_unrestricted_hybrid:
        unrestricted_avg = {alpha: np.mean(unrestricted_hybrid_accuracies[alpha]) for alpha in alphas}
        unrestricted_std = {alpha: np.std(unrestricted_hybrid_accuracies[alpha]) for alpha in alphas}
        unrestricted_cost_avg = {alpha: np.mean(unrestricted_hybrid_costs[alpha]) for alpha in alphas}
        unrestricted_cost_std = {alpha: np.std(unrestricted_hybrid_costs[alpha]) for alpha in alphas}
    
    for alpha in alphas:
        print(f"Alpha={alpha}:")
        print(f"  Hybrid accuracy: {hybrid_avg[alpha]:.4f} ± {hybrid_std[alpha]:.4f}")
        print(f"  Large model usage: {calls_avg[alpha]:.2%} ± {calls_std[alpha]:.2%}")
        print(f"  Lambda threshold: {lambda_avg[alpha]:.4f} ± {lambda_std[alpha]:.4f}")
        print(f"  Hybrid cost: ${hybrid_cost_avg[alpha]:.6f} ± ${hybrid_cost_std[alpha]:.6f}")
        
        if include_unrestricted_hybrid:
            print(f"  Unrestricted hybrid accuracy: {unrestricted_avg[alpha]:.4f} ± {unrestricted_std[alpha]:.4f}")
            print(f"  Unrestricted cost: ${unrestricted_cost_avg[alpha]:.6f} ± ${unrestricted_cost_std[alpha]:.6f}")
    
    # Process random baseline data if available
    random_baseline_data = {}
    if include_random_baseline:
        for alpha in alphas:
            if alpha in random_baseline_accuracies and random_baseline_accuracies[alpha]:
                random_avg = np.mean(random_baseline_accuracies[alpha])
                random_std = np.std(random_baseline_accuracies[alpha])
                random_cost_avg = np.mean(random_baseline_costs[alpha])
                random_cost_std = np.std(random_baseline_costs[alpha])
                random_sm_frac_avg = np.mean(random_baseline_sm_fractions[alpha])
                random_sm_frac_std = np.std(random_baseline_sm_fractions[alpha])
                
                random_baseline_data[alpha] = {
                    "avg_accuracy": random_avg,
                    "std_accuracy": random_std,
                    "avg_cost": random_cost_avg,
                    "std_cost": random_cost_std,
                    "avg_small_model_fraction": random_sm_frac_avg,
                    "std_small_model_fraction": random_sm_frac_std
                }
                print(f"  Random baseline (α={alpha}): {random_avg:.4f} ± {random_std:.4f}, Cost: ${random_cost_avg:.6f} ± ${random_cost_std:.6f}")
    
    # Process unrestricted random baseline data if available
    unrestricted_random_baseline_data = {}
    if include_random_baseline and include_unrestricted_hybrid:
        for alpha in alphas:
            if alpha in unrestricted_random_baseline_accuracies and unrestricted_random_baseline_accuracies[alpha]:
                unrestricted_random_avg = np.mean(unrestricted_random_baseline_accuracies[alpha])
                unrestricted_random_std = np.std(unrestricted_random_baseline_accuracies[alpha])
                unrestricted_random_cost_avg = np.mean(unrestricted_random_baseline_costs[alpha])
                unrestricted_random_cost_std = np.std(unrestricted_random_baseline_costs[alpha])
                unrestricted_random_sm_frac_avg = np.mean(unrestricted_random_baseline_sm_fractions[alpha])
                unrestricted_random_sm_frac_std = np.std(unrestricted_random_baseline_sm_fractions[alpha])
                
                unrestricted_random_baseline_data[alpha] = {
                    "avg_accuracy": unrestricted_random_avg,
                    "std_accuracy": unrestricted_random_std,
                    "avg_cost": unrestricted_random_cost_avg,
                    "std_cost": unrestricted_random_cost_std,
                    "avg_small_model_fraction": unrestricted_random_sm_frac_avg,
                    "std_small_model_fraction": unrestricted_random_sm_frac_std
                }
                print(f"  Unrestricted random baseline (α={alpha}): {unrestricted_random_avg:.4f} ± {unrestricted_random_std:.4f}, Cost: ${unrestricted_random_cost_avg:.6f} ± ${unrestricted_random_cost_std:.6f}")
    
    # Compile final results
    final_results = {
        "subject": subject_name,
        "method": "standard",
        "iterations": num_trials,
        "random_seed": random_seed,
        "calibration_size": calibration_size,
        "total_examples": total_examples,
        "small_model": {
            "avg_accuracy": small_avg,
            "std_accuracy": small_std,
            "avg_cost": small_cost_avg,
            "std_cost": small_cost_std,
            "x_position": 1.0,  # 100% small model
            "x_std": 0.0        # No variation in x position
        },
        "large_model": {
            "avg_accuracy": large_avg,
            "std_accuracy": large_std,
            "avg_cost": large_cost_avg,
            "std_cost": large_cost_std,
            "x_position": 0.0,  # 0% small model
            "x_std": 0.0        # No variation in x position
        },
        "hybrid_models": {
            alpha: {
                "avg_accuracy": hybrid_avg[alpha],
                "std_accuracy": hybrid_std[alpha],
                "avg_large_model_usage": calls_avg[alpha],
                "std_large_model_usage": calls_std[alpha],
                "avg_lambda": lambda_avg[alpha],
                "std_lambda": lambda_std[alpha],
                "avg_cost": hybrid_cost_avg[alpha],
                "std_cost": hybrid_cost_std[alpha]
            } for alpha in alphas
        },
        "all_trials": trial_results
    }
    
    # Add unrestricted hybrid data (if enabled)
    # Add unrestricted hybrid data (if enabled)
    if include_unrestricted_hybrid:
        final_results["unrestricted_hybrid"] = {
            alpha: {
                "avg_accuracy": unrestricted_avg[alpha],
                "std_accuracy": unrestricted_std[alpha],
                "avg_cost": unrestricted_cost_avg[alpha],
                "std_cost": unrestricted_cost_std[alpha]
            } for alpha in alphas
        }
    
    # Add random baseline data if available
    if random_baseline_data:
        final_results["random_baseline"] = random_baseline_data
    
    # Add unrestricted random baseline data if available
    if include_unrestricted_hybrid and include_random_baseline and unrestricted_random_baseline_data:
        final_results["unrestricted_random_baseline"] = unrestricted_random_baseline_data
    
    # Save final results to the results directory
    final_file = os.path.join(results_dir, f"mmlu_{subject_name}_{method_name}_final_results_{num_trials}_trials.json")
    with open(final_file, "w") as f:
        json.dump(final_results, f, default=custom_json_serializer)
    
    # Plot the results
    try:
        from plotting import plot_cost_vs_accuracy_improved, plot_enhanced_performance_improved
        
        # Plot cost vs accuracy (improved)
        cost_plot = plot_cost_vs_accuracy_improved(final_results, output_dir=results_dir)
        print(f"Created cost vs accuracy plot: {cost_plot}")
        
        # Plot enhanced performance (improved)
        perf_plot = plot_enhanced_performance_improved(final_results, output_dir=results_dir)
        print(f"Created enhanced performance plot: {perf_plot}")
    except ImportError:
        print("Could not import plotting module for generating plots.")
        print("Results are saved to JSON files for later analysis.")
    
    return final_results

def custom_json_serializer(obj):
    """Custom JSON serializer for objects not serializable by default json code"""
    if isinstance(obj, (np.integer, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        try:
            return str(obj)
        except:
            return None

# Function moved to conformal_plotting.py

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Evaluate MMLU with conformal alignment and cost tracking"
    )
    parser.add_argument("--api_key",        type=str,   default=None,
                        help="OpenAI API key (or env OPENAI_API_KEY)")
    parser.add_argument("--num_trials",     type=int,   default=30,
                        help="Number of independent trials")
    parser.add_argument("--alphas",         type=float, nargs="+",
                        default=[0.05, 0.1, 0.15, 0.2, 0.25],
                        help="Risk levels for conformal calibration")
    parser.add_argument("--max_workers",    type=int,   default=8,
                        help="Parallel workers for scoring")
    parser.add_argument("--calibration_size", type=int, default=500,
                        help="Calibration set size")
    parser.add_argument("--total_examples",   type=int, default=1000,
                        help="Total MMLU examples to sample")
    parser.add_argument("--random_seed",      type=int, default=42,
                        help="Random seed for splitting")
    parser.add_argument("--results_dir",    type=str, default="mmlu_results",
                        help="Directory to save results")
    parser.add_argument("--evaluation_mode",
                        choices=["run", "plot_only"],
                        default="run",
                        help="Mode: run (full evaluation) or plot_only (load & plot only)")
    parser.add_argument("--standard_plot",
                        action="store_true",
                        default=False,
                        help="In plot_only mode, use the original plotting style")
    parser.add_argument("--include_random",      action="store_true", default=True,
                        help="Also compute a random‐baseline routing")
    parser.add_argument("--include_unrestricted", action="store_true", default=True,
                        help="Include unrestricted hybrid in plots")
    parser.add_argument("--no_unrestricted", action="store_false", dest="include_unrestricted",
                        help="Exclude unrestricted hybrid from plots")
    parser.add_argument("--subject", type=str, default=None,
                        help=f"MMLU subject to evaluate (if None, samples from all subjects)")
    
    args = parser.parse_args()
    
    # Check if subject is valid
    if args.subject is not None and args.subject not in MMLU_SUBJECTS:
        print(f"Warning: '{args.subject}' is not in the list of known MMLU subjects")
        print(f"Valid subjects are: {', '.join(MMLU_SUBJECTS)}")
        print("Continuing anyway, but be aware that this might fail if the subject doesn't exist")
    
    # ─── PLOT-ONLY BRANCH ─────────────────────────────────────────────────────
    if args.evaluation_mode == "plot_only":
        if not args.results_dir:
            print("Error: --results_dir must be specified for plot_only mode")
            exit(1)
        print(f"Loading and plotting results from {args.results_dir}...")
        try:
            from plotting import create_simple_plots
            create_simple_plots(args.results_dir)
        except ImportError:
            print("Error: Could not import plotting module for plot_only mode")
            print("Make sure the plotting.py file is available")
        exit(0)
    
    # ─── Make sure args match unless explicitly overridden ───────────────────────
    # Enable these by default for MMLU evaluation
    if not hasattr(args, 'include_random') or args.include_random is None:
        args.include_random = True
    if not hasattr(args, 'include_unrestricted') or args.include_unrestricted is None:
        args.include_unrestricted = True
    
    # ─── FULL EVALUATION ───────────────────────────────────────────────────────
    # Check if results directory already exists with result files
    if args.results_dir and os.path.exists(args.results_dir):
        # Look for final results files
        results_files = [f for f in os.listdir(args.results_dir) 
                        if f.endswith("_trials.json") and "final_results" in f]
        if results_files:
            print(f"Results directory {args.results_dir} already contains results files.")
            print("Loading existing results to display statistics...")
            
            # Load the most recent final results file
            final_file = sorted(results_files)[-1]
            with open(os.path.join(args.results_dir, final_file), "r") as f:
                final_results = json.load(f)
            
            # Display summary statistics
            small_avg = final_results["small_model"]["avg_accuracy"]
            small_std = final_results["small_model"]["std_accuracy"]
            small_cost_avg = final_results["small_model"]["avg_cost"]
            small_cost_std = final_results["small_model"]["std_cost"]
            
            large_avg = final_results["large_model"]["avg_accuracy"]
            large_std = final_results["large_model"]["std_accuracy"]
            large_cost_avg = final_results["large_model"]["avg_cost"]
            large_cost_std = final_results["large_model"]["std_cost"]
            
            alphas = list(final_results["hybrid_models"].keys())
            
            print(f"\n===== Summary Statistics From {final_results['iterations']} Trials =====")
            print(f"Subject: {final_results.get('subject', 'Unknown')}")
            print(f"Small model: {small_avg:.4f} ± {small_std:.4f}, Cost: ${small_cost_avg:.6f} ± ${small_cost_std:.6f}")
            print(f"Large model: {large_avg:.4f} ± {large_std:.4f}, Cost: ${large_cost_avg:.6f} ± ${large_cost_std:.6f}")
            
            for alpha in alphas:
                hybrid_data = final_results["hybrid_models"][alpha]
                print(f"\nAlpha={alpha}:")
                print(f"  Hybrid accuracy: {hybrid_data['avg_accuracy']:.4f} ± {hybrid_data['std_accuracy']:.4f}")
                print(f"  Large model usage: {hybrid_data['avg_large_model_usage']:.2%} ± {hybrid_data['std_large_model_usage']:.2%}")
                print(f"  Lambda threshold: {hybrid_data['avg_lambda']:.4f} ± {hybrid_data['std_lambda']:.4f}")
                print(f"  Hybrid cost: ${hybrid_data['avg_cost']:.6f} ± {hybrid_data['std_cost']:.6f}")
                
                if "unrestricted_hybrid" in final_results and alpha in final_results["unrestricted_hybrid"]:
                    unrestricted = final_results["unrestricted_hybrid"][alpha]
                    print(f"  Unrestricted hybrid accuracy: {unrestricted['avg_accuracy']:.4f} ± {unrestricted['std_accuracy']:.4f}")
                    print(f"  Unrestricted cost: ${unrestricted['avg_cost']:.6f} ± {unrestricted['std_cost']:.6f}")
                
                if "random_baseline" in final_results and alpha in final_results["random_baseline"]:
                    random = final_results["random_baseline"][alpha]
                    print(f"  Random baseline: {random['avg_accuracy']:.4f} ± {random['std_accuracy']:.4f}, Cost: ${random['avg_cost']:.6f} ± ${random['std_cost']:.6f}")
            
            # Generate plots
            try:
                from plotting import create_simple_plots
                create_simple_plots(args.results_dir)
            except ImportError:
                print("\nWarning: Could not import plotting module for generating plots.")
                print("Results statistics have been displayed but plots could not be generated.")
            
            exit(0)
    
    # If no existing results, run full evaluation
    run_conformal_evaluation(
        alphas=args.alphas,
        api_key=args.api_key,
        max_workers=args.max_workers,
        num_trials=args.num_trials,
        calibration_size=args.calibration_size,
        total_examples=args.total_examples,
        random_seed=args.random_seed,
        results_dir=args.results_dir,
        subject=args.subject,
        include_random_baseline=args.include_random,
        include_unrestricted_hybrid=args.include_unrestricted
    )