#!/usr/bin/env python3
"""
Analyze saved generation timings from a JSON or pickle file.

Usage:
    python analyze_timings.py generation_timings.json
    python analyze_timings.py generation_timings.pkl
"""

import argparse
import json
import pickle
import statistics
from collections import defaultdict
from pathlib import Path


def load_timings(filename):
    """Load timings from JSON or pickle file."""
    path = Path(filename)
    
    if not path.exists():
        raise FileNotFoundError(f"File not found: {filename}")
    
    if path.suffix == '.json':
        with open(filename, 'r') as f:
            return json.load(f)
    elif path.suffix in ['.pkl', '.pickle']:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    else:
        # Try JSON first, then pickle
        try:
            with open(filename, 'r') as f:
                return json.load(f)
        except:
            with open(filename, 'rb') as f:
                return pickle.load(f)


def analyze_timings(timings_list):
    """Analyze timing data and print statistics."""
    if not timings_list:
        print("Error: No timing data found")
        return
    
    # Aggregate timings by operation
    aggregated = defaultdict(list)
    for timing_dict in timings_list:
        for op, time_ms in timing_dict.items():
            aggregated[op].append(time_ms)
    
    # Calculate total time per step
    total_times = [sum(t.values()) for t in timings_list]
    overall_total = sum(total_times)
    overall_mean = statistics.mean(total_times) if total_times else 0
    overall_min = min(total_times) if total_times else 0
    overall_max = max(total_times) if total_times else 0
    
    # Sort by mean time (descending)
    sorted_ops = sorted(aggregated.items(), key=lambda x: statistics.mean(x[1]), reverse=True)
    
    # Print header
    print(f"\n{'='*80}")
    print(f"Generation Step Timing Analysis")
    print(f"{'='*80}")
    print(f"Total steps: {len(timings_list)}")
    print(f"Total time: {overall_total:.4f} ms")
    print(f"Mean time per step: {overall_mean:.4f} ms")
    print(f"Min time per step: {overall_min:.4f} ms")
    print(f"Max time per step: {overall_max:.4f} ms")
    print(f"{'='*80}")
    print(f"{'Operation':<25} | {'Mean (ms)':<12} | {'Total (ms)':<12} | {'Min':<10} | {'Max':<10} | {'% of Total':<12}")
    print(f"{'-'*80}")
    
    # Print each operation
    for op, times in sorted_ops:
        mean_time = statistics.mean(times)
        total_time = sum(times)
        min_time = min(times)
        max_time = max(times)
        pct_of_total = (total_time / overall_total * 100) if overall_total > 0 else 0
        
        print(f"{op:<25} | {mean_time:>10.4f} | {total_time:>10.4f} | {min_time:>8.4f} | {max_time:>8.4f} | {pct_of_total:>10.2f}%")
    
    print(f"{'-'*80}")
    print(f"{'TOTAL':<25} | {overall_mean:>10.4f} | {overall_total:>10.4f} | {overall_min:>8.4f} | {overall_max:>8.4f} | {'100.00%':>12}")
    print(f"{'='*80}\n")
    
    # Show first vs last step comparison
    if len(timings_list) >= 2:
        first = timings_list[0]
        last = timings_list[-1]
        first_total = sum(first.values())
        last_total = sum(last.values())
        speedup = first_total / last_total if last_total > 0 else 0
        
        print(f"First step breakdown:")
        for op, time_ms in sorted(first.items(), key=lambda x: x[1], reverse=True):
            pct = (time_ms / first_total * 100) if first_total > 0 else 0
            print(f"  {op:<25}: {time_ms:>8.4f} ms ({pct:>5.1f}%)")
        
        print(f"\nLast step breakdown:")
        for op, time_ms in sorted(last.items(), key=lambda x: x[1], reverse=True):
            pct = (time_ms / last_total * 100) if last_total > 0 else 0
            print(f"  {op:<25}: {time_ms:>8.4f} ms ({pct:>5.1f}%)")
        
        print(f"\nFirst step total: {first_total:.4f} ms")
        print(f"Last step total:  {last_total:.4f} ms")
        print(f"Speedup: {speedup:.2f}x")
        print()


def main():
    parser = argparse.ArgumentParser(description="Analyze saved generation timings")
    parser.add_argument("filename", type=str, help="Path to JSON or pickle file with timing data")
    parser.add_argument("--format", type=str, choices=["json", "pickle", "auto"], default="auto",
                       help="File format (auto-detect if not specified)")
    
    args = parser.parse_args()
    
    try:
        timings = load_timings(args.filename)
        analyze_timings(timings)
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0


if __name__ == "__main__":
    exit(main())
