import torch
import time


def get_time(fn, pre=None):
    events = []
    for _ in range(10):
        if pre is not None:
            vals = pre()
        else:
            vals = None
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        fn(vals)
        end_event.record()
        events.append((start_event, end_event))
    torch.cuda.synchronize()
    times = [t0.elapsed_time(t1) * 1e-3 for t0, t1 in events]
    return sum(times[-4:]) / 4


def check_close(ref, x, atol=5e-4, rtol=5e-4, name='sometensor', print_details=False, forgive_thres=0):
    if x.dtype != ref.dtype:
        x = x.to(ref.dtype)
    if torch.allclose(ref, x, rtol=rtol, atol=atol):
        return True
    adiff = (ref - x).abs()
    rdiff = adiff / torch.maximum(ref.abs(), x.abs())
    rdiff[x.abs() < 1e-2] = 0
    is_error = adiff.mean().item() > forgive_thres
    pr = print
    pr(f'Inaccurate {name}')
    if print_details:
        pr('===== Ref =====')
        pr(ref)
        pr('----- Ours ----')
        pr(x)
        pr('===== diff ====')
        pr((ref - x).abs())
    pr(f'    Max atol {adiff.max().item():.3f} rtol {rdiff.max().item():.3f}')
    pr(f'    Mean atol {adiff.mean().item():.3f} rtol {rdiff.mean().item():.3f}')
    return not is_error

