# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import torch
import time

# Shared storage for timing between start/end pairs
_timer_storage = {}

class TimerStart(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, name):
        ctx.name = name
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        # Store start time for TimerEnd to access
        _timer_storage[name] = {'forward_start': start_time}
        return x

    @staticmethod
    def backward(ctx, grad):
        torch.cuda.synchronize()
        backward_end = time.perf_counter()
        backward_time = (backward_end - _timer_storage[ctx.name]['backward_start']) * 1000
        print(f"{ctx.name} backward: {backward_time:.3f} ms")
        return grad, None

class TimerEnd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, name):
        ctx.name = name
        torch.cuda.synchronize()
        forward_end = time.perf_counter()
        forward_time = (forward_end - _timer_storage[name]['forward_start']) * 1000
        print(f"{name} forward: {forward_time:.3f} ms")
        return x

    @staticmethod
    def backward(ctx, grad):
        torch.cuda.synchronize()
        backward_start = time.perf_counter()
        # Store for TimerStart to access
        _timer_storage[ctx.name]['backward_start'] = backward_start
        return grad, None

def timer_start(x, name):
    return TimerStart.apply(x, name)

def timer_end(x, name):
    return TimerEnd.apply(x, name)
