#!/usr/bin/env python3
"""
truthqa_standalone.py

Evaluate the TruthfulQA multiple‐choice dataset with conformal risk control
and full OpenAI cost accounting. This is a standalone version that includes all
the necessary functionality without relying on an external mmlu_cost module.
"""
import os
import random
import json
import time
import math
import argparse
import numpy as np
import concurrent.futures
from tqdm import tqdm
from typing import List, Dict, Tuple, Any, Optional
import openai
from datasets import load_dataset

# -----------------------------------------------------------------------------
# OpenAI API Wrapper for confidence scores with cost tracking
# -----------------------------------------------------------------------------
class OpenAIWrapper:
    """Wrapper for OpenAI API to get confidence scores with cost tracking."""
    def __init__(self, model_name, api_key=None):
        """Initialize with model name and optional API key."""
        self.model_name = model_name
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.api_calls = 0
        self.total_cost = 0.0
        
        # Define pricing per 1M tokens
        self.pricing = {
            "gpt-4.1-2025-04-14": {"input": 2.00, "output": 8.00, "cached_input": 0.50},
            "gpt-4.1-mini-2025-04-14": {"input": 0.40, "output": 1.60, "cached_input": 0.10},
            "gpt-4.1-nano-2025-04-14": {"input": 0.10, "output": 0.40, "cached_input": 0.025}
        }
        
        # Map shortened model names to full names for pricing lookup
        self.model_name_map = {
            "gpt-4.1": "gpt-4.1-2025-04-14",
            "gpt-4.1-mini": "gpt-4.1-mini-2025-04-14",
            "gpt-4.1-nano": "gpt-4.1-nano-2025-04-14"
        }
        
        if api_key:
            self.client = openai.OpenAI(api_key=api_key)
        else:
            # Use environment variable
            self.client = openai.OpenAI()
    
    def get_scores(self, problem, answer_choices, filter_choices=None):
        """
        Get confidence scores for each answer choice.
        
        Parameters:
        -----------
        problem : str
            The problem statement
        answer_choices : list
            List of answer choices
        filter_choices : list, optional
            If provided, only these choices (indices) will be sent to the model
        
        Returns:
        --------
        list
            Confidence scores for each answer choice in the original list
        """
        # Use original choices by default
        choices_to_send = answer_choices
        choice_indices = list(range(len(answer_choices)))
        
        # If filter_choices is provided, filter the choices
        if filter_choices is not None and len(filter_choices) < len(answer_choices):
            choices_to_send = [answer_choices[i] for i in filter_choices]
            choice_indices = filter_choices
        
        # More structured prompt specifically designed for consistent token outputs
        messages = [
            {"role": "system", "content": """You are an expert who evaluates multiple choice questions.

# Instructions
- Assign a confidence score to each answer choice on a scale from 0 to 1 
- 0 means certainly incorrect, 1 means certainly correct
- Don't assign similar scores to choices unless genuinely equally uncertain

# Response Format
- Output ONLY a valid JSON object with a "scores" key containing an array of numbers
- Example: {"scores": [0.1, 0.8, 0.05, 0.05]}
- NO explanations, just the JSON"""},
            {"role": "user", "content": f"""Question:
{problem}

Answer Choices:
{json.dumps(choices_to_send)}

Respond ONLY with a JSON object containing your confidence scores for these choices, like: {{"scores": [0.1, 0.8, 0.05, 0.05]}}"""}
        ]
        
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=0.1,
                response_format={"type": "json_object"},
                max_tokens=50  # Keep it small for just scores
            )
            
            # Track token usage
            prompt_tokens = response.usage.prompt_tokens
            completion_tokens = response.usage.completion_tokens
            
            # Update tracking counters
            self.total_prompt_tokens += prompt_tokens
            self.total_completion_tokens += completion_tokens
            self.api_calls += 1
            
            # Calculate cost
            # Get the pricing model
            pricing_model = self.model_name
            if pricing_model in self.model_name_map:
                pricing_model = self.model_name_map[pricing_model]
            
            # Calculate cost (convert from per 1M tokens to per token)
            input_cost = prompt_tokens * (self.pricing[pricing_model]["input"] / 1000000)
            output_cost = completion_tokens * (self.pricing[pricing_model]["output"] / 1000000)
            total_call_cost = input_cost + output_cost
            
            self.total_cost += total_call_cost
            
            # Parse the response
            content = response.choices[0].message.content
            
            try:
                # Parse the JSON response
                response_json = json.loads(content)
                
                # Directly expect "scores" field since we've specified it explicitly
                if "scores" in response_json:
                    filtered_scores = response_json["scores"]
                else:
                    # If no scores field, check if it's just an array
                    if isinstance(response_json, list):
                        filtered_scores = response_json
                    else:
                        # Try to find any array in the response
                        for key, value in response_json.items():
                            if isinstance(value, list) and len(value) == len(choices_to_send):
                                filtered_scores = value
                                break
                        else:
                            print(f"Could not find scores in response: {content}")
                            # Use a random distribution rather than uniform to ensure conformal sets don't get too large
                            filtered_scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
                
                # Ensure we have the right number of scores
                if len(filtered_scores) != len(choices_to_send):
                    print(f"Warning: Got {len(filtered_scores)} scores for {len(choices_to_send)} choices")
                    if len(filtered_scores) < len(choices_to_send):
                        # Add random values rather than uniform values to ensure variation
                        filtered_scores.extend([0.1 + 0.8 * random.random() for _ in range(len(choices_to_send) - len(filtered_scores))])
                    else:
                        filtered_scores = filtered_scores[:len(choices_to_send)]
                
                # Ensure all scores are between 0 and 1
                filtered_scores = [min(max(float(s), 0.0), 1.0) for s in filtered_scores]
                
                # Add small randomness to break ties and avoid all equal scores
                if all(s == filtered_scores[0] for s in filtered_scores):
                    print("All scores are equal, adding random variation")
                    filtered_scores = [s + random.uniform(-0.05, 0.05) for s in filtered_scores]
                    filtered_scores = [min(max(s, 0.0), 1.0) for s in filtered_scores]
                
                # If we filtered choices, expand back to the original size with zeros for missing choices
                if filter_choices is not None:
                    full_scores = [0.0] * len(answer_choices)
                    for i, score in zip(choice_indices, filtered_scores):
                        full_scores[i] = score
                    
                    # Normalize the scores
                    total = sum(full_scores)
                    if total > 0:
                        full_scores = [s / total for s in full_scores]
                    
                    return full_scores, total_call_cost, prompt_tokens, completion_tokens
                
                # Normalize to sum to 1
                total = sum(filtered_scores)
                if total > 0:
                    normalized_scores = [s / total for s in filtered_scores]
                else:
                    # Random distribution instead of uniform
                    normalized_scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
                    total = sum(normalized_scores)
                    normalized_scores = [s / total for s in normalized_scores]
                
                return normalized_scores, total_call_cost, prompt_tokens, completion_tokens
                
            except Exception as e:
                print(f"Error parsing JSON response: {e}")
                print(f"Raw content: {content}")
                # Use a random distribution to ensure conformal sets don't get too large
                scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
                total = sum(scores)
                normalized_scores = [s / total for s in scores]
                return normalized_scores, total_call_cost, prompt_tokens, completion_tokens
                
        except Exception as e:
            print(f"API request error: {e}")
            # Return random scores in case of API error
            scores = [0.1 + 0.8 * random.random() for _ in range(len(choices_to_send))]
            total = sum(scores)
            normalized_scores = [s / total for s in scores]
            return normalized_scores, 0.0, 0, 0  # Zero cost for failed requests
    
    def reset_tracking(self):
        """Reset the tracking counters."""
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.api_calls = 0
        self.total_cost = 0.0


# -----------------------------------------------------------------------------
# TruthfulQA dataset loading functions
# -----------------------------------------------------------------------------
def load_truthfulqa_data(subject=None,
                         num_examples: int = 1000,
                         random_seed: int = 42):
    """
    Load the EleutherAI/truthful_qa_mc dataset with exactly 684 MC questions.
    Split into calibration and evaluation sets.
    
    Returns: (data_dict, subject_name)
    """
    random.seed(random_seed)

    # Load the EleutherAI/truthful_qa_mc dataset
    ds = load_dataset("EleutherAI/truthful_qa_mc")
    
    print(f"Available splits in the dataset: {list(ds.keys())}")
    
    # Get examples from main split (should be 'train')
    all_examples = []
    if 'train' in ds:
        all_examples = list(ds['train'])
    else:
        # If 'train' doesn't exist, take examples from any available split
        for split_name in ds.keys():
            all_examples.extend(list(ds[split_name]))
    
    print(f"Loaded {len(all_examples)} multiple-choice questions from EleutherAI/truthful_qa_mc")
    
    # Show a sample to verify format
    if all_examples:
        sample = all_examples[0]
        print(f"Sample format - keys: {list(sample.keys())}")
        if 'question' in sample and 'choices' in sample and 'label' in sample:
            print(f"Sample question: {sample['question'][:50]}...")
            print(f"Number of choices: {len(sample.get('choices', []))}")
            print(f"Label index: {sample.get('label', 'None')}")
    
    random.shuffle(all_examples)

    # Limit to requested number of examples
    n = min(num_examples, len(all_examples))
    
    # Do not split here - put all examples in both splits
    # This allows run_conformal_evaluation to handle the split with calibration_size
    data = {
        "train": [],  # Empty since we don't need it for conformal calibration
        "validation": all_examples[:n],  # All examples - calibration_size will be applied later
        "test": all_examples[:n]  # All examples - will be filtered after calibration
    }

    print(f"Loaded {len(data['validation'])} examples - calibration_size parameter will determine the split")

    # Return tuple (data, subject_name) as expected by run_conformal_evaluation
    return data, "TruthfulQA"

def process_truthfulqa_problem(example: dict) -> dict:
    """
    Convert a single TruthfulQA example from EleutherAI/truthful_qa_mc format into
    {'problem': str, 'choices': List[str], 'correct_index': int} format.
    
    The EleutherAI format has:
    - 'question': The question text
    - 'choices': List of answer choices
    - 'label': Index of the correct answer (0-based)
    """
    question = example["question"]
    
    # EleutherAI/truthful_qa_mc format has direct choices and label fields
    if "choices" in example and "label" in example:
        choices = example["choices"]
        correct_index = example["label"]
    # Fallback to original format if using the original dataset
    elif "mc1_targets" in example:
        choices = example["mc1_targets"]["choices"]
        labels = example["mc1_targets"]["labels"]
        try:
            correct_index = labels.index(1)
        except ValueError:
            correct_index = 0
    else:
        # Emergency fallback (shouldn't happen with EleutherAI dataset)
        print(f"WARNING: Unexpected example format for question: {question[:50]}...")
        choices = ["Option A", "Option B", "Option C", "Option D"]
        correct_index = 0
    
    # Ensure we have valid choices
    if not isinstance(choices, list) or len(choices) < 2:
        print(f"WARNING: Invalid choices for question: {question[:50]}...")
        choices = ["Option A", "Option B", "Option C", "Option D"]
    
    # Ensure correct_index is valid
    if not isinstance(correct_index, int) or correct_index < 0 or correct_index >= len(choices):
        print(f"WARNING: Invalid correct_index ({correct_index}) for question: {question[:50]}...")
        correct_index = 0
    
    return {
        "problem": question,
        "choices": choices,
        "correct_index": correct_index
    }

# -----------------------------------------------------------------------------
# Conformal evaluation helper functions
# -----------------------------------------------------------------------------
def process_single_calibration_example(args):
    """Process a single calibration example - to be used with parallel processing"""
    index, example, small_model, large_model = args
    
    processed = process_truthfulqa_problem(example)
    
    print(f"Getting scores for calibration example {index+1}")
    
    # Get scores from both models
    small_scores, small_cost, small_prompt_tokens, small_completion_tokens = small_model.get_scores(
        processed["problem"], processed["choices"]
    )
    print(f"Small model scores: {small_scores}")
    
    # Sleep to avoid rate limits - shorter sleep between batches
    time.sleep(0.5)
    
    large_scores, large_cost, large_prompt_tokens, large_completion_tokens = large_model.get_scores(
        processed["problem"], processed["choices"]
    )
    print(f"Large model scores: {large_scores}")
    
    return {
        "context": processed["problem"],
        "choices": processed["choices"],
        "small_scores": small_scores,
        "large_scores": large_scores,
        "correct_index": processed["correct_index"],
        "small_cost": small_cost,
        "large_cost": large_cost,
        "small_tokens": {"prompt": small_prompt_tokens, "completion": small_completion_tokens},
        "large_tokens": {"prompt": large_prompt_tokens, "completion": large_completion_tokens}
    }

def create_calibration_dataset_parallel(small_model, large_model, data, num_samples=300, max_workers=4):
    """Create a calibration dataset with scores from both models using parallel processing."""
    # Reset tracking
    small_model.reset_tracking()
    large_model.reset_tracking()
    
    # Use a subset of validation data for calibration
    validation_sample = data["validation"]
    if len(validation_sample) > num_samples:
        # Randomly sample to ensure diversity
        validation_sample = random.sample(validation_sample, num_samples)
    
    calibration_data = []
    
    # Prepare arguments for parallel processing
    args_list = [(i, example, small_model, large_model) 
                for i, example in enumerate(validation_sample)]
    
    # Process in parallel with a thread pool
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit tasks and process results as they complete
        future_to_index = {executor.submit(process_single_calibration_example, arg): i 
                           for i, arg in enumerate(args_list)}
        
        # Process results as they complete
        for future in tqdm(concurrent.futures.as_completed(future_to_index), 
                          total=len(args_list), desc="Processing calibration examples"):
            idx = future_to_index[future]
            try:
                result = future.result()
                calibration_data.append(result)
            except Exception as exc:
                print(f"Processing example {idx} generated an exception: {exc}")
    
    # Sort calibration data by original index to maintain consistency
    calibration_data.sort(key=lambda x: validation_sample.index(next(e for e in validation_sample 
                                                               if e["question"] == x["context"])))
    
    return calibration_data

def compute_standard_conformal_threshold(calibration_data, alpha=0.1):
    """Compute the standard conformal threshold lambda for a given risk level alpha."""
    n = len(calibration_data)
    losses = []
    lambda_values = np.linspace(0, 1, 101)
    
    # Store detailed metrics for each lambda
    lambda_metrics = []
    
    for lam in lambda_values:
        loss_sum = 0
        conformal_set_sizes = []
        small_model_usage_count = 0
        
        for sample in calibration_data:
            small_scores = sample["small_scores"]
            large_scores = sample["large_scores"]
            correct_index = sample["correct_index"]
            
            # Construct conformal action set
            max_score = max(small_scores)
            conformal_set = [i for i, score in enumerate(small_scores) if score >= max_score - lam]
            conformal_set_sizes.append(len(conformal_set))
            
            # Count cases where we'd use the small model
            if len(conformal_set) == 1:
                small_model_usage_count += 1
            
            # Calculate loss based on correctness
            # Get large model's choice from all options
            large_model_choice = np.argmax(large_scores)
            large_model_correct = (large_model_choice == correct_index)

            if correct_index in conformal_set:
                # Correct answer is in the conformal set - no loss
                loss = 0
            else:
                # Correct answer is NOT in the conformal set
                if large_model_correct:
                    # Large model would get it right with all options
                    # But correct answer isn't in conformal set, so we lose this one
                    loss = 1
                else:
                    # Large model would get it wrong anyway, so no real loss
                    loss = 0
                            
            loss_sum += loss
        
        avg_loss = loss_sum / n
        avg_set_size = sum(conformal_set_sizes) / len(conformal_set_sizes)
        small_model_usage = small_model_usage_count / n
        
        lambda_metrics.append({
            "lambda": lam,
            "avg_loss": avg_loss,
            "avg_set_size": avg_set_size,
            "small_model_usage": small_model_usage,
            "risk_bound": (n / (n + 1)) * avg_loss + (1 / (n + 1))
        })
        
        losses.append((avg_loss, avg_set_size))
    
    # Find the smallest lambda that satisfies the risk constraint
    selected_lambda = None
    selected_idx = None
    
    for i, metrics in enumerate(lambda_metrics):
        if metrics["risk_bound"] <= alpha:
            selected_lambda = lambda_values[i]
            selected_idx = i
            break
    
    if selected_lambda is None:
        print("Warning: Could not find a suitable lambda, using the maximum value")
        selected_lambda = lambda_values[-1]
        selected_idx = len(lambda_values) - 1
    
    # Print detailed information about lambda selection
    print("\nLambda Selection Details:")
    print(f"Target alpha: {alpha}")
    print(f"Selected lambda: {selected_lambda:.4f}")
    
    for i in range(max(0, selected_idx-2), min(len(lambda_metrics), selected_idx+3)):
        metrics = lambda_metrics[i]
        print(f"λ={metrics['lambda']:.4f}: Loss={metrics['avg_loss']:.4f}, Risk={metrics['risk_bound']:.4f}, SmallUsage={metrics['small_model_usage']:.2%}")
    
    return selected_lambda, lambda_metrics


def precompute_test_scores(test_examples, small_model, large_model, max_workers=4):
    """Precompute scores for all test examples to ensure consistency."""
    
    def process_test_example(args):
        idx, example, small_model, large_model = args
        processed = process_truthfulqa_problem(example)
        
        # Get scores from both models
        small_scores, small_cost, small_prompt_tokens, small_completion_tokens = small_model.get_scores(
            processed["problem"], processed["choices"]
        )
        
        large_scores, large_cost, large_prompt_tokens, large_completion_tokens = large_model.get_scores(
            processed["problem"], processed["choices"]
        )
        
        return {
            "problem": processed["problem"],
            "choices": processed["choices"],
            "correct_index": processed["correct_index"],
            "small_scores": small_scores,
            "large_scores": large_scores,
            "small_pred": np.argmax(small_scores),
            "large_pred": np.argmax(large_scores),
            "small_cost": small_cost,
            "large_cost": large_cost,
            "small_tokens": {"prompt": small_prompt_tokens, "completion": small_completion_tokens},
            "large_tokens": {"prompt": large_prompt_tokens, "completion": large_completion_tokens}
        }
    
    # Prepare arguments for parallel processing
    args_list = [(i, example, small_model, large_model) for i, example in enumerate(test_examples)]
    
    # Process in parallel
    test_data = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_idx = {executor.submit(process_test_example, arg): i for i, arg in enumerate(args_list)}
        
        for future in tqdm(concurrent.futures.as_completed(future_to_idx), 
                          total=len(args_list), desc="Precomputing test example scores"):
            idx = future_to_idx[future]
            try:
                result = future.result()
                test_data.append(result)
                
                # Short sleep to avoid rate limits
                time.sleep(0.1)
            except Exception as exc:
                print(f"Processing test example {idx} generated an exception: {exc}")
    
    return test_data

def calculate_random_baseline(test_data, small_model_usage_fraction, num_samples=100):
    """
    Calculate a random baseline that randomly chooses between small and large model
    based on the observed small model usage fraction.
    
    Now also calculates average cost.
    """
    accuracies = []
    costs = []
    
    for _ in range(num_samples):
        correct_count = 0
        total_cost = 0.0
        
        for example in test_data:
            # Extract information
            small_pred = example["small_pred"]
            large_pred = example["large_pred"]
            correct_index = example["correct_index"]
            small_cost = example["small_cost"]
            large_cost = example["large_cost"]
            
            # Randomly decide whether to use small or large model
            if random.random() < small_model_usage_fraction:
                # Use small model
                pred = small_pred
                cost = small_cost
            else:
                # Use large model
                pred = large_pred
                cost = large_cost
            
            # Add cost
            total_cost += cost
            
            # Check if prediction is correct
            if pred == correct_index:
                correct_count += 1
        
        # Calculate accuracy and average cost for this sample
        accuracy = correct_count / len(test_data)
        avg_cost = total_cost / len(test_data)
        
        accuracies.append(accuracy)
        costs.append(avg_cost)
    
    # Return average and standard deviation for both accuracy and cost
    return (np.mean(accuracies), np.std(accuracies)), (np.mean(costs), np.std(costs))

def calculate_cost_matched_random_baseline(test_data, target_cost, small_cost_avg, large_cost_avg, num_samples=100):
    """
    Calculate a random baseline that matches a target cost.
    
    Parameters:
    -----------
    test_data : list
        List of test examples with precomputed scores for both models
    target_cost : float
        Target average cost per example to match
    small_cost_avg : float
        Average cost of small model per example
    large_cost_avg : float
        Average cost of large model per example
    num_samples : int
        Number of random trials to average over
    
    Returns:
    --------
    tuple
        (Accuracy mean/std, Cost mean/std, Small model usage fraction)
    """
    # Calculate what fraction of small model calls would achieve the target cost
    # Using the formula: target_cost = p*small_cost + (1-p)*large_cost
    # Solving for p: p = (large_cost - target_cost) / (large_cost - small_cost)
    
    if large_cost_avg == small_cost_avg:
        # Edge case: if costs are identical, use 50/50 split
        small_model_fraction = 0.5
    else:
        small_model_fraction = (large_cost_avg - target_cost) / (large_cost_avg - small_cost_avg)
    
    # Clamp the fraction to [0, 1] range
    small_model_fraction = max(0.0, min(1.0, small_model_fraction))
    
    accuracies = []
    costs = []
    
    for _ in range(num_samples):
        correct_count = 0
        total_cost = 0.0
        
        for example in test_data:
            # Extract information
            small_pred = example["small_pred"]
            large_pred = example["large_pred"]
            correct_index = example["correct_index"]
            small_cost = example["small_cost"]
            large_cost = example["large_cost"]
            
            # Randomly decide whether to use small or large model
            if random.random() < small_model_fraction:
                # Use small model
                pred = small_pred
                cost = small_cost
            else:
                # Use large model
                pred = large_pred
                cost = large_cost
            
            # Add cost
            total_cost += cost
            
            # Check if prediction is correct
            if pred == correct_index:
                correct_count += 1
        
        # Calculate accuracy and average cost for this sample
        accuracy = correct_count / len(test_data)
        avg_cost = total_cost / len(test_data)
        
        accuracies.append(accuracy)
        costs.append(avg_cost)
    
    # Return average and standard deviation for both accuracy and cost
    return (np.mean(accuracies), np.std(accuracies)), (np.mean(costs), np.std(costs)), small_model_fraction

def custom_json_serializer(obj):
    """Custom JSON serializer for objects not serializable by default json code"""
    if isinstance(obj, (np.integer, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        try:
            return str(obj)
        except:
            return None

# -----------------------------------------------------------------------------
# Main: Conformal evaluation function
# -----------------------------------------------------------------------------
def run_conformal_evaluation(
    alphas=[0.1, 0.2, 0.3, 0.4, 0.5], 
    api_key=None, 
    max_workers=4,
    num_trials=5,
    calibration_size=300,
    total_examples=1500,
    random_seed=42, 
    results_dir=None,
    subject=None,
    include_random_baseline=True,
    include_unrestricted_hybrid=True):
    """
    Run conformal evaluation with properly independent trials.
    Now includes cost tracking and unrestricted hybrid option.
    """
    print(f"Starting revised conformal evaluation with {num_trials} independent trials")
    print(f"Method: Standard conformal")
    
    # Create results directory first
    method_name = "standard"
    if results_dir is None:
        # We'll update with actual subject name after loading data
        results_dir = f"truthqa_results_{method_name}_revised"
    os.makedirs(results_dir, exist_ok=True)
    
    # Load the full dataset
    full_data, subject_name = load_truthfulqa_data(subject=subject, num_examples=total_examples * 2)  # Request extra to ensure enough examples
    print(f"Subject: {subject_name}")
    
    # Update results directory with subject name
    if results_dir == f"truthqa_results_{method_name}_revised":
        results_dir = f"truthqa_results_{method_name}_revised_{subject_name}"
        os.makedirs(results_dir, exist_ok=True)
    
    # Combine all examples for the pool
    all_available_examples = full_data["train"] + full_data["validation"] + full_data["test"]
    print(f"Total available examples: {len(all_available_examples)}")
    
    # Results storage for trial statistics
    trial_results = []
    small_model_accuracies = []
    large_model_accuracies = []
    small_model_costs = []
    large_model_costs = []
    
    # Hybrid model statistics
    hybrid_accuracies = {alpha: [] for alpha in alphas}
    large_model_calls = {alpha: [] for alpha in alphas}
    lambda_values = {alpha: [] for alpha in alphas}
    hybrid_costs = {alpha: [] for alpha in alphas}
    
    # Unrestricted hybrid statistics (if enabled)
    if include_unrestricted_hybrid:
        unrestricted_hybrid_accuracies = {alpha: [] for alpha in alphas}
        unrestricted_hybrid_costs = {alpha: [] for alpha in alphas}
        unrestricted_random_baseline_accuracies = {alpha: [] for alpha in alphas}
        unrestricted_random_baseline_costs = {alpha: [] for alpha in alphas}
        unrestricted_random_baseline_sm_fractions = {alpha: [] for alpha in alphas}
    
    # Random baseline statistics (if enabled)
    if include_random_baseline:
        random_baseline_accuracies = {alpha: [] for alpha in alphas}
        random_baseline_costs = {alpha: [] for alpha in alphas}
        random_baseline_sm_fractions = {alpha: [] for alpha in alphas}
    
    # For each trial
    for trial_idx in range(num_trials):
        print(f"\n{'='*30}")
        print(f" TRIAL {trial_idx+1}/{num_trials} ")
        print(f"{'='*30}")
        
        # Set trial-specific random seed for reproducibility
        trial_seed = random_seed + trial_idx
        random.seed(trial_seed)
        
        # Sample a subset of examples for this trial
        if len(all_available_examples) <= total_examples:
            trial_pool = all_available_examples.copy()
            print(f"Using all {len(trial_pool)} available examples for this trial")
        else:
            trial_pool = random.sample(all_available_examples, total_examples)
            print(f"Sampled {len(trial_pool)} examples for this trial")
        
        # Initialize models for this trial
        small_model = OpenAIWrapper("gpt-4.1-nano-2025-04-14", api_key)
        large_model = OpenAIWrapper("gpt-4.1-2025-04-14", api_key)
        
        # First check if we already have scored examples for this trial
        scored_examples_file = os.path.join(results_dir, f"scored_examples_trial_{trial_idx+1}.json")
        if os.path.exists(scored_examples_file):
            print(f"Found pre-computed scores for trial {trial_idx+1}, loading from {scored_examples_file}")
            try:
                with open(scored_examples_file, "r") as f:
                    saved_data = json.load(f)
                    if "examples" in saved_data and len(saved_data["examples"]) > 0:
                        print(f"Loaded {len(saved_data['examples'])} pre-computed examples")
                        scored_examples = saved_data["examples"]
                    else:
                        print("No examples found in saved file, computing scores...")
                        scored_examples = precompute_test_scores(
                            trial_pool,
                            small_model,
                            large_model,
                            max_workers=max_workers
                        )
            except Exception as e:
                print(f"Error loading pre-computed scores: {e}")
                print("Computing scores from scratch...")
                scored_examples = precompute_test_scores(
                    trial_pool,
                    small_model,
                    large_model,
                    max_workers=max_workers
                )
        else:
            # Score all examples in this trial's pool
            print(f"No pre-computed scores found for trial {trial_idx+1}, computing scores...")
            scored_examples = precompute_test_scores(
                trial_pool,
                small_model,
                large_model,
                max_workers=max_workers
            )
        
        # Calculate trial-specific baseline accuracies and costs
        small_correct = sum(1 for example in scored_examples if example["small_pred"] == example["correct_index"])
        large_correct = sum(1 for example in scored_examples if example["large_pred"] == example["correct_index"])
        
        small_acc = small_correct / len(scored_examples)
        large_acc = large_correct / len(scored_examples)
        
        small_avg_cost = sum(example["small_cost"] for example in scored_examples) / len(scored_examples)
        large_avg_cost = sum(example["large_cost"] for example in scored_examples) / len(scored_examples)
        
        small_model_accuracies.append(small_acc)
        large_model_accuracies.append(large_acc)
        small_model_costs.append(small_avg_cost)
        large_model_costs.append(large_avg_cost)
        
        print(f"Trial {trial_idx+1} baseline accuracies and costs:")
        print(f"  Small model: {small_acc:.4f}, Cost: ${small_avg_cost:.6f} per example")
        print(f"  Large model: {large_acc:.4f}, Cost: ${large_avg_cost:.6f} per example")
        
        # Save scored examples for this trial
        scored_examples_file = os.path.join(results_dir, f"scored_examples_trial_{trial_idx+1}.json")
        with open(scored_examples_file, "w") as f:
            json.dump({
                "trial": trial_idx + 1,
                "trial_seed": trial_seed,
                "subject": subject_name,
                "small_model_accuracy": small_acc,
                "small_model_cost": small_avg_cost,
                "large_model_accuracy": large_acc,
                "large_model_cost": large_avg_cost,
                "examples": scored_examples
            }, f, default=custom_json_serializer)
        
        # Shuffle examples before splitting
        random.shuffle(scored_examples)
        
        # Split into calibration and evaluation sets
        if calibration_size >= len(scored_examples):
            print(f"Warning: Requested calibration size {calibration_size} exceeds available data {len(scored_examples)}")
            calibration_size = len(scored_examples) // 2
            
        calibration_examples = scored_examples[:calibration_size]
        test_examples = scored_examples[calibration_size:]
        
        # Save the data split information for reproducibility
        data_split_info = {
            "trial": trial_idx + 1,
            "trial_seed": trial_seed,
            "subject": subject_name,
            "total_examples": len(scored_examples),
            "calibration_size": len(calibration_examples),
            "test_size": len(test_examples),
            "small_model_accuracy": small_acc,
            "small_model_cost": small_avg_cost,
            "large_model_accuracy": large_acc,
            "large_model_cost": large_avg_cost
        }
        
        # Save data split info to file
        split_file = os.path.join(results_dir, f"data_split_trial_{trial_idx+1}.json")
        with open(split_file, "w") as f:
            json.dump(data_split_info, f, default=custom_json_serializer)
        
        print(f"Using {len(calibration_examples)} examples for calibration and {len(test_examples)} for testing")
        
        # Check if we already have calibration data for this trial
        calibration_file = os.path.join(results_dir, f"calibration_data_{method_name}_trial_{trial_idx+1}.json")
        if os.path.exists(calibration_file):
            print(f"Found pre-computed calibration data for trial {trial_idx+1}, loading from {calibration_file}")
            try:
                with open(calibration_file, "r") as f:
                    calibration_data = json.load(f)
                print(f"Loaded {len(calibration_data)} pre-computed calibration examples")
            except Exception as e:
                print(f"Error loading pre-computed calibration data: {e}")
                print("Creating calibration data from scratch...")
                # Extract calibration data in the format needed for conformal prediction
                calibration_data = []
                for example in calibration_examples:
                    calibration_data.append({
                        "context": example["problem"],
                        "choices": example["choices"],
                        "small_scores": example["small_scores"],
                        "large_scores": example["large_scores"],
                        "correct_index": example["correct_index"],
                        "small_cost": example["small_cost"],
                        "large_cost": example["large_cost"]
                    })
                # Save calibration data for future reference
                with open(calibration_file, "w") as f:
                    json.dump(calibration_data, f, default=custom_json_serializer)
        else:
            print(f"Creating calibration data for trial {trial_idx+1}")
            # Extract calibration data in the format needed for conformal prediction
            calibration_data = []
            for example in calibration_examples:
                calibration_data.append({
                    "context": example["problem"],
                    "choices": example["choices"],
                    "small_scores": example["small_scores"],
                    "large_scores": example["large_scores"],
                    "correct_index": example["correct_index"],
                    "small_cost": example["small_cost"],
                    "large_cost": example["large_cost"]
                })
            # Save calibration data for future reference
            with open(calibration_file, "w") as f:
                json.dump(calibration_data, f, default=custom_json_serializer)
        
        # We will use test_examples directly since they already have all the scores we need
        test_data = test_examples
        
        # Check if we already have detailed trial results (which would include thresholds)
        detailed_trial_file = os.path.join(results_dir, f"detailed_trial_{trial_idx+1}_results.json")
        if os.path.exists(detailed_trial_file):
            print(f"Found existing detailed results for trial {trial_idx+1}, attempting to load thresholds...")
            try:
                with open(detailed_trial_file, "r") as f:
                    trial_result = json.load(f)
                
                # Try to extract lambda values from existing results
                if "hybrid_results" in trial_result:
                    standard_lambda_thresholds = {}
                    for alpha in alphas:
                        if alpha in trial_result["hybrid_results"] and "avg_lambda" in trial_result["hybrid_results"][alpha]:
                            # Use the stored lambda value
                            standard_lambda_thresholds[alpha] = trial_result["hybrid_results"][alpha]["avg_lambda"]
                            print(f"Loaded threshold for α={alpha}: {standard_lambda_thresholds[alpha]:.4f}")
                        else:
                            # Compute it if not found
                            lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                            standard_lambda_thresholds[alpha] = lambda_threshold
                            print(f"Computed threshold for α={alpha}: {lambda_threshold:.4f}")
                else:
                    # Compute all thresholds if not found
                    standard_lambda_thresholds = {}
                    for alpha in alphas:
                        lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                        standard_lambda_thresholds[alpha] = lambda_threshold
                        print(f"Standard lambda threshold for α={alpha}: {lambda_threshold:.4f}")
            except Exception as e:
                print(f"Error loading existing thresholds: {e}")
                print("Computing thresholds from scratch...")
                # Compute thresholds if loading failed
                standard_lambda_thresholds = {}
                for alpha in alphas:
                    lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                    standard_lambda_thresholds[alpha] = lambda_threshold
                    print(f"Standard lambda threshold for α={alpha}: {lambda_threshold:.4f}")
        else:
            # Compute thresholds if no file exists
            print("Computing conformal thresholds...")
            standard_lambda_thresholds = {}
            for alpha in alphas:
                lambda_threshold, _ = compute_standard_conformal_threshold(calibration_data, alpha=alpha)
                standard_lambda_thresholds[alpha] = lambda_threshold
                print(f"Standard lambda threshold for α={alpha}: {lambda_threshold:.4f}")
        
        # For each alpha, evaluate hybrid model performance using precomputed scores
        trial_hybrid_results = {}
        
        for alpha in alphas:
            print(f"\nEvaluating hybrid model with alpha = {alpha}")
            
            # Evaluate hybrid model and unrestricted hybrid model (if enabled) on each test example
            hybrid_results = []
            unrestricted_hybrid_results = [] if include_unrestricted_hybrid else None
            
            for example in tqdm(test_data, desc=f"Evaluating hybrid model (α={alpha})"):
                # Extract precomputed scores and data
                small_scores = example["small_scores"]
                large_scores = example["large_scores"]
                correct_index = example["correct_index"]
                small_cost = example["small_cost"]
                large_cost = example["large_cost"]
                choices = example["choices"]
                problem = example["problem"]
                
                # Use precomputed standard lambda threshold
                lambda_threshold = standard_lambda_thresholds[alpha]
                
                # Construct conformal action set
                max_score = max(small_scores)
                conformal_set = [i for i, score in enumerate(small_scores) if score >= max_score - lambda_threshold]
                
                # === STANDARD HYBRID MODEL ===
                # If only one action in conformal set, choose it directly (use small model)
                if len(conformal_set) == 1:
                    hybrid_pred = conformal_set[0]
                    used_large_model = False
                    example_cost = small_cost  # Only small model cost
                else:
                    # We need to simulate getting refined scores from the large model
                    # but only for choices in the conformal set
                    filtered_choices = [choices[i] for i in conformal_set]
                    
                    # Since we've already precomputed all large model scores, we can just filter them
                    # But in a real scenario, we would make a new API call with filtered choices
                    filtered_large_scores = [large_scores[i] for i in conformal_set]
                    
                    # Find best choice
                    best_idx_in_conformal = np.argmax(filtered_large_scores)
                    hybrid_pred = conformal_set[best_idx_in_conformal]
                    used_large_model = True
                    
                    # For cost calculation, we need to estimate what the cost would be
                    # if we had made the API call with only the filtered choices
                    # We'll assume input tokens scale linearly with number of choices,
                    # and output tokens remain constant
                    
                    # Calculate the scaling factor based on number of choices
                    scaling_factor = len(conformal_set) / len(choices)
                    
                    # Calculate the estimated cost of a large model call with filtered choices
                    large_prompt_tokens = example["large_tokens"]["prompt"]
                    large_completion_tokens = example["large_tokens"]["completion"]
                    
                    # Estimate tokens for filtered call
                    filtered_prompt_tokens = int(large_prompt_tokens * (0.5 + 0.5 * scaling_factor))
                    filtered_completion_tokens = large_completion_tokens
                    
                    # Recalculate cost using the pricing structure in the OpenAIWrapper
                    # Get the pricing model
                    pricing_model = large_model.model_name
                    if pricing_model in large_model.model_name_map:
                        pricing_model = large_model.model_name_map[pricing_model]
                    
                    # Calculate cost (convert from per 1M tokens to per token)
                    filtered_input_cost = filtered_prompt_tokens * (large_model.pricing[pricing_model]["input"] / 1000000)
                    filtered_output_cost = filtered_completion_tokens * (large_model.pricing[pricing_model]["output"] / 1000000)
                    filtered_large_cost = filtered_input_cost + filtered_output_cost
                    
                    # Total cost is small model + filtered large model
                    example_cost = small_cost + filtered_large_cost
                
                hybrid_correct = (hybrid_pred == correct_index)
                
                hybrid_results.append({
                    "hybrid_correct": hybrid_correct,
                    "used_large_model": used_large_model,
                    "lambda_value": lambda_threshold,
                    "cost": example_cost,
                    "conformal_set_size": len(conformal_set)
                })
                
                # === UNRESTRICTED HYBRID MODEL (if enabled) ===
                if include_unrestricted_hybrid:
                    # Unrestricted hybrid always uses small model first,
                    # then passes all options to large model if needed
                    # but decides to use large model based on conformal set size
                    
                    if len(conformal_set) == 1:
                        # Same as regular hybrid when conformal set has one element
                        unrestricted_pred = conformal_set[0]
                        unrestricted_used_large = False
                        unrestricted_cost = small_cost
                    else:
                        # Use large model with all choices (already precomputed)
                        unrestricted_pred = np.argmax(large_scores)
                        unrestricted_used_large = True
                        unrestricted_cost = small_cost + large_cost
                    
                    unrestricted_hybrid_correct = (unrestricted_pred == correct_index)
                    
                    unrestricted_hybrid_results.append({
                        "hybrid_correct": unrestricted_hybrid_correct,
                        "used_large_model": unrestricted_used_large,
                        "lambda_value": lambda_threshold,
                        "cost": unrestricted_cost,
                        "conformal_set_size": len(conformal_set)
                    })
            
            # Calculate hybrid model metrics
            hybrid_correct_count = sum(1 for r in hybrid_results if r["hybrid_correct"])
            large_model_used_count = sum(1 for r in hybrid_results if r["used_large_model"])
            avg_lambda = sum(r["lambda_value"] for r in hybrid_results) / len(hybrid_results)
            avg_cost = sum(r["cost"] for r in hybrid_results) / len(hybrid_results)
            
            hybrid_accuracy = hybrid_correct_count / len(test_data)
            large_model_usage = large_model_used_count / len(test_data)
            small_model_usage = 1.0 - large_model_usage
            
            hybrid_accuracies[alpha].append(hybrid_accuracy)
            large_model_calls[alpha].append(large_model_usage)
            lambda_values[alpha].append(avg_lambda)
            hybrid_costs[alpha].append(avg_cost)
            
            print(f"Hybrid model accuracy (α={alpha}): {hybrid_accuracy:.4f}")
            print(f"Large model usage: {large_model_usage:.2%}")
            print(f"Small model usage: {small_model_usage:.2%}")
            print(f"Average lambda: {avg_lambda:.4f}")
            print(f"Average cost per example: ${avg_cost:.6f}")
            
            # Store results for this alpha
            trial_hybrid_results[alpha] = {
                "accuracy": hybrid_accuracy,
                "large_model_usage": large_model_usage,
                "small_model_usage": small_model_usage,
                "avg_lambda": avg_lambda,
                "avg_cost": avg_cost
            }
            
            # Calculate unrestricted hybrid metrics (if enabled)
            if include_unrestricted_hybrid:
                unrestricted_correct_count = sum(1 for r in unrestricted_hybrid_results if r["hybrid_correct"])
                unrestricted_large_model_used_count = sum(1 for r in unrestricted_hybrid_results if r["used_large_model"])
                unrestricted_avg_cost = sum(r["cost"] for r in unrestricted_hybrid_results) / len(unrestricted_hybrid_results)
                
                unrestricted_accuracy = unrestricted_correct_count / len(test_data)
                unrestricted_large_model_usage = unrestricted_large_model_used_count / len(test_data)
                
                unrestricted_hybrid_accuracies[alpha].append(unrestricted_accuracy)
                unrestricted_hybrid_costs[alpha].append(unrestricted_avg_cost)
                
                print(f"Unrestricted hybrid accuracy (α={alpha}): {unrestricted_accuracy:.4f}")
                print(f"Unrestricted large model usage: {unrestricted_large_model_usage:.2%}")
                print(f"Unrestricted avg cost per example: ${unrestricted_avg_cost:.6f}")
                
                # Add unrestricted hybrid results
                trial_hybrid_results[alpha]["unrestricted"] = {
                    "accuracy": unrestricted_accuracy,
                    "large_model_usage": unrestricted_large_model_usage,
                    "avg_cost": unrestricted_avg_cost
                }
            
            # Calculate cost-matched random baseline for conformal method
            if include_random_baseline:
                print(f"Calculating cost-matched random baseline for conformal method (α={alpha})")
                (random_avg, random_std), (random_cost_avg, random_cost_std), sm_fraction = calculate_cost_matched_random_baseline(
                    test_data, avg_cost, small_avg_cost, large_avg_cost, num_samples=10
                )
                random_baseline_accuracies[alpha].append(random_avg)
                random_baseline_costs[alpha].append(random_cost_avg)
                random_baseline_sm_fractions[alpha].append(sm_fraction)
                
                print(f"Random baseline accuracy: {random_avg:.4f} ± {random_std:.4f}")
                print(f"Random baseline avg cost: ${random_cost_avg:.6f} ± ${random_cost_std:.6f}")
                print(f"Small model fraction: {sm_fraction:.2%}")
                
                # Add to trial results
                trial_hybrid_results[alpha]["random_baseline"] = {
                    "accuracy": random_avg,
                    "accuracy_std": random_std,
                    "avg_cost": random_cost_avg,
                    "cost_std": random_cost_std,
                    "small_model_fraction": sm_fraction
                }
                
                # If unrestricted hybrid is enabled, calculate a baseline for it too
                if include_unrestricted_hybrid:
                    print(f"Calculating cost-matched random baseline for unrestricted hybrid (α={alpha})")
                    (unrestricted_random_avg, unrestricted_random_std), (unrestricted_random_cost_avg, unrestricted_random_cost_std), unrestricted_sm_fraction = calculate_cost_matched_random_baseline(
                        test_data, unrestricted_avg_cost, small_avg_cost, large_avg_cost, num_samples=10
                    )
                    unrestricted_random_baseline_accuracies[alpha].append(unrestricted_random_avg)
                    unrestricted_random_baseline_costs[alpha].append(unrestricted_random_cost_avg)
                    unrestricted_random_baseline_sm_fractions[alpha].append(unrestricted_sm_fraction)
                    
                    print(f"Unrestricted random baseline accuracy: {unrestricted_random_avg:.4f} ± {unrestricted_random_std:.4f}")
                    print(f"Unrestricted random baseline avg cost: ${unrestricted_random_cost_avg:.6f} ± ${unrestricted_random_cost_std:.6f}")
                    print(f"Unrestricted small model fraction: {unrestricted_sm_fraction:.2%}")
                    
                    # Add to trial results
                    trial_hybrid_results[alpha]["unrestricted_random_baseline"] = {
                        "accuracy": unrestricted_random_avg,
                        "accuracy_std": unrestricted_random_std,
                        "avg_cost": unrestricted_random_cost_avg,
                        "cost_std": unrestricted_random_cost_std,
                        "small_model_fraction": unrestricted_sm_fraction
                    }
        
        # Save all results for this trial
        trial_results.append({
            "trial": trial_idx + 1,
            "trial_seed": trial_seed,
            "subject": subject_name,
            "method": "standard",
            "calibration_size": len(calibration_examples),
            "test_size": len(test_examples),
            "small_model_accuracy": small_acc,
            "large_model_accuracy": large_acc,
            "small_model_cost": small_avg_cost,
            "large_model_cost": large_avg_cost,
            "hybrid_results": trial_hybrid_results
        })
        
        # Save detailed trial results
        detailed_trial_file = os.path.join(results_dir, f"detailed_trial_{trial_idx+1}_results.json")
        with open(detailed_trial_file, "w") as f:
            json.dump(trial_results[-1], f, default=custom_json_serializer)
    
    # Compute summary statistics across all trials
    small_avg = np.mean(small_model_accuracies)
    small_std = np.std(small_model_accuracies)
    large_avg = np.mean(large_model_accuracies)
    large_std = np.std(large_model_accuracies)
    small_cost_avg = np.mean(small_model_costs)
    small_cost_std = np.std(small_model_costs)
    large_cost_avg = np.mean(large_model_costs)
    large_cost_std = np.std(large_model_costs)
    
    print(f"\n===== Summary Statistics Across {num_trials} Trials =====")
    print(f"Small model: {small_avg:.4f} ± {small_std:.4f}, Cost: ${small_cost_avg:.6f} ± ${small_cost_std:.6f}")
    print(f"Large model: {large_avg:.4f} ± {large_std:.4f}, Cost: ${large_cost_avg:.6f} ± ${large_cost_std:.6f}")
    
    hybrid_avg = {alpha: np.mean(hybrid_accuracies[alpha]) for alpha in alphas}
    hybrid_std = {alpha: np.std(hybrid_accuracies[alpha]) for alpha in alphas}
    calls_avg = {alpha: np.mean(large_model_calls[alpha]) for alpha in alphas}
    calls_std = {alpha: np.std(large_model_calls[alpha]) for alpha in alphas}
    lambda_avg = {alpha: np.mean(lambda_values[alpha]) for alpha in alphas}
    lambda_std = {alpha: np.std(lambda_values[alpha]) for alpha in alphas}
    hybrid_cost_avg = {alpha: np.mean(hybrid_costs[alpha]) for alpha in alphas}
    hybrid_cost_std = {alpha: np.std(hybrid_costs[alpha]) for alpha in alphas}
    
    # Unrestricted hybrid statistics (if enabled)
    if include_unrestricted_hybrid:
        unrestricted_avg = {alpha: np.mean(unrestricted_hybrid_accuracies[alpha]) for alpha in alphas}
        unrestricted_std = {alpha: np.std(unrestricted_hybrid_accuracies[alpha]) for alpha in alphas}
        unrestricted_cost_avg = {alpha: np.mean(unrestricted_hybrid_costs[alpha]) for alpha in alphas}
        unrestricted_cost_std = {alpha: np.std(unrestricted_hybrid_costs[alpha]) for alpha in alphas}
    
    for alpha in alphas:
        print(f"Alpha={alpha}:")
        print(f"  Hybrid accuracy: {hybrid_avg[alpha]:.4f} ± {hybrid_std[alpha]:.4f}")
        print(f"  Large model usage: {calls_avg[alpha]:.2%} ± {calls_std[alpha]:.2%}")
        print(f"  Lambda threshold: {lambda_avg[alpha]:.4f} ± {lambda_std[alpha]:.4f}")
        print(f"  Hybrid cost: ${hybrid_cost_avg[alpha]:.6f} ± ${hybrid_cost_std[alpha]:.6f}")
        
        if include_unrestricted_hybrid:
            print(f"  Unrestricted hybrid accuracy: {unrestricted_avg[alpha]:.4f} ± {unrestricted_std[alpha]:.4f}")
            print(f"  Unrestricted cost: ${unrestricted_cost_avg[alpha]:.6f} ± ${unrestricted_cost_std[alpha]:.6f}")
    
    # Process random baseline data if available
    random_baseline_data = {}
    if include_random_baseline:
        for alpha in alphas:
            if alpha in random_baseline_accuracies and random_baseline_accuracies[alpha]:
                random_avg = np.mean(random_baseline_accuracies[alpha])
                random_std = np.std(random_baseline_accuracies[alpha])
                random_cost_avg = np.mean(random_baseline_costs[alpha])
                random_cost_std = np.std(random_baseline_costs[alpha])
                random_sm_frac_avg = np.mean(random_baseline_sm_fractions[alpha])
                random_sm_frac_std = np.std(random_baseline_sm_fractions[alpha])
                
                random_baseline_data[alpha] = {
                    "avg_accuracy": random_avg,
                    "std_accuracy": random_std,
                    "avg_cost": random_cost_avg,
                    "std_cost": random_cost_std,
                    "avg_small_model_fraction": random_sm_frac_avg,
                    "std_small_model_fraction": random_sm_frac_std
                }
                print(f"  Random baseline (α={alpha}): {random_avg:.4f} ± {random_std:.4f}, Cost: ${random_cost_avg:.6f} ± ${random_cost_std:.6f}")
    
    # Process unrestricted random baseline data if available
    unrestricted_random_baseline_data = {}
    if include_random_baseline and include_unrestricted_hybrid:
        for alpha in alphas:
            if alpha in unrestricted_random_baseline_accuracies and unrestricted_random_baseline_accuracies[alpha]:
                unrestricted_random_avg = np.mean(unrestricted_random_baseline_accuracies[alpha])
                unrestricted_random_std = np.std(unrestricted_random_baseline_accuracies[alpha])
                unrestricted_random_cost_avg = np.mean(unrestricted_random_baseline_costs[alpha])
                unrestricted_random_cost_std = np.std(unrestricted_random_baseline_costs[alpha])
                unrestricted_random_sm_frac_avg = np.mean(unrestricted_random_baseline_sm_fractions[alpha])
                unrestricted_random_sm_frac_std = np.std(unrestricted_random_baseline_sm_fractions[alpha])
                
                unrestricted_random_baseline_data[alpha] = {
                    "avg_accuracy": unrestricted_random_avg,
                    "std_accuracy": unrestricted_random_std,
                    "avg_cost": unrestricted_random_cost_avg,
                    "std_cost": unrestricted_random_cost_std,
                    "avg_small_model_fraction": unrestricted_random_sm_frac_avg,
                    "std_small_model_fraction": unrestricted_random_sm_frac_std
                }
                print(f"  Unrestricted random baseline (α={alpha}): {unrestricted_random_avg:.4f} ± {unrestricted_random_std:.4f}, Cost: ${unrestricted_random_cost_avg:.6f} ± ${unrestricted_random_cost_std:.6f}")
    
    # Compile final results
    final_results = {
        "subject": subject_name,
        "method": "standard",
        "iterations": num_trials,
        "random_seed": random_seed,
        "calibration_size": calibration_size,
        "total_examples": total_examples,
        "small_model": {
            "avg_accuracy": small_avg,
            "std_accuracy": small_std,
            "avg_cost": small_cost_avg,
            "std_cost": small_cost_std,
            "x_position": 1.0,  # 100% small model
            "x_std": 0.0        # No variation in x position
        },
        "large_model": {
            "avg_accuracy": large_avg,
            "std_accuracy": large_std,
            "avg_cost": large_cost_avg,
            "std_cost": large_cost_std,
            "x_position": 0.0,  # 0% small model
            "x_std": 0.0        # No variation in x position
        },
        "hybrid_models": {
            alpha: {
                "avg_accuracy": hybrid_avg[alpha],
                "std_accuracy": hybrid_std[alpha],
                "avg_large_model_usage": calls_avg[alpha],
                "std_large_model_usage": calls_std[alpha],
                "avg_lambda": lambda_avg[alpha],
                "std_lambda": lambda_std[alpha],
                "avg_cost": hybrid_cost_avg[alpha],
                "std_cost": hybrid_cost_std[alpha]
            } for alpha in alphas
        },
        "all_trials": trial_results
    }
    
    # Add unrestricted hybrid data (if enabled)
    if include_unrestricted_hybrid:
        final_results["unrestricted_hybrid"] = {
            alpha: {
                "avg_accuracy": unrestricted_avg[alpha],
                "std_accuracy": unrestricted_std[alpha],
                "avg_cost": unrestricted_cost_avg[alpha],
                "std_cost": unrestricted_cost_std[alpha]
            } for alpha in alphas
        }
    
    # Add random baseline data if available
    if random_baseline_data:
        final_results["random_baseline"] = random_baseline_data
    
    # Add unrestricted random baseline data if available
    if include_unrestricted_hybrid and include_random_baseline and unrestricted_random_baseline_data:
        final_results["unrestricted_random_baseline"] = unrestricted_random_baseline_data
    
    # Save final results to the results directory
    final_file = os.path.join(results_dir, f"truthqa_{subject_name}_{method_name}_final_results_{num_trials}_trials.json")
    with open(final_file, "w") as f:
        json.dump(final_results, f, default=custom_json_serializer)
    
    # Plot the results
    try:
        from plotting import plot_cost_vs_accuracy_simple, plot_enhanced_performance_simple
        
        # Plot cost vs accuracy (simplified)
        cost_plot = plot_cost_vs_accuracy_simple(final_results, output_dir=results_dir)
        print(f"Created cost vs accuracy plot: {cost_plot}")
        
        # Plot enhanced performance (simplified)
        perf_plot = plot_enhanced_performance_simple(final_results, output_dir=results_dir)
        print(f"Created enhanced performance plot: {perf_plot}")
    except ImportError:
        print("Could not import plotting module for generating plots.")
        print("Results are saved to JSON files for later analysis.")
    
    return final_results

# -----------------------------------------------------------------------------
# Main: CLI with run / plot_only support
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate TruthfulQA with conformal alignment and cost tracking"
    )
    parser.add_argument("--api_key",        type=str,   default=None,
                        help="OpenAI API key (or env OPENAI_API_KEY)")
    parser.add_argument("--num_trials",     type=int,   default=30,
                        help="Number of independent trials")
    parser.add_argument("--alphas",         type=float, nargs="+",
                        default=[0.05, 0.1, 0.15, 0.2, 0.25],
                        help="Risk levels for conformal calibration")
    parser.add_argument("--max_workers",    type=int,   default=4,
                        help="Parallel workers for scoring")
    parser.add_argument("--calibration_size", type=int, default=400,
                        help="Calibration set size")
    parser.add_argument("--total_examples",   type=int, default=684,
                        help="Total TruthfulQA examples to sample")
    parser.add_argument("--random_seed",      type=int, default=42,
                        help="Random seed for splitting")
    parser.add_argument("--results_dir",    type=str, default="truthqa_results",
                        help="Directory to save results")
    parser.add_argument("--evaluation_mode",
                        choices=["run", "plot_only"],
                        default="run",
                        help="Mode: run (full evaluation) or plot_only (load & plot only)")
    parser.add_argument("--standard_plot",
                        action="store_true",
                        default=False,
                        help="In plot_only mode, use the original plotting style")
    parser.add_argument("--include_random",      action="store_true", default=True,
                        help="Also compute a random‐baseline routing")
    parser.add_argument("--include_unrestricted", action="store_true", default=True,
                        help="Include unrestricted hybrid in plots")
    parser.add_argument("--no_unrestricted", action="store_false", dest="include_unrestricted",
                        help="Exclude unrestricted hybrid from plots")

    args = parser.parse_args()

    # ─── PLOT-ONLY BRANCH ─────────────────────────────────────────────────────
    if args.evaluation_mode == "plot_only":
        if not args.results_dir:
            print("Error: --results_dir must be specified for plot_only mode")
            exit(1)
        print(f"Loading and plotting results from {args.results_dir}...")
        try:
            from plotting import create_simple_plots
            use_enhanced = not args.standard_plot
            create_simple_plots(args.results_dir)
        except ImportError:
            print("Error: Could not import plotting module for plot_only mode")
            print("Make sure the plotting.py file is available")
        exit(0)

    # ─── Make sure args match unless explicitly overridden ───────────────────────
    # Enable these by default for TruthfulQA evaluation
    if not hasattr(args, 'include_random') or args.include_random is None:
        args.include_random = True
    if not hasattr(args, 'include_unrestricted') or args.include_unrestricted is None:
        args.include_unrestricted = True

    # ─── FULL EVALUATION ───────────────────────────────────────────────────────
    # Check if results directory already exists with result files
    if args.results_dir and os.path.exists(args.results_dir):
        # Look for final results files
        results_files = [f for f in os.listdir(args.results_dir) 
                        if f.endswith("_trials.json") and "final_results" in f]
        if results_files:
            print(f"Results directory {args.results_dir} already contains results files.")
            print("Loading existing results to display statistics...")
            
            # Load the most recent final results file
            final_file = sorted(results_files)[-1]
            with open(os.path.join(args.results_dir, final_file), "r") as f:
                final_results = json.load(f)
            
            # Display summary statistics
            small_avg = final_results["small_model"]["avg_accuracy"]
            small_std = final_results["small_model"]["std_accuracy"]
            small_cost_avg = final_results["small_model"]["avg_cost"]
            small_cost_std = final_results["small_model"]["std_cost"]
            
            large_avg = final_results["large_model"]["avg_accuracy"]
            large_std = final_results["large_model"]["std_accuracy"]
            large_cost_avg = final_results["large_model"]["avg_cost"]
            large_cost_std = final_results["large_model"]["std_cost"]
            
            alphas = list(final_results["hybrid_models"].keys())
            
            print(f"\n===== Summary Statistics From {final_results['iterations']} Trials =====")
            print(f"Subject: {final_results.get('subject', 'TruthfulQA')}")
            print(f"Small model: {small_avg:.4f} ± {small_std:.4f}, Cost: ${small_cost_avg:.6f} ± ${small_cost_std:.6f}")
            print(f"Large model: {large_avg:.4f} ± {large_std:.4f}, Cost: ${large_cost_avg:.6f} ± ${large_cost_std:.6f}")
            
            for alpha in alphas:
                hybrid_data = final_results["hybrid_models"][alpha]
                print(f"\nAlpha={alpha}:")
                print(f"  Hybrid accuracy: {hybrid_data['avg_accuracy']:.4f} ± {hybrid_data['std_accuracy']:.4f}")
                print(f"  Large model usage: {hybrid_data['avg_large_model_usage']:.2%} ± {hybrid_data['std_large_model_usage']:.2%}")
                print(f"  Lambda threshold: {hybrid_data['avg_lambda']:.4f} ± {hybrid_data['std_lambda']:.4f}")
                print(f"  Hybrid cost: ${hybrid_data['avg_cost']:.6f} ± {hybrid_data['std_cost']:.6f}")
                
                if "unrestricted_hybrid" in final_results and alpha in final_results["unrestricted_hybrid"]:
                    unrestricted = final_results["unrestricted_hybrid"][alpha]
                    print(f"  Unrestricted hybrid accuracy: {unrestricted['avg_accuracy']:.4f} ± {unrestricted['std_accuracy']:.4f}")
                    print(f"  Unrestricted cost: ${unrestricted['avg_cost']:.6f} ± {unrestricted['std_cost']:.6f}")
                
                if "random_baseline" in final_results and alpha in final_results["random_baseline"]:
                    random = final_results["random_baseline"][alpha]
                    print(f"  Random baseline: {random['avg_accuracy']:.4f} ± {random['std_accuracy']:.4f}, Cost: ${random['avg_cost']:.6f} ± ${random['std_cost']:.6f}")
            
            # Generate plots
            try:
                from plotting import create_simple_plots
                create_simple_plots(args.results_dir)
            except ImportError:
                print("\nWarning: Could not import plotting module for generating plots.")
                print("Results statistics have been displayed but plots could not be generated.")
            
            exit(0)
    
    # If no existing results, run full evaluation
    run_conformal_evaluation(
        alphas=args.alphas,
        api_key=args.api_key,
        max_workers=args.max_workers,
        num_trials=args.num_trials,
        calibration_size=args.calibration_size,
        total_examples=args.total_examples,
        random_seed=args.random_seed,
        results_dir=args.results_dir,
        subject=None,  # TruthfulQA has no subject split
        include_random_baseline=args.include_random,
        include_unrestricted_hybrid=args.include_unrestricted
    )