#!/usr/bin/env python3
"""
Timing Analysis Helper Script for Ring vs Ulysses Attention Performance

This script helps parse and analyze the timing outputs from the profiling
to identify performance bottlenecks between ring and Ulysses attention.

Usage:
    python timing_analysis.py <log_file>
    
    Or pipe logs directly:
    tail -f train.log | python timing_analysis.py
"""

import sys
import re
from collections import defaultdict
import statistics

def parse_timing_logs(lines):
    """Parse timing information from log lines."""
    
    # Pattern matchers for different timing categories
    patterns = {
        'layer7_total': r'\[LAYER-7-TIMING\] === TOTAL LAYER 7 TIME: ([\d.]+)ms ===',
        'layer7_attention_total': r'\[LAYER-7-TIMING\] Total Attention Module: ([\d.]+)ms',
        'layer7_sdpa': r'\[LAYER-7-TIMING\] \*\*\* SDPA \(Core Attention\): ([\d.]+)ms \*\*\*',
        'layer7_qkv_linear': r'\[LAYER-7-TIMING\] QKV Linear projections: ([\d.]+)ms',
        'layer7_rope': r'\[LAYER-7-TIMING\] RoPE embedding: ([\d.]+)ms',
        'layer7_wo_proj': r'\[LAYER-7-TIMING\] Output projection: ([\d.]+)ms',
        'layer7_feedforward': r'\[LAYER-7-TIMING\] FeedForward: ([\d.]+)ms',
        'layer7_attn_norm': r'\[LAYER-7-TIMING\] Attention LayerNorm: ([\d.]+)ms',
        'layer7_ffn_norm': r'\[LAYER-7-TIMING\] FFN LayerNorm: ([\d.]+)ms',
        
        # New backward timing patterns
        'layer7_total_backward': r'\[LAYER-7-BACKWARD-TIMING\] === TOTAL LAYER 7 BACKWARD: ([\d.]+)ms ===',
        'layer7_sdpa_backward': r'\[LAYER-7-BACKWARD-TIMING\] \*\*\* SDPA BACKWARD: ([\d.]+)ms \*\*\*',
        'layer7_wo_backward': r'\[LAYER-7-BACKWARD-TIMING\] WO Projection Backward: ([\d.]+)ms',
        'layer7_attention_backward': r'\[LAYER-7-BACKWARD-TIMING\] Total Attention Backward: ([\d.]+)ms',
        'layer7_ffn_backward': r'\[LAYER-7-BACKWARD-TIMING\] FFN Backward: ([\d.]+)ms',
        
        # Manual analysis patterns
        'manual_total_backward': r'=== TOTAL BACKWARD TIME: ([\d.]+)ms ===',
        'manual_wo_to_sdpa': r'WO PROJECTION BACKWARD -> SDPA BACKWARD: ([\d.]+)ms',
        'manual_loss_to_wo': r'LOSS BACKWARD -> WO PROJECTION BACKWARD: ([\d.]+)ms',
        
        'step_forward': r'\[STEP-TIMING\] Forward pass: ([\d.]+)ms',
        'step_backward': r'\[STEP-TIMING\] Backward pass: ([\d.]+)ms',
        'step_optimizer': r'\[STEP-TIMING\] Optimizer step: ([\d.]+)ms',
        'step_total': r'\[STEP-TIMING\] === TOTAL STEP TIME: ([\d.]+)ms ===',
        'attention_mode': r'(🔥 ATTENTION MODE: Hybrid Ulysses \+ Ring Attention|🚀 ATTENTION MODE: Ulysses Attention Only|⭕ ATTENTION MODE: Ring Attention Only|💻 ATTENTION MODE: Standard Attention \(No CP\))',
    }
    
    timing_data = defaultdict(list)
    config_info = {}
    
    for line in lines:
        line = line.strip()
        
        # Remove rank prefix if present (e.g., "[rank0]:")
        if line.startswith('[rank') and ']:' in line:
            line = line.split(']:', 1)[1]
        
        # Extract configuration info
        if 'Context Parallel Ring Degree:' in line:
            try:
                value = line.split('Context Parallel Ring Degree:')[1].strip()
                config_info['cp_ring'] = int(value)
            except (ValueError, IndexError):
                pass
        elif 'Context Parallel Ulysses Degree:' in line:
            try:
                value = line.split('Context Parallel Ulysses Degree:')[1].strip()
                config_info['cp_ulysses'] = int(value)
            except (ValueError, IndexError):
                pass
        elif 'Sequence Length:' in line:
            try:
                value = line.split('Sequence Length:')[1].strip()
                config_info['seq_len'] = int(value)
            except (ValueError, IndexError):
                pass
        elif 'Batch Size:' in line:
            try:
                value = line.split('Batch Size:')[1].strip()
                config_info['batch_size'] = int(value)
            except (ValueError, IndexError):
                pass
        
        # Extract timing data
        for category, pattern in patterns.items():
            match = re.search(pattern, line)
            if match:
                if category == 'attention_mode':
                    config_info['attention_mode'] = match.group(1)
                else:
                    timing_data[category].append(float(match.group(1)))
    
    return timing_data, config_info

def print_summary(timing_data, config_info):
    """Analyze and summarize timing data."""
    
    print("\n" + "="*60)
    print("PERFORMANCE ANALYSIS SUMMARY")
    print("="*60)
    
    if config_info.get('attention_mode'):
        print(f"Attention Mode: {config_info['attention_mode']}")
    
    print(f"\nStep-level Timing (averaged across {len(timing_data.get('step_total', []))} steps):")
    for key in ['step_total', 'step_forward', 'step_backward', 'step_optimizer']:
        if key in timing_data and timing_data[key]:
            avg = statistics.mean(timing_data[key])
            print(f"  {key.replace('step_', '').replace('_', ' ').title()}: {avg:.2f}ms")
    
    print(f"\nLayer 7 Forward Timing (averaged across {len(timing_data.get('layer7_total', []))} measurements):")
    forward_keys = ['layer7_total', 'layer7_attention_total', 'layer7_sdpa', 'layer7_qkv_linear', 
                   'layer7_rope', 'layer7_wo_proj', 'layer7_feedforward', 'layer7_attn_norm', 'layer7_ffn_norm']
    for key in forward_keys:
        if key in timing_data and timing_data[key]:
            avg = statistics.mean(timing_data[key])
            name = key.replace('layer7_', '').replace('_', ' ').title()
            if 'sdpa' in key.lower():
                name = "*** SDPA (Core Attention) ***"
            print(f"  {name}: {avg:.2f}ms")
    
    print(f"\nLayer 7 Backward Timing (averaged across {len(timing_data.get('layer7_total_backward', []))} measurements):")
    backward_keys = ['layer7_total_backward', 'layer7_sdpa_backward', 'layer7_wo_backward', 
                    'layer7_attention_backward', 'layer7_ffn_backward']
    
    # Add manual timing data if available
    manual_keys = ['manual_total_backward', 'manual_wo_to_sdpa', 'manual_loss_to_wo']
    
    # Print formal backward timings first
    for key in backward_keys:
        if key in timing_data and timing_data[key]:
            avg = statistics.mean(timing_data[key])
            name = key.replace('layer7_', '').replace('_backward', '').replace('_', ' ').title() + " Backward"
            if 'sdpa' in key.lower():
                name = "*** SDPA Backward ***"
            print(f"  {name}: {avg:.2f}ms")
    
    # Print manual analysis if available
    if any(key in timing_data and timing_data[key] for key in manual_keys):
        print(f"\nManual Backward Analysis:")
        for key in manual_keys:
            if key in timing_data and timing_data[key]:
                avg = statistics.mean(timing_data[key])
                if 'total_backward' in key:
                    name = "Total Backward Time"
                elif 'wo_to_sdpa' in key:
                    name = "WO -> SDPA Backward"
                elif 'loss_to_wo' in key:
                    name = "Loss -> WO Backward"
                print(f"  {name}: {avg:.2f}ms")
    
    # Performance ratios
    if timing_data.get('step_forward') and timing_data.get('step_backward'):
        forward_avg = statistics.mean(timing_data['step_forward'])
        backward_avg = statistics.mean(timing_data['step_backward'])
        ratio = backward_avg / forward_avg
        print(f"\nForward/Backward Ratio: {ratio:.2f}x (Backward takes {ratio:.1f}x longer than forward)")
    
    print("="*60)

def main():
    """Main function to process timing logs."""
    
    if len(sys.argv) > 1:
        # Read from file
        with open(sys.argv[1], 'r') as f:
            lines = f.readlines()
    else:
        # Read from stdin
        lines = sys.stdin.readlines()
    
    timing_data, config_info = parse_timing_logs(lines)
    print_summary(timing_data, config_info)

if __name__ == "__main__":
    main() 