# profiling/measure_flops.py
"""
Simple timing-based FLOP estimator (coarse). For rigorous FLOPs use fvcore or ptflops.
This module times forward passes and reports per-sample latency; optional hooks to compute FLOPs.
"""
import time
import numpy as np

def profile_fn(fn, args=(), kwargs=None, n_runs=10):
    kwargs = kwargs or {}
    times = []
    for _ in range(n_runs):
        t0 = time.time()
        _ = fn(*args, **kwargs)
        t1 = time.time()
        times.append(t1 - t0)
    times = np.array(times)
    return {"mean_s": float(times.mean()), "median_s": float(np.median(times)), "p95_s": float(np.percentile(times,95))}
