"""
AIME'24 Experiment Template

Template for running entropy-based early stopping experiments on AIME'24 dataset.
Demonstrates the framework on mathematical competition problems.
"""

import sys
import os
import json
import numpy as np
import pandas as pd
from typing import Dict, List, Optional
import openai
from datetime import datetime

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from entropy_framework import EarlyStoppingFramework, EntropyCalculator, ThresholdCalculator

# OpenRouter API configuration
OPENROUTER_API_KEY = "your_openrouter_key"  # Replace with actual key
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"

# Sample AIME'24 problems for demonstration
SAMPLE_AIME24_PROBLEMS = [
    {
        "problem": "Let $p$ be the least prime factor of $1009^2 + 1$. Find $p$.",
        "answer": "1009",
        "explanation": "We need to find the least prime factor of $1009^2 + 1$."
    },
    {
        "problem": "Let $ABCD$ be a convex quadrilateral with $AB = 10$, $BC = 14$, $CD = 20$, and $DA = 12$. Suppose that the diagonals of $ABCD$ intersect at point $P$, and that the sum of the areas of triangles $APB$ and $CPD$ is equal to the sum of the areas of triangles $BPC$ and $DPA$. Find the area of quadrilateral $ABCD$.",
        "answer": "156",
        "explanation": "Using properties of quadrilateral areas and diagonal intersections."
    },
    {
        "problem": "Find the number of ways to place 8 non-attacking rooks on the squares of an $8 \\times 8$ chessboard that are colored red.",
        "answer": "1680", 
        "explanation": "This is a combinatorial problem involving non-attacking rook placements."
    }
]

class AIModel:
    """
    AI Model interface for generating reasoning with logprobs.
    
    This class provides an interface for different models to generate
    reasoning steps with token-level log probabilities for entropy analysis.
    """
    
    def __init__(self, model_name: str, api_key: Optional[str] = None):
        """
        Initialize the model interface.
        
        Args:
            model_name: Name of the model to use
            api_key: API key for the service
        """
        self.model_name = model_name
        self.api_key = api_key or OPENROUTER_API_KEY
        
        # Configure OpenAI client for OpenRouter
        if api_key:
            openai.api_key = self.api_key
            openai.api_base = OPENROUTER_BASE_URL
    
    def generate_reasoning(self, 
                          problem: str, 
                          max_tokens: int = 8192,
                          temperature: float = 0.7) -> Dict:
        """
        Generate reasoning for a problem with logprobs.
        
        Args:
            problem: Problem statement
            max_tokens: Maximum tokens for generation
            temperature: Sampling temperature
            
        Returns:
            Dictionary with reasoning text and logprobs
        """
        try:
            # Create reasoning prompt
            prompt = f"""Solve this mathematical problem step by step. Show all your work and reasoning clearly.

Problem: {problem}

Solution:"""

            # Call API with logprobs (note: actual implementation depends on model)
            response = openai.ChatCompletion.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "You are an expert mathematician. Solve problems step by step with clear reasoning."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=max_tokens,
                temperature=temperature,
                logprobs=20,  # Get top-20 logprobs
                echo=False
            )
            
            # Extract reasoning and logprobs
            reasoning_text = response.choices[0].message.content
            
            # Note: Actual logprobs extraction depends on API response format
            # This is a placeholder structure
            logprobs = []
            if hasattr(response.choices[0], 'logprobs') and response.choices[0].logprobs:
                for token_logprobs in response.choices[0].logprobs.token_logprobs:
                    if token_logprobs:
                        # Extract top-k logprobs for each token
                        top_logprobs = list(token_logprobs.values())[:20]
                        logprobs.append(top_logprobs)
            
            return {
                'reasoning': reasoning_text,
                'logprobs': logprobs,
                'model': self.model_name,
                'timestamp': datetime.now().isoformat()
            }
            
        except Exception as e:
            print(f"Error generating reasoning: {e}")
            return {
                'reasoning': "Error generating reasoning",
                'logprobs': [],
                'model': self.model_name,
                'error': str(e)
            }
    
    def evaluate_answer(self, reasoning: str, correct_answer: str) -> bool:
        """
        Evaluate if the reasoning leads to the correct answer.
        
        Args:
            reasoning: Generated reasoning text
            correct_answer: Expected correct answer
            
        Returns:
            Whether the answer is correct
        """
        # Simple heuristic: check if correct answer appears in reasoning
        # In practice, you might want more sophisticated answer extraction
        return correct_answer.lower() in reasoning.lower()

class AIME24Experiment:
    """
    Main experiment class for AIME'24 entropy analysis.
    
    Runs the complete experiment pipeline: data generation, entropy calculation,
    threshold calibration, and evaluation.
    """
    
    def __init__(self, model_name: str, api_key: Optional[str] = None):
        """
        Initialize the experiment.
        
        Args:
            model_name: Name of the model to test
            api_key: API key for the model service
        """
        self.model = AIModel(model_name, api_key)
        self.framework = EarlyStoppingFramework()
        self.results = []
        
    def run_full_experiment(self, 
                           problems: Optional[List[Dict]] = None,
                           num_steps: int = 4,
                           step_tokens: int = 8192) -> Dict:
        """
        Run the complete AIME'24 experiment.
        
        Args:
            problems: List of problems to test (uses samples if None)
            num_steps: Number of reasoning steps (4-step sequential)
            step_tokens: Tokens per reasoning step
            
        Returns:
            Complete experiment results
        """
        if problems is None:
            problems = SAMPLE_AIME24_PROBLEMS
            
        print(f"Running AIME'24 experiment with {len(problems)} problems")
        print(f"Model: {self.model.model_name}")
        print(f"Steps: {num_steps}, Tokens per step: {step_tokens}")
        
        # Step 1: Generate reasoning for all problems
        print("\n1. Generating reasoning sequences...")
        all_data = []
        
        for i, problem in enumerate(problems):
            print(f"Processing problem {i+1}/{len(problems)}")
            
            # Generate multi-step reasoning
            step_results = []
            for step in range(num_steps):
                print(f"  Step {step+1}/{num_steps}")
                result = self.model.generate_reasoning(
                    problem['problem'], 
                    max_tokens=step_tokens,
                    temperature=0.7
                )
                step_results.append(result)
            
            # Evaluate correctness
            final_reasoning = step_results[-1]['reasoning']
            is_correct = self.model.evaluate_answer(final_reasoning, problem['answer'])
            
            problem_data = {
                'problem_id': i,
                'problem': problem['problem'],
                'correct_answer': problem['answer'],
                'step_results': step_results,
                'is_correct': is_correct,
                'step1_logprobs': step_results[0]['logprobs'],
                'step1_reasoning': step_results[0]['reasoning']
            }
            
            all_data.append(problem_data)
        
        # Step 2: Calculate entropies
        print("\n2. Calculating entropies...")
        calibration_data = []
        
        for data in all_data:
            if data['step1_logprobs']:  # Only if we have logprobs
                calibration_data.append({
                    'logprobs': data['step1_logprobs'],
                    'correct': data['is_correct']
                })
        
        # Step 3: Calibrate thresholds
        print("\n3. Calibrating thresholds...")
        threshold_methods = ["entropy_mean", "information_theoretic", "bayesian", "scale_invariant"]
        calibration_results = {}
        
        for method in threshold_methods:
            try:
                stats = self.framework.calibrate(calibration_data, method=method)
                calibration_results[method] = stats
                print(f"  {method}: threshold={stats['threshold']:.3f}, Cohen's d={stats['cohens_d']:.3f}")
            except Exception as e:
                print(f"  Error calibrating {method}: {e}")
                calibration_results[method] = None
        
        # Step 4: Evaluate performance
        print("\n4. Evaluating performance...")
        evaluation_results = {}
        
        for method in threshold_methods:
            if calibration_results[method] is not None:
                try:
                    eval_metrics = self.evaluate_framework(calibration_data, method)
                    evaluation_results[method] = eval_metrics
                    print(f"  {method}: early_stop_rate={eval_metrics['early_stop_rate']:.1%}, "
                          f"threshold_acc={eval_metrics['threshold_accuracy']:.1%}, "
                          f"token_savings={eval_metrics['avg_token_savings']:.1%}")
                except Exception as e:
                    print(f"  Error evaluating {method}: {e}")
                    evaluation_results[method] = None
        
        # Compile final results
        final_results = {
            'experiment_info': {
                'model': self.model.model_name,
                'dataset': 'AIME_24',
                'num_problems': len(problems),
                'num_steps': num_steps,
                'step_tokens': step_tokens,
                'timestamp': datetime.now().isoformat()
            },
            'problem_data': all_data,
            'calibration_results': calibration_results,
            'evaluation_results': evaluation_results,
            'summary': self._generate_summary(calibration_results, evaluation_results)
        }
        
        # Save results
        output_file = f"aime24_results_{self.model.model_name.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(output_file, 'w') as f:
            json.dump(final_results, f, indent=2)
        
        print(f"\n5. Results saved to {output_file}")
        
        return final_results
    
    def evaluate_framework(self, test_data: List[Dict], method: str) -> Dict[str, float]:
        """Evaluate framework performance."""
        total_questions = len(test_data)
        early_stopped = 0
        early_stopped_correct = 0
        token_savings = 0
        
        for example in test_data:
            try:
                result = self.framework.should_stop_early(example['logprobs'], method)
                
                if result.should_stop:
                    early_stopped += 1
                    if example['correct']:
                        early_stopped_correct += 1
                    token_savings += 0.75  # Assume 75% token savings for early stop
            except Exception as e:
                print(f"Error evaluating example: {e}")
                continue
        
        early_stop_rate = early_stopped / total_questions if total_questions > 0 else 0
        threshold_accuracy = (early_stopped_correct / early_stopped 
                             if early_stopped > 0 else 0)
        avg_token_savings = token_savings / total_questions if total_questions > 0 else 0
        
        return {
            'early_stop_rate': early_stop_rate,
            'threshold_accuracy': threshold_accuracy,
            'avg_token_savings': avg_token_savings,
            'total_questions': total_questions,
            'early_stopped': early_stopped
        }
    
    def _generate_summary(self, calibration_results: Dict, evaluation_results: Dict) -> Dict:
        """Generate experiment summary."""
        summary = {
            'best_method': None,
            'best_token_savings': 0,
            'best_threshold_accuracy': 0,
            'cohens_d_range': [float('inf'), -float('inf')],
            'methods_evaluated': len([m for m in evaluation_results.values() if m is not None])
        }
        
        for method, eval_result in evaluation_results.items():
            if eval_result is not None:
                if eval_result['avg_token_savings'] > summary['best_token_savings']:
                    summary['best_method'] = method
                    summary['best_token_savings'] = eval_result['avg_token_savings']
                    summary['best_threshold_accuracy'] = eval_result['threshold_accuracy']
                
                # Track Cohen's d range
                if calibration_results[method]:
                    cohens_d = calibration_results[method]['cohens_d']
                    summary['cohens_d_range'][0] = min(summary['cohens_d_range'][0], cohens_d)
                    summary['cohens_d_range'][1] = max(summary['cohens_d_range'][1], cohens_d)
        
        return summary

def main():
    """
    Main function to run AIME'24 experiment.
    
    Usage:
        python aime24_experiment.py [model_name]
    """
    import argparse
    
    parser = argparse.ArgumentParser(description='Run AIME\'24 entropy experiment')
    parser.add_argument('--model', default='gpt-3.5-turbo', 
                       help='Model name to test')
    parser.add_argument('--api-key', default=None,
                       help='API key for the model service')
    parser.add_argument('--problems', type=int, default=None,
                       help='Number of problems to test (default: all samples)')
    
    args = parser.parse_args()
    
    # Initialize experiment
    experiment = AIME24Experiment(args.model, args.api_key)
    
    # Select problems
    problems = SAMPLE_AIME24_PROBLEMS
    if args.problems:
        problems = problems[:args.problems]
    
    # Run experiment
    results = experiment.run_full_experiment(problems)
    
    # Print summary
    print("\n" + "="*60)
    print("EXPERIMENT SUMMARY")
    print("="*60)
    summary = results['summary']
    print(f"Best method: {summary['best_method']}")
    print(f"Token savings: {summary['best_token_savings']:.1%}")
    print(f"Threshold accuracy: {summary['best_threshold_accuracy']:.1%}")
    print(f"Cohen's d range: {summary['cohens_d_range'][0]:.3f} - {summary['cohens_d_range'][1]:.3f}")
    print(f"Methods evaluated: {summary['methods_evaluated']}")

if __name__ == "__main__":
    main()