"""
GPQA Diamond Experiment Template

Template for running entropy-based early stopping experiments on GPQA Diamond dataset.
Demonstrates the framework on graduate-level scientific reasoning questions.
"""

import sys
import os
import json
import numpy as np
import pandas as pd
from typing import Dict, List, Optional
import openai
from datetime import datetime

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from entropy_framework import EarlyStoppingFramework, EntropyCalculator, ThresholdCalculator

# OpenRouter API configuration
OPENROUTER_API_KEY = "your_openrouter_key"  # Replace with actual key
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"

# Sample GPQA Diamond problems for demonstration
SAMPLE_GPQA_PROBLEMS = [
    {
        "problem": "Which of the following compounds would be expected to have the highest boiling point?\n\nA) CH3CH2CH2CH3 (butane)\nB) CH3CH2CH2OH (propanol) \nC) CH3CH2OCH3 (diethyl ether)\nD) CH3CH2CHO (propanal)",
        "answer": "B",
        "explanation": "Propanol has hydrogen bonding due to the OH group, giving it the highest boiling point among these compounds.",
        "subject": "Chemistry"
    },
    {
        "problem": "A particle of mass m moves in a one-dimensional potential V(x) = kx^4, where k > 0. Using the uncertainty principle, estimate the ground state energy of this system.\n\nA) E ~ (ħ^2k/m)^(1/3)\nB) E ~ (ħ^4k/m^3)^(1/5) \nC) E ~ (ħ^2k^2/m)^(1/3)\nD) E ~ (ħ^6k/m^5)^(1/7)",
        "answer": "A",
        "explanation": "Using the uncertainty principle ΔxΔp ≥ ħ/2 and minimizing total energy E = p²/2m + kx⁴.",
        "subject": "Physics"
    },
    {
        "problem": "The enzyme carbonic anhydrase catalyzes the reaction CO₂ + H₂O ⇌ HCO₃⁻ + H⁺. If the pH of blood decreases, what effect would this have on the equilibrium?\n\nA) Shift left, decreasing CO₂ concentration\nB) Shift right, increasing CO₂ concentration  \nC) Shift left, increasing CO₂ concentration\nD) No effect on equilibrium position",
        "answer": "C",
        "explanation": "Lower pH means higher H⁺ concentration. By Le Chatelier's principle, equilibrium shifts left to consume excess H⁺, increasing CO₂.",
        "subject": "Biology"
    },
    {
        "problem": "What is the coordination number of the central metal ion in [Co(NH₃)₄Cl₂]⁺?\n\nA) 4\nB) 5  \nC) 6\nD) 7",
        "answer": "C",
        "explanation": "The coordination number is the total number of ligands directly bonded to the central metal: 4 NH₃ + 2 Cl⁻ = 6.",
        "subject": "Chemistry"
    },
    {
        "problem": "In quantum mechanics, if a particle is in a superposition state |ψ⟩ = α|0⟩ + β|1⟩, what is the probability of measuring the particle in state |1⟩?\n\nA) α\nB) β\nC) |α|²\nD) |β|²",
        "answer": "D",
        "explanation": "In quantum mechanics, the probability of measuring a state is given by the square of the amplitude's magnitude.",
        "subject": "Physics"
    }
]

class GPQAExperiment:
    """
    Main experiment class for GPQA Diamond entropy analysis.
    
    Runs the complete experiment pipeline on graduate-level scientific reasoning,
    testing cross-domain generalization of entropy-based confidence.
    """
    
    def __init__(self, model_name: str, api_key: Optional[str] = None):
        """
        Initialize the GPQA experiment.
        
        Args:
            model_name: Name of the model to test
            api_key: API key for the model service
        """
        self.model_name = model_name
        self.api_key = api_key or OPENROUTER_API_KEY
        self.framework = EarlyStoppingFramework()
        self.results = []
        
        # Configure OpenAI client for OpenRouter
        if self.api_key:
            openai.api_key = self.api_key
            openai.api_base = OPENROUTER_BASE_URL
    
    def generate_scientific_reasoning(self, 
                                    problem: str, 
                                    max_tokens: int = 8192,
                                    temperature: float = 0.7) -> Dict:
        """
        Generate scientific reasoning for a GPQA problem.
        
        Args:
            problem: GPQA problem statement with multiple choice options
            max_tokens: Maximum tokens for generation
            temperature: Sampling temperature
            
        Returns:
            Dictionary with reasoning and logprobs
        """
        try:
            prompt = f"""Solve this graduate-level scientific reasoning problem step by step. 
Analyze each option carefully and provide detailed scientific reasoning.
End with your final answer choice (A, B, C, or D).

Problem: {problem}

Step-by-step solution:"""

            response = openai.ChatCompletion.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "You are an expert scientist with graduate-level knowledge across physics, chemistry, and biology. Solve problems with rigorous scientific reasoning."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=max_tokens,
                temperature=temperature,
                logprobs=20,  # Top-20 for entropy calculation
                echo=False
            )
            
            reasoning_text = response.choices[0].message.content
            
            # Extract logprobs for entropy calculation
            logprobs = []
            if hasattr(response.choices[0], 'logprobs') and response.choices[0].logprobs:
                for token_data in response.choices[0].logprobs.token_logprobs:
                    if token_data:
                        top_logprobs = list(token_data.values())[:20]
                        logprobs.append(top_logprobs)
            
            return {
                'reasoning': reasoning_text,
                'logprobs': logprobs,
                'model': self.model_name,
                'step_tokens': max_tokens,
                'timestamp': datetime.now().isoformat()
            }
            
        except Exception as e:
            print(f"Error in reasoning generation: {e}")
            return {
                'reasoning': f"Error: {str(e)}",
                'logprobs': [],
                'model': self.model_name,
                'error': str(e)
            }
    
    def extract_answer_choice(self, reasoning: str) -> Optional[str]:
        """
        Extract multiple choice answer (A, B, C, D) from reasoning text.
        
        Args:
            reasoning: Generated reasoning text
            
        Returns:
            Extracted answer choice or None
        """
        import re
        
        # Look for answer patterns
        patterns = [
            r'[Ff]inal answer[:\s]*([ABCD])',
            r'[Aa]nswer[:\s]*([ABCD])',
            r'[Tt]herefore[,\s]+the answer is[:\s]*([ABCD])',
            r'[Ss]o the answer is[:\s]*([ABCD])',
            r'[Tt]he answer is[:\s]*([ABCD])',
            r'[Cc]hoice[:\s]*([ABCD])',
            r'[Oo]ption[:\s]*([ABCD])',
            r'\\boxed\{([ABCD])\}',  # LaTeX boxed format
            r'\b([ABCD])\)?\s*$',  # Letter at end of text
        ]
        
        for pattern in patterns:
            match = re.search(pattern, reasoning, re.IGNORECASE)
            if match:
                return match.group(1).upper()
        
        # Fallback: look for isolated A, B, C, D in last few lines
        lines = reasoning.strip().split('\n')
        for line in reversed(lines[-3:]):  # Check last 3 lines
            matches = re.findall(r'\b([ABCD])\b', line.upper())
            if matches:
                return matches[-1]
        
        return None
    
    def evaluate_correctness(self, reasoning: str, correct_answer: str) -> bool:
        """
        Evaluate if reasoning leads to correct answer choice.
        
        Args:
            reasoning: Generated reasoning text
            correct_answer: Expected correct answer (A, B, C, or D)
            
        Returns:
            Whether the answer is correct
        """
        extracted_answer = self.extract_answer_choice(reasoning)
        
        if extracted_answer is None:
            return False
        
        return extracted_answer.upper() == correct_answer.upper()
    
    def run_sequential_reasoning(self, 
                                problem: Dict, 
                                num_steps: int = 4,
                                step_tokens: int = 8192) -> Dict:
        """
        Run 4-step sequential reasoning process for GPQA.
        
        Args:
            problem: Problem dictionary with statement, answer, and metadata
            num_steps: Number of reasoning steps
            step_tokens: Tokens per step
            
        Returns:
            Complete reasoning sequence data
        """
        print(f"  Running {num_steps}-step scientific reasoning...")
        
        step_results = []
        cumulative_reasoning = ""
        
        for step in range(num_steps):
            print(f"    Step {step+1}/{num_steps}")
            
            if step == 0:
                # First step: solve from scratch
                step_problem = problem['problem']
            else:
                # Subsequent steps: refine previous reasoning
                step_problem = f"""Previous scientific analysis:
{cumulative_reasoning}

Please review and improve the above analysis for this scientific problem:
{problem['problem']}

Provide a more rigorous and detailed scientific solution:"""
            
            step_result = self.generate_scientific_reasoning(
                step_problem,
                max_tokens=step_tokens,
                temperature=0.7
            )
            
            step_results.append({
                'step': step + 1,
                'reasoning': step_result['reasoning'],
                'logprobs': step_result['logprobs'],
                'tokens': len(step_result['reasoning'].split()),
                'timestamp': step_result['timestamp']
            })
            
            # Update cumulative reasoning
            cumulative_reasoning = step_result['reasoning']
        
        # Evaluate final answer
        final_reasoning = step_results[-1]['reasoning']
        is_correct = self.evaluate_correctness(final_reasoning, problem['answer'])
        extracted_answer = self.extract_answer_choice(final_reasoning)
        
        return {
            'problem_statement': problem['problem'],
            'correct_answer': problem['answer'],
            'extracted_answer': extracted_answer,
            'is_correct': is_correct,
            'subject': problem.get('subject', 'Unknown'),
            'explanation': problem.get('explanation', ''),
            'step_results': step_results,
            'total_tokens': sum(s['tokens'] for s in step_results),
            'step1_entropy_data': {
                'logprobs': step_results[0]['logprobs'],
                'reasoning': step_results[0]['reasoning'],
                'is_correct_step1': self.evaluate_correctness(step_results[0]['reasoning'], problem['answer'])
            }
        }
    
    def analyze_by_subject(self, experiment_data: List[Dict]) -> Dict:
        """
        Analyze results by scientific subject area.
        
        Args:
            experiment_data: List of problem results
            
        Returns:
            Subject-wise analysis
        """
        print("Analyzing results by scientific subject...")
        
        subject_analysis = {}
        
        for data in experiment_data:
            subject = data.get('subject', 'Unknown')
            
            if subject not in subject_analysis:
                subject_analysis[subject] = {
                    'problems': [],
                    'step1_correct': 0,
                    'final_correct': 0,
                    'total': 0
                }
            
            subject_data = subject_analysis[subject]
            subject_data['problems'].append(data)
            subject_data['total'] += 1
            
            if data['step1_entropy_data']['is_correct_step1']:
                subject_data['step1_correct'] += 1
            
            if data['is_correct']:
                subject_data['final_correct'] += 1
        
        # Calculate subject-wise metrics
        for subject, data in subject_analysis.items():
            total = data['total']
            data['step1_accuracy'] = data['step1_correct'] / total if total > 0 else 0
            data['final_accuracy'] = data['final_correct'] / total if total > 0 else 0
            data['improvement'] = data['final_accuracy'] - data['step1_accuracy']
        
        return subject_analysis
    
    def calculate_entropy_statistics(self, experiment_data: List[Dict]) -> Dict:
        """
        Calculate comprehensive entropy statistics for GPQA.
        
        Args:
            experiment_data: List of problem results
            
        Returns:
            Entropy statistics and cross-domain analysis
        """
        print("Calculating entropy statistics for scientific reasoning...")
        
        entropy_calc = EntropyCalculator(k=20)
        
        # Overall entropy analysis
        correct_entropies = []
        incorrect_entropies = []
        
        # Subject-wise entropy analysis  
        subject_entropies = {}
        
        for data in experiment_data:
            step1_data = data['step1_entropy_data']
            subject = data.get('subject', 'Unknown')
            
            if step1_data['logprobs']:
                entropy = entropy_calc.calculate_sequence_entropy(step1_data['logprobs'])
                
                # Overall analysis
                if data['is_correct']:
                    correct_entropies.append(entropy)
                else:
                    incorrect_entropies.append(entropy)
                
                # Subject-wise analysis
                if subject not in subject_entropies:
                    subject_entropies[subject] = {'correct': [], 'incorrect': []}
                
                if data['is_correct']:
                    subject_entropies[subject]['correct'].append(entropy)
                else:
                    subject_entropies[subject]['incorrect'].append(entropy)
        
        if not correct_entropies or not incorrect_entropies:
            print("Warning: Insufficient data for entropy analysis")
            return {}
        
        # Calculate overall statistics
        correct_mean = np.mean(correct_entropies)
        incorrect_mean = np.mean(incorrect_entropies)
        pooled_std = np.sqrt((np.var(correct_entropies) + np.var(incorrect_entropies)) / 2)
        cohens_d = abs(incorrect_mean - correct_mean) / pooled_std if pooled_std > 0 else 0
        
        # Statistical significance test
        from scipy.stats import ttest_ind
        t_stat, p_value = ttest_ind(correct_entropies, incorrect_entropies)
        
        # Subject-wise Cohen's d
        subject_cohens_d = {}
        for subject, entropies in subject_entropies.items():
            if entropies['correct'] and entropies['incorrect']:
                subj_correct_mean = np.mean(entropies['correct'])
                subj_incorrect_mean = np.mean(entropies['incorrect'])
                subj_pooled_std = np.sqrt((np.var(entropies['correct']) + np.var(entropies['incorrect'])) / 2)
                subj_cohens_d = abs(subj_incorrect_mean - subj_correct_mean) / subj_pooled_std if subj_pooled_std > 0 else 0
                subject_cohens_d[subject] = subj_cohens_d
        
        return {
            'overall_statistics': {
                'correct_entropies': {
                    'mean': correct_mean,
                    'std': np.std(correct_entropies),
                    'count': len(correct_entropies),
                    'values': correct_entropies
                },
                'incorrect_entropies': {
                    'mean': incorrect_mean,
                    'std': np.std(incorrect_entropies),
                    'count': len(incorrect_entropies),
                    'values': incorrect_entropies
                },
                'effect_size': {
                    'cohens_d': cohens_d,
                    'interpretation': self._interpret_cohens_d(cohens_d)
                },
                'statistical_test': {
                    't_statistic': t_stat,
                    'p_value': p_value,
                    'significant': p_value < 0.05
                }
            },
            'subject_wise_analysis': {
                'entropy_data': subject_entropies,
                'cohens_d_by_subject': subject_cohens_d
            }
        }
    
    def _interpret_cohens_d(self, cohens_d: float) -> str:
        """Interpret Cohen's d effect size."""
        if cohens_d < 0.2:
            return "negligible"
        elif cohens_d < 0.5:
            return "small"
        elif cohens_d < 0.8:
            return "medium"
        else:
            return "large"
    
    def run_threshold_analysis(self, entropy_stats: Dict) -> Dict:
        """
        Run comprehensive threshold analysis for GPQA.
        
        Args:
            entropy_stats: Entropy statistics from experiment
            
        Returns:
            Cross-domain threshold analysis results
        """
        print("Running threshold analysis for scientific reasoning...")
        
        overall_stats = entropy_stats.get('overall_statistics', {})
        if not overall_stats:
            return {}
        
        threshold_calc = ThresholdCalculator()
        correct_entropies = overall_stats['correct_entropies']['values']
        incorrect_entropies = overall_stats['incorrect_entropies']['values']
        
        methods = {
            'entropy_mean': threshold_calc.entropy_mean_threshold(correct_entropies),
            'information_theoretic': threshold_calc.information_theoretic_optimal(
                correct_entropies, incorrect_entropies),
            'bayesian': threshold_calc.bayesian_optimal(
                correct_entropies, incorrect_entropies),
            'scale_invariant': threshold_calc.scale_invariant_universal(
                correct_entropies, incorrect_entropies)
        }
        
        # Calculate comprehensive metrics for each threshold
        threshold_analysis = {}
        for method, threshold in methods.items():
            # Calculate stopping behavior
            correct_stops = sum(1 for e in correct_entropies if e <= threshold)
            incorrect_stops = sum(1 for e in incorrect_entropies if e <= threshold)
            
            total_correct = len(correct_entropies)
            total_incorrect = len(incorrect_entropies)
            total_problems = total_correct + total_incorrect
            total_stops = correct_stops + incorrect_stops
            
            # Calculate key metrics
            correct_stop_rate = correct_stops / total_correct if total_correct > 0 else 0
            incorrect_stop_rate = incorrect_stops / total_incorrect if total_incorrect > 0 else 0
            threshold_accuracy = correct_stops / total_stops if total_stops > 0 else 0
            
            # Token savings (75% savings when stopping early)
            token_savings = total_stops / total_problems * 0.75 if total_problems > 0 else 0
            
            # Calculate precision, recall, F1 for early stopping as classification
            true_positives = correct_stops  # Correct problems correctly stopped
            false_positives = incorrect_stops  # Incorrect problems incorrectly stopped
            false_negatives = total_correct - correct_stops  # Correct problems not stopped
            
            precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
            recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
            f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            threshold_analysis[method] = {
                'threshold': threshold,
                'stopping_metrics': {
                    'correct_stop_rate': correct_stop_rate,
                    'incorrect_stop_rate': incorrect_stop_rate,
                    'total_stop_rate': total_stops / total_problems if total_problems > 0 else 0,
                    'threshold_accuracy': threshold_accuracy
                },
                'efficiency_metrics': {
                    'token_savings': token_savings,
                    'total_stops': total_stops,
                    'correct_stops': correct_stops,
                    'incorrect_stops': incorrect_stops
                },
                'classification_metrics': {
                    'precision': precision,
                    'recall': recall,
                    'f1_score': f1_score
                }
            }
        
        return threshold_analysis
    
    def run_full_experiment(self, 
                           problems: Optional[List[Dict]] = None,
                           num_steps: int = 4,
                           step_tokens: int = 8192) -> Dict:
        """
        Run complete GPQA Diamond experiment.
        
        Args:
            problems: Problems to test (uses samples if None)
            num_steps: Number of reasoning steps
            step_tokens: Tokens per reasoning step
            
        Returns:
            Complete experiment results
        """
        if problems is None:
            problems = SAMPLE_GPQA_PROBLEMS
        
        print(f"Running GPQA Diamond Experiment")
        print(f"Model: {self.model_name}")
        print(f"Problems: {len(problems)}")
        print(f"Steps: {num_steps}, Tokens per step: {step_tokens}")
        print("="*60)
        
        # Run reasoning for all problems
        experiment_data = []
        for i, problem in enumerate(problems):
            print(f"\nProblem {i+1}/{len(problems)} [{problem.get('subject', 'Unknown')}]:")
            print(f"  {problem['problem'][:80]}...")
            
            problem_result = self.run_sequential_reasoning(
                problem, num_steps, step_tokens
            )
            problem_result['problem_id'] = i
            experiment_data.append(problem_result)
            
            print(f"  Result: {'✓' if problem_result['is_correct'] else '✗'} "
                  f"(Answer: {problem_result['extracted_answer']} vs {problem['answer']})")
        
        # Calculate comprehensive statistics
        print(f"\nCalculating experiment statistics...")
        
        # Basic accuracy metrics
        total_problems = len(experiment_data)
        step1_correct = sum(1 for d in experiment_data if d['step1_entropy_data']['is_correct_step1'])
        final_correct = sum(1 for d in experiment_data if d['is_correct'])
        
        step1_accuracy = step1_correct / total_problems if total_problems > 0 else 0
        final_accuracy = final_correct / total_problems if total_problems > 0 else 0
        
        # Subject-wise analysis
        subject_analysis = self.analyze_by_subject(experiment_data)
        
        # Entropy analysis
        entropy_stats = self.calculate_entropy_statistics(experiment_data)
        
        # Threshold analysis
        threshold_analysis = self.run_threshold_analysis(entropy_stats)
        
        # Compile comprehensive results
        results = {
            'experiment_info': {
                'model': self.model_name,
                'dataset': 'GPQA_Diamond',
                'num_problems': total_problems,
                'num_steps': num_steps,
                'step_tokens': step_tokens,
                'subjects_tested': list(subject_analysis.keys()),
                'timestamp': datetime.now().isoformat()
            },
            'accuracy_metrics': {
                'step1_accuracy': step1_accuracy,
                'final_accuracy': final_accuracy,
                'accuracy_improvement': final_accuracy - step1_accuracy,
                'step1_correct': step1_correct,
                'final_correct': final_correct,
                'total_problems': total_problems
            },
            'subject_wise_results': subject_analysis,
            'entropy_statistics': entropy_stats,
            'threshold_analysis': threshold_analysis,
            'problem_data': experiment_data,
            'summary': self._generate_experiment_summary(
                entropy_stats, threshold_analysis, subject_analysis, 
                step1_accuracy, final_accuracy
            )
        }
        
        # Save results
        output_file = f"gpqa_results_{self.model_name.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2, default=str)
        
        print(f"\nResults saved to: {output_file}")
        
        return results
    
    def _generate_experiment_summary(self, 
                                   entropy_stats: Dict,
                                   threshold_analysis: Dict,
                                   subject_analysis: Dict,
                                   step1_acc: float,
                                   final_acc: float) -> Dict:
        """Generate comprehensive experiment summary."""
        summary = {
            'overall_performance': {
                'step1_accuracy': f"{step1_acc:.1%}",
                'final_accuracy': f"{final_acc:.1%}",
                'accuracy_improvement': f"{final_acc - step1_acc:.1%}"
            }
        }
        
        # Subject-wise performance
        if subject_analysis:
            subject_summary = {}
            for subject, data in subject_analysis.items():
                subject_summary[subject] = {
                    'problems': data['total'],
                    'step1_accuracy': f"{data['step1_accuracy']:.1%}",
                    'final_accuracy': f"{data['final_accuracy']:.1%}",
                    'improvement': f"{data['improvement']:.1%}"
                }
            summary['subject_performance'] = subject_summary
        
        # Entropy analysis summary
        overall_stats = entropy_stats.get('overall_statistics', {})
        if overall_stats:
            summary['entropy_analysis'] = {
                'cohens_d': overall_stats.get('effect_size', {}).get('cohens_d', 0),
                'effect_interpretation': overall_stats.get('effect_size', {}).get('interpretation', 'unknown'),
                'entropy_separation_significant': overall_stats.get('statistical_test', {}).get('significant', False)
            }
            
            # Cross-domain entropy consistency
            subject_cohens = entropy_stats.get('subject_wise_analysis', {}).get('cohens_d_by_subject', {})
            if subject_cohens:
                summary['cross_domain_consistency'] = {
                    'subject_effect_sizes': subject_cohens,
                    'min_cohens_d': min(subject_cohens.values()) if subject_cohens else 0,
                    'max_cohens_d': max(subject_cohens.values()) if subject_cohens else 0,
                    'mean_cohens_d': np.mean(list(subject_cohens.values())) if subject_cohens else 0
                }
        
        # Threshold analysis summary
        if threshold_analysis:
            best_method = None
            best_savings = 0
            best_accuracy = 0
            best_f1 = 0
            
            for method, analysis in threshold_analysis.items():
                savings = analysis['efficiency_metrics']['token_savings']
                accuracy = analysis['stopping_metrics']['threshold_accuracy']
                f1 = analysis['classification_metrics']['f1_score']
                
                if savings > best_savings:
                    best_method = method
                    best_savings = savings
                    best_accuracy = accuracy
                    best_f1 = f1
            
            summary['threshold_performance'] = {
                'best_method': best_method,
                'best_token_savings': f"{best_savings:.1%}",
                'best_threshold_accuracy': f"{best_accuracy:.1%}",
                'best_f1_score': f"{best_f1:.3f}"
            }
        
        return summary

def main():
    """Main function to run GPQA Diamond experiment."""
    import argparse
    
    parser = argparse.ArgumentParser(description='Run GPQA Diamond entropy experiment')
    parser.add_argument('--model', default='gpt-4', 
                       help='Model name to test')
    parser.add_argument('--api-key', default=None,
                       help='API key for the model service')  
    parser.add_argument('--num-problems', type=int, default=None,
                       help='Number of problems to test')
    parser.add_argument('--steps', type=int, default=4,
                       help='Number of reasoning steps')
    parser.add_argument('--step-tokens', type=int, default=8192,
                       help='Tokens per reasoning step')
    
    args = parser.parse_args()
    
    # Initialize experiment
    experiment = GPQAExperiment(args.model, args.api_key)
    
    # Select problems
    problems = SAMPLE_GPQA_PROBLEMS
    if args.num_problems:
        problems = problems[:args.num_problems]
    
    # Run experiment
    results = experiment.run_full_experiment(
        problems, args.steps, args.step_tokens
    )
    
    # Print comprehensive summary
    print("\n" + "="*70)
    print("GPQA DIAMOND EXPERIMENT SUMMARY")
    print("="*70)
    
    summary = results['summary']
    
    # Overall performance
    print("\nOVERALL PERFORMANCE:")
    overall = summary['overall_performance']
    for key, value in overall.items():
        print(f"  {key.replace('_', ' ').title()}: {value}")
    
    # Subject-wise performance
    if 'subject_performance' in summary:
        print("\nSUBJECT-WISE PERFORMANCE:")
        for subject, metrics in summary['subject_performance'].items():
            print(f"  {subject}:")
            for metric, value in metrics.items():
                print(f"    {metric.replace('_', ' ').title()}: {value}")
    
    # Entropy analysis
    if 'entropy_analysis' in summary:
        print("\nENTROPY ANALYSIS:")
        entropy = summary['entropy_analysis']
        for key, value in entropy.items():
            print(f"  {key.replace('_', ' ').title()}: {value}")
    
    # Cross-domain consistency
    if 'cross_domain_consistency' in summary:
        print("\nCROSS-DOMAIN CONSISTENCY:")
        consistency = summary['cross_domain_consistency']
        for key, value in consistency.items():
            if key != 'subject_effect_sizes':
                print(f"  {key.replace('_', ' ').title()}: {value}")
    
    # Threshold performance
    if 'threshold_performance' in summary:
        print("\nTHRESHOLD PERFORMANCE:")
        threshold = summary['threshold_performance']
        for key, value in threshold.items():
            print(f"  {key.replace('_', ' ').title()}: {value}")

if __name__ == "__main__":
    main()