
     
selector = MetaModel

selector = selector.to('cuda').eval() 


# ---------------------------------------------------------------
BATCH      = 1
NUM_FEATS  = 4                  # ★  top-5 + skew + kurtosis
EMB_DIM    = 10

numeric = torch.randn(BATCH, NUM_FEATS, dtype=torch.float32)
classes = torch.randint(0, 196, (BATCH,), dtype=torch.long)

def measure(device: torch.device, n_iter: int = 10000):
    sel   = selector.to(device)          # just moves weights
    x_num = numeric.to(device)
    x_cls = classes.to(device)

    # warm-up ----------------------------------------------------
    with torch.no_grad():
        for _ in range(50):
            _ = sel(x_num, x_cls)

    times = []
    with torch.no_grad():
        if device.type == 'cuda':
            starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
            for _ in range(n_iter):
                starter.record()
                _ = sel(x_num, x_cls)
                ender.record()
                torch.cuda.synchronize()
                times.append(starter.elapsed_time(ender))             # ms
        else:  # CPU timing
            for _ in range(n_iter):
                t0 = time.perf_counter()
                _ = sel(x_num, x_cls)
                times.append( (time.perf_counter() - t0) * 1_000  )   # ★ ms

    return mean(times), stdev(times)

# ---------------------------------------------------------------
cpu_mean, cpu_std = measure(torch.device('cpu'))
print(f"CPU latency : {cpu_mean:.3f} ± {cpu_std:.3f} ms")

if torch.cuda.is_available():
    gpu_mean, gpu_std = measure(torch.device('cuda'))
    print(f"GPU latency : {gpu_mean:.3f} ± {gpu_std:.3f} ms")
else:
    print("GPU not detected – skipping CUDA timing.")





