"""
Entropy-Based Early Stopping Framework for LLM Reasoning
Implementation of the "Think Just Enough" paper methods

This module implements Shannon entropy-based confidence estimation for
early stopping in large language model reasoning tasks.
"""

import json
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
import math
import logging
from dataclasses import dataclass
from scipy import stats

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class EntropyResult:
    """Results from entropy calculation."""
    entropy: float
    confidence: float
    should_stop: bool
    threshold_used: float

class EntropyCalculator:
    """
    Calculate Shannon entropy from token logprobs for confidence estimation.
    
    Uses top-k token logprobs to compute normalized entropy as a confidence signal
    for early stopping decisions in LLM reasoning.
    """
    
    def __init__(self, k: int = 20):
        """
        Initialize entropy calculator.
        
        Args:
            k: Number of top tokens to use for entropy calculation
        """
        self.k = k
        
    def calculate_entropy(self, logprobs: List[float]) -> float:
        """
        Calculate Shannon entropy from logprobs.
        
        Args:
            logprobs: List of log probabilities for top-k tokens
            
        Returns:
            Shannon entropy in bits
        """
        if not logprobs or len(logprobs) == 0:
            return float('inf')
            
        # Convert logprobs to probabilities and normalize
        exp_logprobs = [math.exp(lp) for lp in logprobs[:self.k]]
        total = sum(exp_logprobs)
        
        if total == 0:
            return float('inf')
            
        probs = [p / total for p in exp_logprobs]
        
        # Calculate Shannon entropy
        entropy = 0.0
        for p in probs:
            if p > 0:
                entropy -= p * math.log2(p)
                
        return entropy
    
    def calculate_sequence_entropy(self, token_logprobs: List[List[float]]) -> float:
        """
        Calculate mean entropy across a sequence of tokens.
        
        Args:
            token_logprobs: List of logprob lists, one per token
            
        Returns:
            Mean entropy across the sequence
        """
        if not token_logprobs:
            return float('inf')
            
        entropies = [self.calculate_entropy(lps) for lps in token_logprobs]
        return np.mean(entropies)

class ThresholdCalculator:
    """
    Calculate optimal thresholds for entropy-based early stopping.
    
    Implements four threshold methods from the paper:
    - Entropy Mean
    - Information-Theoretic Optimal  
    - Bayesian Optimal
    - Scale-Invariant Universal
    """
    
    def __init__(self):
        self.thresholds = {}
        
    def entropy_mean_threshold(self, correct_entropies: List[float]) -> float:
        """
        Conservative threshold using mean entropy of correct answers.
        
        Args:
            correct_entropies: Entropy values for correct predictions
            
        Returns:
            Threshold value
        """
        return np.mean(correct_entropies)
    
    def information_theoretic_optimal(self, 
                                    correct_entropies: List[float],
                                    incorrect_entropies: List[float]) -> float:
        """
        Information-theoretic optimal threshold using effect size scaling.
        
        Args:
            correct_entropies: Entropy values for correct predictions
            incorrect_entropies: Entropy values for incorrect predictions
            
        Returns:
            Optimal threshold value
        """
        correct_mean = np.mean(correct_entropies)
        incorrect_mean = np.mean(incorrect_entropies)
        
        # Calculate Cohen's d effect size
        pooled_std = np.sqrt((np.var(correct_entropies) + np.var(incorrect_entropies)) / 2)
        cohens_d = abs(incorrect_mean - correct_mean) / pooled_std if pooled_std > 0 else 0
        
        # Information-theoretic scaling
        scaling_factor = math.log(1 + cohens_d)
        threshold = correct_mean + scaling_factor * (incorrect_mean - correct_mean) * 0.3
        
        return threshold
    
    def bayesian_optimal(self,
                        correct_entropies: List[float], 
                        incorrect_entropies: List[float]) -> float:
        """
        Bayesian optimal threshold minimizing classification error.
        
        Args:
            correct_entropies: Entropy values for correct predictions
            incorrect_entropies: Entropy values for incorrect predictions
            
        Returns:
            Bayesian optimal threshold
        """
        # Assume Gaussian distributions
        correct_mean, correct_std = np.mean(correct_entropies), np.std(correct_entropies)
        incorrect_mean, incorrect_std = np.mean(incorrect_entropies), np.std(incorrect_entropies)
        
        if correct_std == 0 or incorrect_std == 0:
            return (correct_mean + incorrect_mean) / 2
            
        # Optimal threshold for Gaussian distributions
        a = 1 / (2 * correct_std**2) - 1 / (2 * incorrect_std**2)
        b = incorrect_mean / incorrect_std**2 - correct_mean / correct_std**2
        c = (correct_mean**2 / (2 * correct_std**2) - 
             incorrect_mean**2 / (2 * incorrect_std**2) - 
             math.log(incorrect_std / correct_std))
        
        if a == 0:
            return -c / b if b != 0 else (correct_mean + incorrect_mean) / 2
            
        discriminant = b**2 - 4*a*c
        if discriminant < 0:
            return (correct_mean + incorrect_mean) / 2
            
        threshold1 = (-b + math.sqrt(discriminant)) / (2*a)
        threshold2 = (-b - math.sqrt(discriminant)) / (2*a)
        
        # Choose threshold between the means
        mid_point = (correct_mean + incorrect_mean) / 2
        if abs(threshold1 - mid_point) < abs(threshold2 - mid_point):
            return threshold1
        else:
            return threshold2
    
    def scale_invariant_universal(self,
                                 correct_entropies: List[float],
                                 incorrect_entropies: List[float]) -> float:
        """
        Scale-invariant universal threshold with coefficient of variation adjustment.
        
        Args:
            correct_entropies: Entropy values for correct predictions  
            incorrect_entropies: Entropy values for incorrect predictions
            
        Returns:
            Scale-invariant threshold
        """
        correct_mean = np.mean(correct_entropies)
        incorrect_mean = np.mean(incorrect_entropies)
        correct_std = np.std(correct_entropies)
        incorrect_std = np.std(incorrect_entropies)
        
        # Calculate effect size
        pooled_std = np.sqrt((correct_std**2 + incorrect_std**2) / 2)
        effect_size = abs(incorrect_mean - correct_mean) / pooled_std if pooled_std > 0 else 0
        
        # Coefficient of variation adjustment
        cv_correct = correct_std / correct_mean if correct_mean > 0 else 0
        cv_incorrect = incorrect_std / incorrect_mean if incorrect_mean > 0 else 0
        avg_cv = (cv_correct + cv_incorrect) / 2
        
        # Scale-invariant calculation
        normalized_effect = effect_size / (1 + avg_cv) if avg_cv > 0 else effect_size
        scaling = max(0.1, min(0.8, normalized_effect / 3.0))  # Prevent negative scaling
        
        threshold = correct_mean + scaling * (incorrect_mean - correct_mean)
        return threshold

class EarlyStoppingFramework:
    """
    Main framework for entropy-based early stopping in LLM reasoning.
    
    Combines entropy calculation with threshold-based stopping decisions
    to achieve computational savings while preserving accuracy.
    """
    
    def __init__(self, 
                 entropy_calculator: Optional[EntropyCalculator] = None,
                 threshold_calculator: Optional[ThresholdCalculator] = None):
        """
        Initialize the framework.
        
        Args:
            entropy_calculator: Custom entropy calculator (uses default if None)
            threshold_calculator: Custom threshold calculator (uses default if None)
        """
        self.entropy_calc = entropy_calculator or EntropyCalculator()
        self.threshold_calc = threshold_calculator or ThresholdCalculator()
        self.calibrated_thresholds = {}
        
    def calibrate(self, 
                  calibration_data: List[Dict],
                  method: str = "entropy_mean") -> Dict[str, float]:
        """
        Calibrate thresholds using validation data.
        
        Args:
            calibration_data: List of dicts with 'logprobs', 'correct' keys
            method: Threshold method to use
            
        Returns:
            Dictionary of calibrated thresholds
        """
        # Separate correct and incorrect examples
        correct_entropies = []
        incorrect_entropies = []
        
        for example in calibration_data:
            entropy = self.entropy_calc.calculate_sequence_entropy(example['logprobs'])
            if example['correct']:
                correct_entropies.append(entropy)
            else:
                incorrect_entropies.append(entropy)
        
        logger.info(f"Calibrating with {len(correct_entropies)} correct, "
                   f"{len(incorrect_entropies)} incorrect examples")
        
        # Calculate threshold based on method
        if method == "entropy_mean":
            threshold = self.threshold_calc.entropy_mean_threshold(correct_entropies)
        elif method == "information_theoretic":
            threshold = self.threshold_calc.information_theoretic_optimal(
                correct_entropies, incorrect_entropies)
        elif method == "bayesian":
            threshold = self.threshold_calc.bayesian_optimal(
                correct_entropies, incorrect_entropies)
        elif method == "scale_invariant":
            threshold = self.threshold_calc.scale_invariant_universal(
                correct_entropies, incorrect_entropies)
        else:
            raise ValueError(f"Unknown threshold method: {method}")
        
        self.calibrated_thresholds[method] = threshold
        
        # Calculate statistics
        stats_info = {
            'threshold': threshold,
            'correct_mean': np.mean(correct_entropies),
            'correct_std': np.std(correct_entropies),
            'incorrect_mean': np.mean(incorrect_entropies),
            'incorrect_std': np.std(incorrect_entropies),
            'cohens_d': self._calculate_cohens_d(correct_entropies, incorrect_entropies)
        }
        
        logger.info(f"Calibrated {method} threshold: {threshold:.3f}")
        logger.info(f"Cohen's d effect size: {stats_info['cohens_d']:.3f}")
        
        return stats_info
    
    def should_stop_early(self, 
                         token_logprobs: List[List[float]], 
                         method: str = "entropy_mean") -> EntropyResult:
        """
        Determine whether to stop early based on entropy.
        
        Args:
            token_logprobs: Logprobs for the current reasoning sequence
            method: Threshold method to use
            
        Returns:
            EntropyResult with stopping decision and confidence metrics
        """
        if method not in self.calibrated_thresholds:
            raise ValueError(f"Threshold for method '{method}' not calibrated. "
                           f"Call calibrate() first.")
        
        entropy = self.entropy_calc.calculate_sequence_entropy(token_logprobs)
        threshold = self.calibrated_thresholds[method]
        
        should_stop = entropy <= threshold
        confidence = max(0, 1 - (entropy / threshold)) if threshold > 0 else 0
        
        return EntropyResult(
            entropy=entropy,
            confidence=confidence,
            should_stop=should_stop,
            threshold_used=threshold
        )
    
    def _calculate_cohens_d(self, 
                          correct_entropies: List[float], 
                          incorrect_entropies: List[float]) -> float:
        """Calculate Cohen's d effect size."""
        if not correct_entropies or not incorrect_entropies:
            return 0.0
            
        correct_mean = np.mean(correct_entropies)
        incorrect_mean = np.mean(incorrect_entropies)
        
        pooled_std = np.sqrt((np.var(correct_entropies) + np.var(incorrect_entropies)) / 2)
        
        if pooled_std == 0:
            return 0.0
            
        return abs(incorrect_mean - correct_mean) / pooled_std

def load_experiment_data(file_path: str) -> List[Dict]:
    """
    Load experiment data from JSON file.
    
    Args:
        file_path: Path to JSON file with experiment data
        
    Returns:
        List of experiment examples
    """
    with open(file_path, 'r') as f:
        return json.load(f)

def evaluate_framework(framework: EarlyStoppingFramework,
                      test_data: List[Dict],
                      method: str = "entropy_mean") -> Dict[str, float]:
    """
    Evaluate the framework on test data.
    
    Args:
        framework: Calibrated framework instance
        test_data: Test examples with logprobs and correct answers
        method: Threshold method to use
        
    Returns:
        Evaluation metrics
    """
    total_questions = len(test_data)
    early_stopped = 0
    early_stopped_correct = 0
    token_savings = 0
    
    for example in test_data:
        result = framework.should_stop_early(example['logprobs'], method)
        
        if result.should_stop:
            early_stopped += 1
            if example['correct']:
                early_stopped_correct += 1
            # Assume we save 75% of tokens when stopping early
            token_savings += 0.75
    
    early_stop_rate = early_stopped / total_questions if total_questions > 0 else 0
    threshold_accuracy = (early_stopped_correct / early_stopped 
                         if early_stopped > 0 else 0)
    avg_token_savings = token_savings / total_questions if total_questions > 0 else 0
    
    return {
        'early_stop_rate': early_stop_rate,
        'threshold_accuracy': threshold_accuracy,
        'avg_token_savings': avg_token_savings,
        'total_questions': total_questions,
        'early_stopped': early_stopped
    }

# Example usage and testing
if __name__ == "__main__":
    # Initialize framework
    framework = EarlyStoppingFramework()
    
    # Example calibration data (replace with actual data)
    sample_calibration_data = [
        {
            'logprobs': [[-0.1, -0.5, -1.2], [-0.2, -0.8, -1.5]],
            'correct': True
        },
        {
            'logprobs': [[-0.8, -1.2, -2.1], [-1.1, -1.8, -2.5]], 
            'correct': False
        }
    ]
    
    # Calibrate thresholds
    stats = framework.calibrate(sample_calibration_data, method="entropy_mean")
    print(f"Calibration stats: {stats}")
    
    # Test stopping decision
    test_logprobs = [[-0.1, -0.3, -0.7], [-0.2, -0.4, -0.9]]
    result = framework.should_stop_early(test_logprobs, method="entropy_mean")
    
    print(f"Entropy: {result.entropy:.3f}")
    print(f"Should stop: {result.should_stop}")
    print(f"Confidence: {result.confidence:.3f}")