#!/usr/bin/env python3
"""Test backward timing system"""

import torch
import time
from torchtitan.models.llama3.model import Attention, TransformerModelArgs, _record_backward_event, _analyze_backward_timing, _backward_timing_state

def test_backward_timing():
    args = TransformerModelArgs(dim=128, n_heads=4, n_layers=8)
    attn = Attention(args)
    x = torch.randn(2, 10, 128, requires_grad=True)
    head_dim = 128 // 4
    freqs_cis = torch.randn(10, head_dim // 2)
    
    out = attn(x, freqs_cis, layer_id=7)
    loss = out.sum()

    # Add a hook to the loss to mark the very end of backward
    def loss_backward_hook(grad):
        _record_backward_event("LOSS BACKWARD")
        return grad
        
    loss.register_hook(loss_backward_hook)

    print('About to run backward...')
    loss.backward()
    print('Backward completed')
    
    # Manual analysis of the timing data
    hook_times = _backward_timing_state['hook_times']
    if len(hook_times) >= 2:
        print("\n=== MANUAL BACKWARD TIMING ANALYSIS ===")
        for i in range(len(hook_times) - 1):
            current_event, current_time = hook_times[i]
            next_event, next_time = hook_times[i + 1]
            interval = (next_time - current_time) * 1000  # Convert to ms
            print(f"{current_event} -> {next_event}: {interval:.2f}ms")
        
        # Total backward time
        total_time = (hook_times[-1][1] - hook_times[0][1]) * 1000
        print(f"=== TOTAL BACKWARD TIME: {total_time:.2f}ms ===")

if __name__ == "__main__":
    test_backward_timing() 