#!/usr/bin/env python3
"""
Script to analyze token statistics for AIME dataset model generations.
Calculates average, minimum, and maximum tokens for each problem and overall.

Usage:
python analyze_tokens.py <path_to_aime_dataset.json>

Example:
python analyze_tokens.py aime_dataset.json

"""

import json
import tiktoken
import argparse
import os
from typing import List, Dict, Tuple
import statistics

def load_aime_data(file_path: str) -> List[Dict]:
    """Load the AIME dataset from JSON file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def count_tokens(text: str, encoding) -> int:
    """Count tokens in a text using tiktoken."""
    return len(encoding.encode(text))

def analyze_problem_tokens(problem: Dict, encoding) -> Dict:
    """Analyze token statistics for a single problem."""
    responses = problem.get('responses', [])
    if not responses:
        return {
            'problem_index': problem.get('extra_info', {}).get('index', 'unknown'),
            'num_responses': 0,
            'token_counts': [],
            'avg_tokens': 0,
            'min_tokens': 0,
            'max_tokens': 0
        }
    
    # Count tokens for each response
    token_counts = [count_tokens(response, encoding) for response in responses]
    
    return {
        'problem_index': problem.get('extra_info', {}).get('index', 'unknown'),
        'num_responses': len(responses),
        'token_counts': token_counts,
        'avg_tokens': statistics.mean(token_counts),
        'min_tokens': min(token_counts),
        'max_tokens': max(token_counts)
    }

def analyze_all_problems(data: List[Dict], encoding) -> Tuple[List[Dict], Dict]:
    """Analyze token statistics for all problems."""
    problem_stats = []
    all_token_counts = []
    
    for problem in data:
        stats = analyze_problem_tokens(problem, encoding)
        problem_stats.append(stats)
        all_token_counts.extend(stats['token_counts'])
    
    # Overall statistics
    overall_stats = {
        'total_problems': len(data),
        'total_responses': len(all_token_counts),
        'all_token_counts': all_token_counts,
        'overall_avg_tokens': statistics.mean(all_token_counts),
        'overall_min_tokens': min(all_token_counts),
        'overall_max_tokens': max(all_token_counts),
        'overall_std_tokens': statistics.stdev(all_token_counts) if len(all_token_counts) > 1 else 0
    }
    
    return problem_stats, overall_stats

def print_results(problem_stats: List[Dict], overall_stats: Dict):
    """Print the analysis results in a formatted way."""
    print("=" * 80)
    print("AIME DATASET TOKEN ANALYSIS")
    print("=" * 80)
    
    print(f"\nOverall Statistics:")
    print(f"  Total Problems: {overall_stats['total_problems']}")
    print(f"  Total Responses: {overall_stats['total_responses']}")
    print(f"  Average Tokens: {overall_stats['overall_avg_tokens']:.2f}")
    print(f"  Minimum Tokens: {overall_stats['overall_min_tokens']}")
    print(f"  Maximum Tokens: {overall_stats['overall_max_tokens']}")
    print(f"  Standard Deviation: {overall_stats['overall_std_tokens']:.2f}")
    
    print(f"\nPer-Problem Statistics:")
    print("-" * 80)
    print(f"{'Problem':<8} {'Responses':<10} {'Avg':<8} {'Min':<8} {'Max':<8}")
    print("-" * 80)
    
    for stats in problem_stats:
        print(f"{stats['problem_index']:<8} {stats['num_responses']:<10} "
              f"{stats['avg_tokens']:<8.1f} {stats['min_tokens']:<8} {stats['max_tokens']:<8}")
    
    # Show some examples of shortest and longest responses
    print(f"\nToken Distribution Analysis:")
    print("-" * 80)
    
    all_counts = overall_stats['all_token_counts']
    sorted_counts = sorted(all_counts)
    
    print(f"  Median Tokens: {statistics.median(all_counts):.1f}")
    print(f"  25th Percentile: {sorted_counts[len(sorted_counts)//4]:.1f}")
    print(f"  75th Percentile: {sorted_counts[3*len(sorted_counts)//4]:.1f}")
    
    # Find problems with extreme values
    min_problem = min(problem_stats, key=lambda x: x['min_tokens'])
    max_problem = max(problem_stats, key=lambda x: x['max_tokens'])
    
    print(f"\nProblem with shortest response: Problem {min_problem['problem_index']} "
          f"(min: {min_problem['min_tokens']} tokens)")
    print(f"Problem with longest response: Problem {max_problem['problem_index']} "
          f"(max: {max_problem['max_tokens']} tokens)")

def save_results_to_file(problem_stats: List[Dict], overall_stats: Dict, output_file: str):
    """Save results to a text file."""
    with open(output_file, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("AIME DATASET TOKEN ANALYSIS\n")
        f.write("=" * 80 + "\n")
        
        f.write(f"\nOverall Statistics:\n")
        f.write(f"  Total Problems: {overall_stats['total_problems']}\n")
        f.write(f"  Total Responses: {overall_stats['total_responses']}\n")
        f.write(f"  Average Tokens: {overall_stats['overall_avg_tokens']:.2f}\n")
        f.write(f"  Minimum Tokens: {overall_stats['overall_min_tokens']}\n")
        f.write(f"  Maximum Tokens: {overall_stats['overall_max_tokens']}\n")
        f.write(f"  Standard Deviation: {overall_stats['overall_std_tokens']:.2f}\n")
        
        f.write(f"\nPer-Problem Statistics:\n")
        f.write("-" * 80 + "\n")
        f.write(f"{'Problem':<8} {'Responses':<10} {'Avg':<8} {'Min':<8} {'Max':<8}\n")
        f.write("-" * 80 + "\n")
        
        for stats in problem_stats:
            f.write(f"{stats['problem_index']:<8} {stats['num_responses']:<10} "
                   f"{stats['avg_tokens']:<8.1f} {stats['min_tokens']:<8} {stats['max_tokens']:<8}\n")
        
        f.write(f"\nToken Distribution Analysis:\n")
        f.write("-" * 80 + "\n")
        
        all_counts = overall_stats['all_token_counts']
        sorted_counts = sorted(all_counts)
        
        f.write(f"  Median Tokens: {statistics.median(all_counts):.1f}\n")
        f.write(f"  25th Percentile: {sorted_counts[len(sorted_counts)//4]:.1f}\n")
        f.write(f"  75th Percentile: {sorted_counts[3*len(sorted_counts)//4]:.1f}\n")
        
        # Find problems with extreme values
        min_problem = min(problem_stats, key=lambda x: x['min_tokens'])
        max_problem = max(problem_stats, key=lambda x: x['max_tokens'])
        
        f.write(f"\nProblem with shortest response: Problem {min_problem['problem_index']} "
               f"(min: {min_problem['min_tokens']} tokens)\n")
        f.write(f"Problem with longest response: Problem {max_problem['problem_index']} "
               f"(max: {max_problem['max_tokens']} tokens)\n")

def main():
    """Main function to run the analysis."""
    parser = argparse.ArgumentParser(description='Analyze token statistics for AIME dataset')
    parser.add_argument('file_path', help='Path to the AIME JSON file to analyze')
    args = parser.parse_args()
    
    file_path = args.file_path
    
    # Check if file exists
    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' does not exist.")
        return
    
    # Initialize tiktoken encoder (using cl100k_base which is commonly used)
    try:
        encoding = tiktoken.get_encoding("cl100k_base")
    except:
        # Fallback to gpt2 encoding if cl100k_base is not available
        encoding = tiktoken.get_encoding("gpt2")
    
    print(f"Loading AIME dataset from: {file_path}")
    data = load_aime_data(file_path)
    
    print(f"Found {len(data)} problems in the dataset")
    
    print("Analyzing token statistics...")
    problem_stats, overall_stats = analyze_all_problems(data, encoding)
    
    print_results(problem_stats, overall_stats)
    
    # Generate output file path in the same directory as input file
    input_dir = os.path.dirname(file_path)
    input_filename = os.path.basename(file_path)
    filename_without_ext = os.path.splitext(input_filename)[0]
    output_file = os.path.join(input_dir, f"{filename_without_ext}_stats.txt")
    
    # Save results to text file
    save_results_to_file(problem_stats, overall_stats, output_file)
    print(f"\nDetailed results saved to: {output_file}")
    
    # Also save detailed JSON results
    json_output_file = os.path.join(input_dir, f"{filename_without_ext}_stats.json")
    results = {
        'problem_statistics': problem_stats,
        'overall_statistics': overall_stats
    }
    
    with open(json_output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"JSON results saved to: {json_output_file}")

if __name__ == "__main__":
    main() 