import time
import torch
from functools import wraps

def timing(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        # CUDA 同步（避免异步误差）
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t0 = time.perf_counter()

        result = func(*args, **kwargs)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t1 = time.perf_counter()

        elapsed = t1 - t0
        print(f"[Timing] {func.__name__} elapsed: {elapsed:.3f} s")
        return result
    return wrapper