#!/usr/bin/env python3
"""
Test script to demonstrate layer-level profiling using actual TransformerBlock.
This will trigger all the timing statements we added for Layer 7 profiling.
"""

import os
import sys
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh


from torchtitan.models.llama3.model import (
    TransformerBlock, 
    TransformerModelArgs, 
    precompute_freqs_cis,
    _backward_timing_state,
    _record_backward_event,
    _analyze_backward_timing
)

def test_layer_profiling():
    """Test layer profiling with actual TransformerBlock."""
    
    print("🔥 Testing Layer 7 Profiling with TransformerBlock...")
    
    # Setup distributed if running with torchrun, otherwise single process
    if "LOCAL_RANK" in os.environ:
        if not dist.is_initialized():
            local_rank = int(os.environ.get("LOCAL_RANK", 0))
            torch.cuda.set_device(local_rank)
            dist.init_process_group(backend="nccl")
        
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        device = torch.device(f"cuda:{rank}")
        print(f"[Rank {rank}] Distributed setup - World size: {world_size}")
    else:
        rank = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Single process setup - Device: {device}")
    
    # Create model configuration
    model_args = TransformerModelArgs(
        dim=512,           # Smaller for testing
        n_heads=8,
        n_layers=8,
        n_kv_heads=8,
        vocab_size=1000,
        max_seq_len=1024
    )
    
    print(f"[Rank {rank}] Model config: dim={model_args.dim}, heads={model_args.n_heads}")
    
    # Create Layer 7 (this will trigger all our profiling code)
    layer = TransformerBlock(layer_id=7, model_args=model_args).to(device)
    
    # Create test input
    batch_size = 2
    seq_len = 512
    
    x = torch.randn(batch_size, seq_len, model_args.dim, device=device, requires_grad=True)
    
    # Precompute frequency embeddings
    freqs_cis = precompute_freqs_cis(
        model_args.dim // model_args.n_heads,
        model_args.max_seq_len,
        model_args.rope_theta
    ).to(device)
    
    print(f"[Rank {rank}] Created test tensors: x={x.shape}, freqs_cis={freqs_cis.shape}")
    print(f"[Rank {rank}] Memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    
    print(f"\n[Rank {rank}] ==========================================")
    print(f"[Rank {rank}] 🚀 STARTING LAYER 7 FORWARD PASS")
    print(f"[Rank {rank}] ==========================================")
    
    # Forward pass (this will trigger all our forward timing statements)
    torch.cuda.synchronize()
    forward_start = torch.cuda.Event(enable_timing=True)
    forward_end = torch.cuda.Event(enable_timing=True)
    
    forward_start.record()
    output = layer(x, freqs_cis)
    forward_end.record()
    
    torch.cuda.synchronize()
    forward_time = forward_start.elapsed_time(forward_end)
    
    print(f"\n[Rank {rank}] Forward pass completed in {forward_time:.2f}ms")
    print(f"[Rank {rank}] Output shape: {output.shape}")
    
    # Create loss for backward pass
    loss = output.sum()
    
    # Add hook to trace the very start of backward
    def loss_backward_hook(grad):
        _record_backward_event("LOSS BACKWARD START")
        return grad
    
    # Add hook to trace the very end of backward  
    def final_backward_hook(grad):
        _record_backward_event("FINAL BACKWARD END")
        # Trigger analysis after all hooks
        _analyze_backward_timing()
        return grad
        
    loss.register_hook(loss_backward_hook)
    x.register_hook(final_backward_hook)  # This fires last
    
    print(f"\n[Rank {rank}] ==========================================")
    print(f"[Rank {rank}] ⚡ STARTING LAYER 7 BACKWARD PASS") 
    print(f"[Rank {rank}] ==========================================")
    
    # Backward pass (this will trigger all our backward timing statements)
    torch.cuda.synchronize()
    backward_start = torch.cuda.Event(enable_timing=True)
    backward_end = torch.cuda.Event(enable_timing=True)
    
    backward_start.record()
    loss.backward()
    backward_end.record()
    
    torch.cuda.synchronize() 
    backward_time = backward_start.elapsed_time(backward_end)
    
    print(f"\n[Rank {rank}] Backward pass completed in {backward_time:.2f}ms")
    print(f"[Rank {rank}] Forward/Backward ratio: {backward_time/forward_time:.2f}x")
    
    # Manual analysis of recorded timing events
    hook_times = _backward_timing_state['hook_times']
    if len(hook_times) >= 2:
        print(f"\n[Rank {rank}] ==========================================")
        print(f"[Rank {rank}] 📊 MANUAL BACKWARD TIMING ANALYSIS")
        print(f"[Rank {rank}] ==========================================")
        
        for i in range(len(hook_times) - 1):
            current_event, current_time = hook_times[i]
            next_event, next_time = hook_times[i + 1]
            interval = (next_time - current_time) * 1000  # Convert to ms
            print(f"[Rank {rank}] {current_event} -> {next_event}: {interval:.2f}ms")
        
        # Total backward time from hooks
        total_hook_time = (hook_times[-1][1] - hook_times[0][1]) * 1000
        print(f"[Rank {rank}] === TOTAL HOOK BACKWARD TIME: {total_hook_time:.2f}ms ===")
    
    print(f"\n[Rank {rank}] ==========================================") 
    print(f"[Rank {rank}] ✅ PROFILING TEST COMPLETED")
    print(f"[Rank {rank}] ==========================================")
    print(f"[Rank {rank}] Memory peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")

def test_with_context_parallel():
    """Test the same layer with context parallel enabled."""
    
    print("\n🔥🔥 Testing Layer 7 Profiling with Context Parallel...")
    
    if not dist.is_initialized():
        print("❌ Context parallel test requires distributed setup. Skipping.")
        return
    
    try:
        from torch.distributed.tensor.experimental import context_parallel
        
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        device = torch.device(f"cuda:{rank}")
        
        # Create 1D mesh for context parallel
        mesh = DeviceMesh("cuda", list(range(world_size)))
        print(f"[Rank {rank}] Created CP device mesh: {mesh}")
        
        # Create model configuration  
        model_args = TransformerModelArgs(
            dim=512,
            n_heads=8, 
            n_layers=8,
            n_kv_heads=8,
            vocab_size=1000,
            max_seq_len=1024
        )
        
        # Create Layer 7
        layer = TransformerBlock(layer_id=7, model_args=model_args).to(device)
        
        # Create test input
        batch_size = 2
        seq_len = 1024  # Larger for CP
        
        x = torch.randn(batch_size, seq_len, model_args.dim, device=device, requires_grad=True)
        freqs_cis = precompute_freqs_cis(
            model_args.dim // model_args.n_heads,
            model_args.max_seq_len,
            model_args.rope_theta
        ).to(device)
        
        print(f"[Rank {rank}] CP test - Input: {x.shape}")
        
        # Test with context parallel
        with context_parallel(mesh, buffers=[x], buffer_seq_dims=[1]):
            print(f"[Rank {rank}] Inside context parallel...")
            output = layer(x, freqs_cis)
            loss = output.sum()
            
            print(f"[Rank {rank}] CP forward complete, starting backward...")
            loss.backward()
            print(f"[Rank {rank}] CP backward complete!")
            
    except Exception as e:
        print(f"[Rank {rank}] Context parallel test failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    # Test 1: Basic layer profiling
    test_layer_profiling()
    
    # Test 2: Context parallel profiling (if distributed)
    if "LOCAL_RANK" in os.environ:
        test_with_context_parallel()
    
    print("\n🎉 All profiling tests completed!") 