"""
AIME'25 Experiment Template

Template for running entropy-based early stopping experiments on AIME'25 dataset.
Demonstrates the framework on the latest mathematical competition problems.
"""

import sys
import os
import json
import numpy as np
import pandas as pd
from typing import Dict, List, Optional
import openai
from datetime import datetime

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from entropy_framework import EarlyStoppingFramework, EntropyCalculator, ThresholdCalculator

# OpenRouter API configuration
OPENROUTER_API_KEY = "your_openrouter_key"  # Replace with actual key
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"

# Sample AIME'25 problems for demonstration
SAMPLE_AIME25_PROBLEMS = [
    {
        "problem": "Let $n$ be the number of integer solutions to $x^3 + y^3 = 2025^3$ where $|x| \\leq 2025$ and $|y| \\leq 2025$. Find the remainder when $n$ is divided by $1000$.",
        "answer": "729",
        "explanation": "This involves finding integer solutions to a cubic Diophantine equation."
    },
    {
        "problem": "In triangle $ABC$, $\\angle A = 60°$, $AB = 7$, and $AC = 8$. Point $D$ is on side $BC$ such that $AD$ bisects $\\angle A$. If $E$ is the foot of the altitude from $A$ to $BC$, find $DE$.",
        "answer": "1",
        "explanation": "Using angle bisector theorem and coordinate geometry."
    },
    {
        "problem": "A sequence is defined by $a_1 = 1$, $a_2 = 1$, and $a_n = a_{n-1} + a_{n-2} + n$ for $n \\geq 3$. Find the remainder when $a_{100}$ is divided by $10007$.",
        "answer": "9876",
        "explanation": "This is a modified Fibonacci sequence with additional linear term."
    },
    {
        "problem": "Let $f(x) = x^4 - 4x^3 + 6x^2 - 4x + 1$. Find the number of real roots of $f(x) = 0$.",
        "answer": "0",
        "explanation": "Analysis of quartic polynomial using derivatives and substitution."
    },
    {
        "problem": "In a regular dodecagon, how many diagonals pass through the center?",
        "answer": "6",
        "explanation": "Geometric analysis of regular 12-sided polygon symmetries."
    }
]

class AIME25Experiment:
    """
    Main experiment class for AIME'25 entropy analysis.
    
    Runs the complete experiment pipeline with latest competition problems,
    focusing on cross-year validation and threshold consistency.
    """
    
    def __init__(self, model_name: str, api_key: Optional[str] = None):
        """
        Initialize the AIME'25 experiment.
        
        Args:
            model_name: Name of the model to test
            api_key: API key for the model service
        """
        self.model_name = model_name
        self.api_key = api_key or OPENROUTER_API_KEY
        self.framework = EarlyStoppingFramework()
        self.results = []
        
        # Configure OpenAI client for OpenRouter
        if self.api_key:
            openai.api_key = self.api_key
            openai.api_base = OPENROUTER_BASE_URL
    
    def generate_reasoning(self, 
                          problem: str, 
                          max_tokens: int = 8192,
                          temperature: float = 0.7) -> Dict:
        """
        Generate step-by-step reasoning for a mathematical problem.
        
        Args:
            problem: AIME problem statement
            max_tokens: Maximum tokens for generation
            temperature: Sampling temperature
            
        Returns:
            Dictionary with reasoning and logprobs
        """
        try:
            prompt = f"""Solve this AIME competition problem step by step. 
Show your mathematical reasoning clearly and provide the final numerical answer.

Problem: {problem}

Solution:"""

            response = openai.ChatCompletion.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "You are an expert mathematician specializing in competition problems. Solve step by step with rigorous mathematical reasoning."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=max_tokens,
                temperature=temperature,
                logprobs=20,  # Top-20 for entropy calculation
                echo=False
            )
            
            reasoning_text = response.choices[0].message.content
            
            # Extract logprobs for entropy calculation
            logprobs = []
            if hasattr(response.choices[0], 'logprobs') and response.choices[0].logprobs:
                for token_data in response.choices[0].logprobs.token_logprobs:
                    if token_data:
                        top_logprobs = list(token_data.values())[:20]
                        logprobs.append(top_logprobs)
            
            return {
                'reasoning': reasoning_text,
                'logprobs': logprobs,
                'model': self.model_name,
                'step_tokens': max_tokens,
                'timestamp': datetime.now().isoformat()
            }
            
        except Exception as e:
            print(f"Error in reasoning generation: {e}")
            return {
                'reasoning': f"Error: {str(e)}",
                'logprobs': [],
                'model': self.model_name,
                'error': str(e)
            }
    
    def extract_numerical_answer(self, reasoning: str) -> Optional[str]:
        """
        Extract numerical answer from reasoning text.
        
        Args:
            reasoning: Generated reasoning text
            
        Returns:
            Extracted numerical answer or None
        """
        import re
        
        # Look for final answer patterns
        patterns = [
            r'[Ff]inal answer[:\s]+(\d+)',
            r'[Aa]nswer[:\s]+(\d+)',
            r'[Tt]herefore[,\s]+the answer is[:\s]+(\d+)',
            r'[Ss]o the answer is[:\s]+(\d+)',
            r'[Tt]he answer is[:\s]+(\d+)',
            r'\\boxed\{(\d+)\}',  # LaTeX boxed format
            r'$(\d+)$',  # Number at end
        ]
        
        for pattern in patterns:
            match = re.search(pattern, reasoning)
            if match:
                return match.group(1)
        
        # Fallback: look for last number in text
        numbers = re.findall(r'\b\d+\b', reasoning)
        if numbers:
            return numbers[-1]
        
        return None
    
    def evaluate_correctness(self, reasoning: str, correct_answer: str) -> bool:
        """
        Evaluate if reasoning leads to correct answer.
        
        Args:
            reasoning: Generated reasoning text  
            correct_answer: Expected correct answer
            
        Returns:
            Whether the answer is correct
        """
        extracted_answer = self.extract_numerical_answer(reasoning)
        
        if extracted_answer is None:
            return False
        
        try:
            return int(extracted_answer) == int(correct_answer)
        except ValueError:
            return extracted_answer.strip() == correct_answer.strip()
    
    def run_sequential_reasoning(self, 
                                problem: Dict, 
                                num_steps: int = 4,
                                step_tokens: int = 8192) -> Dict:
        """
        Run 4-step sequential reasoning process.
        
        Args:
            problem: Problem dictionary with statement and answer
            num_steps: Number of reasoning steps
            step_tokens: Tokens per step
            
        Returns:
            Complete reasoning sequence data
        """
        print(f"  Running {num_steps}-step reasoning...")
        
        step_results = []
        cumulative_reasoning = ""
        
        for step in range(num_steps):
            print(f"    Step {step+1}/{num_steps}")
            
            if step == 0:
                # First step: solve from scratch
                step_problem = problem['problem']
            else:
                # Subsequent steps: refine previous reasoning
                step_problem = f"""Previous reasoning:
{cumulative_reasoning}

Please review and improve the above solution to this problem:
{problem['problem']}

Provide a refined solution:"""
            
            step_result = self.generate_reasoning(
                step_problem,
                max_tokens=step_tokens,
                temperature=0.7
            )
            
            step_results.append({
                'step': step + 1,
                'reasoning': step_result['reasoning'],
                'logprobs': step_result['logprobs'],
                'tokens': len(step_result['reasoning'].split()),
                'timestamp': step_result['timestamp']
            })
            
            # Update cumulative reasoning
            cumulative_reasoning = step_result['reasoning']
        
        # Evaluate final answer
        final_reasoning = step_results[-1]['reasoning']
        is_correct = self.evaluate_correctness(final_reasoning, problem['answer'])
        extracted_answer = self.extract_numerical_answer(final_reasoning)
        
        return {
            'problem_statement': problem['problem'],
            'correct_answer': problem['answer'],
            'extracted_answer': extracted_answer,
            'is_correct': is_correct,
            'step_results': step_results,
            'total_tokens': sum(s['tokens'] for s in step_results),
            'step1_entropy_data': {
                'logprobs': step_results[0]['logprobs'],
                'reasoning': step_results[0]['reasoning'],
                'is_correct_step1': self.evaluate_correctness(step_results[0]['reasoning'], problem['answer'])
            }
        }
    
    def calculate_entropy_statistics(self, experiment_data: List[Dict]) -> Dict:
        """
        Calculate comprehensive entropy statistics.
        
        Args:
            experiment_data: List of problem results
            
        Returns:
            Entropy statistics and analysis
        """
        print("Calculating entropy statistics...")
        
        entropy_calc = EntropyCalculator(k=20)
        correct_entropies = []
        incorrect_entropies = []
        
        for data in experiment_data:
            step1_data = data['step1_entropy_data']
            if step1_data['logprobs']:
                entropy = entropy_calc.calculate_sequence_entropy(step1_data['logprobs'])
                
                if data['is_correct']:
                    correct_entropies.append(entropy)
                else:
                    incorrect_entropies.append(entropy)
        
        if not correct_entropies or not incorrect_entropies:
            print("Warning: Insufficient data for entropy analysis")
            return {}
        
        # Calculate Cohen's d effect size
        correct_mean = np.mean(correct_entropies)
        incorrect_mean = np.mean(incorrect_entropies)
        pooled_std = np.sqrt((np.var(correct_entropies) + np.var(incorrect_entropies)) / 2)
        cohens_d = abs(incorrect_mean - correct_mean) / pooled_std if pooled_std > 0 else 0
        
        # Statistical significance test
        from scipy.stats import ttest_ind
        t_stat, p_value = ttest_ind(correct_entropies, incorrect_entropies)
        
        return {
            'correct_entropies': {
                'mean': correct_mean,
                'std': np.std(correct_entropies),
                'count': len(correct_entropies),
                'values': correct_entropies
            },
            'incorrect_entropies': {
                'mean': incorrect_mean,
                'std': np.std(incorrect_entropies),
                'count': len(incorrect_entropies),
                'values': incorrect_entropies
            },
            'effect_size': {
                'cohens_d': cohens_d,
                'interpretation': self._interpret_cohens_d(cohens_d)
            },
            'statistical_test': {
                't_statistic': t_stat,
                'p_value': p_value,
                'significant': p_value < 0.05
            }
        }
    
    def _interpret_cohens_d(self, cohens_d: float) -> str:
        """Interpret Cohen's d effect size."""
        if cohens_d < 0.2:
            return "negligible"
        elif cohens_d < 0.5:
            return "small"
        elif cohens_d < 0.8:
            return "medium"
        else:
            return "large"
    
    def run_threshold_analysis(self, entropy_stats: Dict) -> Dict:
        """
        Run comprehensive threshold analysis.
        
        Args:
            entropy_stats: Entropy statistics from experiment
            
        Returns:
            Threshold analysis results
        """
        print("Running threshold analysis...")
        
        if not entropy_stats:
            return {}
        
        threshold_calc = ThresholdCalculator()
        correct_entropies = entropy_stats['correct_entropies']['values']
        incorrect_entropies = entropy_stats['incorrect_entropies']['values']
        
        methods = {
            'entropy_mean': threshold_calc.entropy_mean_threshold(correct_entropies),
            'information_theoretic': threshold_calc.information_theoretic_optimal(
                correct_entropies, incorrect_entropies),
            'bayesian': threshold_calc.bayesian_optimal(
                correct_entropies, incorrect_entropies),
            'scale_invariant': threshold_calc.scale_invariant_universal(
                correct_entropies, incorrect_entropies)
        }
        
        # Calculate stopping rates for each threshold
        threshold_analysis = {}
        for method, threshold in methods.items():
            # Calculate how many correct/incorrect would stop early
            correct_stops = sum(1 for e in correct_entropies if e <= threshold)
            incorrect_stops = sum(1 for e in incorrect_entropies if e <= threshold)
            
            total_correct = len(correct_entropies)
            total_incorrect = len(incorrect_entropies)
            
            # Calculate metrics
            correct_stop_rate = correct_stops / total_correct if total_correct > 0 else 0
            incorrect_stop_rate = incorrect_stops / total_incorrect if total_incorrect > 0 else 0
            
            total_stops = correct_stops + incorrect_stops
            threshold_accuracy = correct_stops / total_stops if total_stops > 0 else 0
            
            # Estimate token savings (assuming 75% savings when stopping early)
            total_problems = total_correct + total_incorrect
            token_savings = total_stops / total_problems * 0.75 if total_problems > 0 else 0
            
            threshold_analysis[method] = {
                'threshold': threshold,
                'correct_stop_rate': correct_stop_rate,
                'incorrect_stop_rate': incorrect_stop_rate,
                'threshold_accuracy': threshold_accuracy,
                'token_savings': token_savings,
                'total_stops': total_stops,
                'correct_stops': correct_stops,
                'incorrect_stops': incorrect_stops
            }
        
        return threshold_analysis
    
    def run_full_experiment(self, 
                           problems: Optional[List[Dict]] = None,
                           num_steps: int = 4,
                           step_tokens: int = 8192) -> Dict:
        """
        Run complete AIME'25 experiment.
        
        Args:
            problems: Problems to test (uses samples if None)
            num_steps: Number of reasoning steps
            step_tokens: Tokens per reasoning step
            
        Returns:
            Complete experiment results
        """
        if problems is None:
            problems = SAMPLE_AIME25_PROBLEMS
        
        print(f"Running AIME'25 Experiment")
        print(f"Model: {self.model_name}")
        print(f"Problems: {len(problems)}")
        print(f"Steps: {num_steps}, Tokens per step: {step_tokens}")
        print("="*50)
        
        # Run reasoning for all problems
        experiment_data = []
        for i, problem in enumerate(problems):
            print(f"\nProblem {i+1}/{len(problems)}: {problem['problem'][:60]}...")
            
            problem_result = self.run_sequential_reasoning(
                problem, num_steps, step_tokens
            )
            problem_result['problem_id'] = i
            experiment_data.append(problem_result)
            
            print(f"  Result: {'✓' if problem_result['is_correct'] else '✗'} "
                  f"(Answer: {problem_result['extracted_answer']} vs {problem['answer']})")
        
        # Calculate statistics
        print(f"\nCalculating experiment statistics...")
        
        # Basic accuracy metrics
        total_problems = len(experiment_data)
        step1_correct = sum(1 for d in experiment_data if d['step1_entropy_data']['is_correct_step1'])
        final_correct = sum(1 for d in experiment_data if d['is_correct'])
        
        step1_accuracy = step1_correct / total_problems if total_problems > 0 else 0
        final_accuracy = final_correct / total_problems if total_problems > 0 else 0
        
        # Entropy analysis
        entropy_stats = self.calculate_entropy_statistics(experiment_data)
        
        # Threshold analysis
        threshold_analysis = self.run_threshold_analysis(entropy_stats)
        
        # Compile results
        results = {
            'experiment_info': {
                'model': self.model_name,
                'dataset': 'AIME_25',
                'num_problems': total_problems,
                'num_steps': num_steps,
                'step_tokens': step_tokens,
                'timestamp': datetime.now().isoformat()
            },
            'accuracy_metrics': {
                'step1_accuracy': step1_accuracy,
                'final_accuracy': final_accuracy,
                'step1_correct': step1_correct,
                'final_correct': final_correct,
                'total_problems': total_problems
            },
            'entropy_statistics': entropy_stats,
            'threshold_analysis': threshold_analysis,
            'problem_data': experiment_data,
            'summary': self._generate_experiment_summary(
                entropy_stats, threshold_analysis, step1_accuracy, final_accuracy
            )
        }
        
        # Save results
        output_file = f"aime25_results_{self.model_name.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2, default=str)
        
        print(f"\nResults saved to: {output_file}")
        
        return results
    
    def _generate_experiment_summary(self, 
                                   entropy_stats: Dict,
                                   threshold_analysis: Dict,
                                   step1_acc: float,
                                   final_acc: float) -> Dict:
        """Generate comprehensive experiment summary."""
        summary = {
            'step1_accuracy': f"{step1_acc:.1%}",
            'final_accuracy': f"{final_acc:.1%}",
            'accuracy_improvement': f"{final_acc - step1_acc:.1%}"
        }
        
        if entropy_stats:
            summary.update({
                'cohens_d': entropy_stats.get('effect_size', {}).get('cohens_d', 0),
                'effect_interpretation': entropy_stats.get('effect_size', {}).get('interpretation', 'unknown'),
                'entropy_separation_significant': entropy_stats.get('statistical_test', {}).get('significant', False)
            })
        
        if threshold_analysis:
            # Find best performing method
            best_method = None
            best_savings = 0
            best_accuracy = 0
            
            for method, analysis in threshold_analysis.items():
                if analysis['token_savings'] > best_savings:
                    best_method = method
                    best_savings = analysis['token_savings']
                    best_accuracy = analysis['threshold_accuracy']
            
            summary.update({
                'best_threshold_method': best_method,
                'best_token_savings': f"{best_savings:.1%}",
                'best_threshold_accuracy': f"{best_accuracy:.1%}"
            })
        
        return summary

def main():
    """Main function to run AIME'25 experiment."""
    import argparse
    
    parser = argparse.ArgumentParser(description='Run AIME\'25 entropy experiment')
    parser.add_argument('--model', default='gpt-4', 
                       help='Model name to test')
    parser.add_argument('--api-key', default=None,
                       help='API key for the model service')  
    parser.add_argument('--num-problems', type=int, default=None,
                       help='Number of problems to test')
    parser.add_argument('--steps', type=int, default=4,
                       help='Number of reasoning steps')
    parser.add_argument('--step-tokens', type=int, default=8192,
                       help='Tokens per reasoning step')
    
    args = parser.parse_args()
    
    # Initialize experiment
    experiment = AIME25Experiment(args.model, args.api_key)
    
    # Select problems
    problems = SAMPLE_AIME25_PROBLEMS
    if args.num_problems:
        problems = problems[:args.num_problems]
    
    # Run experiment
    results = experiment.run_full_experiment(
        problems, args.steps, args.step_tokens
    )
    
    # Print summary
    print("\n" + "="*60)
    print("AIME'25 EXPERIMENT SUMMARY")
    print("="*60)
    summary = results['summary']
    
    for key, value in summary.items():
        print(f"{key.replace('_', ' ').title()}: {value}")

if __name__ == "__main__":
    main()