import time
import torch


def compare_slicers_time(
    IF, sc_config, slicer_old, slicer_opt, n_warmup=3, n_measure=1
):
    device = IF.device
    # ----------------
    # Warm-up (old)
    # ----------------
    for _ in range(n_warmup):
        _ = slicer_old(IF, sc_config)
        if device.type == "cuda":
            torch.cuda.synchronize()

    times_old = []
    for _ in range(n_measure):
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.time()
        _ = slicer_old(IF, sc_config)
        if device.type == "cuda":
            torch.cuda.synchronize()
        t1 = time.time()
        times_old.append(t1 - t0)

    # ----------------
    # Warm-up (opt)
    # ----------------
    for _ in range(n_warmup):
        _ = slicer_opt(IF, sc_config)
        if device.type == "cuda":
            torch.cuda.synchronize()

    times_opt = []
    for _ in range(n_measure):
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.time()
        _ = slicer_opt(IF, sc_config)
        if device.type == "cuda":
            torch.cuda.synchronize()
        t1 = time.time()
        times_opt.append(t1 - t0)

    avg_old = sum(times_old) / len(times_old)
    avg_opt = sum(times_opt) / len(times_opt)

    print(
        f"[OLD] runs={n_measure} avg={avg_old * 1e3:.3f} ms | min={min(times_old) * 1e3:.3f} ms | max={max(times_old) * 1e3:.3f} ms"
    )
    print(
        f"[OPT] runs={n_measure} avg={avg_opt * 1e3:.3f} ms | min={min(times_opt) * 1e3:.3f} ms | max={max(times_opt) * 1e3:.3f} ms"
    )
    if avg_opt != 0:
        speedup = avg_old / avg_opt
        print(f"Speedup ~ x{speedup:.2f}")


def verify_same_results(IF, sc_config, slicer_orig, slicer_opt, rtol=1e-4, atol=1e-5):
    IF_clone = IF.clone()

    out_orig = slicer_orig(IF, sc_config)
    out_opt = slicer_opt(IF_clone, sc_config)

    if out_orig.shape != out_opt.shape:
        print("[Mismatch] shape differs!")
        return False

    # GPU->CPU if needed, or direct allclose if same device
    same_device = out_orig.device == out_opt.device
    if not same_device:
        out_opt = out_opt.to(out_orig.device)

    # check allclose
    equal_flag = torch.allclose(out_orig, out_opt, rtol=rtol, atol=atol)
    if equal_flag:
        print("[OK] Both slicers produce same result within tolerance.")
    else:
        diff_max = (out_orig - out_opt).abs().max()
        print(f"[Mismatch] max diff= {diff_max.item()}")
    return equal_flag
