# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import time
import torch
import torch.distributed as dist


class TimerStart(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, timer_state):
        ctx.timer_state = timer_state
        torch.cuda.synchronize()
        dist.barrier()
        timer_state["fwd"] = time.perf_counter()
        return x
    @staticmethod
    def backward(ctx, grad_output):
        timer_state = ctx.timer_state
        torch.cuda.synchronize()
        dist.barrier()
        timer_state["bwd"] = time.perf_counter() - timer_state["bwd"]
        return grad_output, None


class TimerEnd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, timer_state):
        ctx.timer_state = timer_state
        torch.cuda.synchronize()
        dist.barrier()
        timer_state["fwd"] = time.perf_counter() - timer_state["fwd"]
        return x
    @staticmethod
    def backward(ctx, grad_output):
        timer_state = ctx.timer_state
        torch.cuda.synchronize()
        dist.barrier()
        timer_state["bwd"] = time.perf_counter()
        return grad_output, None
