# Adapted from https://github.com/HazyResearch/hippo/blob/datasets/benchmark/utils.py
""" Useful functions for writing test code. """

import torch
import torch.utils.benchmark as benchmark

def benchmark_direction(fn, *inputs, direction='forward', grad=None, repeats=5, desc='', label='', sub_label='',
                      verbose=False, amp=True, amp_dtype=torch.float16,
                      num_threads=torch.get_num_threads(),  **kwinputs):
    """ Use Pytorch Benchmark on a specific direction pass of an arbitrary function. """
    if direction == 'forward':
        return benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, label=label, sub_label=sub_label, verbose=verbose,
                          amp=amp, amp_dtype=amp_dtype, num_threads=num_threads, **kwinputs)
    elif direction == 'backward':
        return benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, label=label, sub_label=sub_label, desc=desc, verbose=verbose,
                           amp=amp, amp_dtype=amp_dtype, num_threads=num_threads, **kwinputs)
    else:
        return benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, label=label, sub_label=sub_label, desc=desc, verbose=verbose,
                           amp=amp, amp_dtype=amp_dtype, num_threads=num_threads, **kwinputs)



def benchmark_forward(fn, *inputs, repeats=5, desc='', label='', sub_label='',
                      verbose=False, amp=True, amp_dtype=torch.float16,
                      num_threads=torch.get_num_threads(),  **kwinputs):
    """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Forward pass')
    def amp_wrapper(*inputs, **kwinputs):
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=True):
            fn(*inputs, **kwinputs)
    t = benchmark.Timer(
        stmt='fn_amp(*inputs, **kwinputs)',
        globals={'fn_amp': amp_wrapper,
                 'inputs': inputs, 'kwinputs': kwinputs},
        label=label,
        sub_label=sub_label,
        description=desc,
        num_threads=num_threads,
    )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return m


def benchmark_backward(fn, *inputs, grad=None, repeats=5, desc='', label='', sub_label='',
                       verbose=False, amp=True,  amp_dtype=torch.float16,
                       num_threads=torch.get_num_threads(), **kwinputs):
    """ Use Pytorch Benchmark on the backward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Backward pass')
    with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
        y = fn(*inputs, **kwinputs)
        if type(y) is tuple:
            y = y[0]
    if grad is None:
        grad = torch.randn_like(y)
    else:
        if grad.shape != y.shape:
            raise RuntimeError('Grad shape does not match output shape')
    t = benchmark.Timer(
        stmt='y.backward(grad, retain_graph=True)',
        globals={'y': y, 'grad': grad},
        label=label,
        sub_label=sub_label,
        description=desc,
        num_threads=num_threads,
    )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return m


def benchmark_combined(fn, *inputs, grad=None, repeats=5, desc='', label='', sub_label='',
                       verbose=False, amp=True,
                       amp_dtype=torch.float16, num_threads=torch.get_num_threads(), **kwinputs):
    """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Forward + Backward pass')
    # y = fn(*inputs, **kwinputs)
    # if grad is None:
    #     grad = torch.randn_like(y)
    # else:
    #     if grad.shape != y.shape:
    #         raise RuntimeError('Grad shape does not match output shape')
    # del y

    def f(grad, *inputs, **kwinputs):
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            y = fn(*inputs, **kwinputs)
            if type(y) is tuple:
                y = y[0]
        if grad is None:
            grad = torch.randn_like(y)
        else:
            if grad.shape != y.shape:
                raise RuntimeError('Grad shape does not match output shape')
        y.backward(grad, retain_graph=True)
    t = benchmark.Timer(
        stmt='f(grad, *inputs, **kwinputs)',
        globals={'f': f, 'fn': fn, 'inputs': inputs,
                 'grad': grad, 'kwinputs': kwinputs},
        label=label,
        sub_label=sub_label,
        description=desc,
        num_threads=num_threads,
    )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return m


def benchmark_all(fn, *inputs, grad=None, repeats=5, desc='', label='', sub_label='', verbose=False, amp=True,
                  amp_dtype=torch.float16, num_threads=torch.get_num_threads(), **kwinputs):
    """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
    return (
        benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, label=label, sub_label=sub_label, verbose=verbose,
                          amp=amp, amp_dtype=amp_dtype, num_threads=num_threads, **kwinputs),
        benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, label=label, sub_label=sub_label, desc=desc, verbose=verbose,
                           amp=amp, amp_dtype=amp_dtype, num_threads=num_threads, **kwinputs),
        benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, label=label, sub_label=sub_label, desc=desc, verbose=verbose,
                           amp=amp, amp_dtype=amp_dtype, num_threads=num_threads, **kwinputs),
    )


def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=True,
                     amp_dtype=torch.float16, cpu=False, verbose=False, **kwinputs):
    """ Wrap benchmark functions in Pytorch profiler to see CUDA information. """
    if backward:
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            g = torch.randn_like(fn(*inputs, **kwinputs))
    for _ in range(30):   # Warm up
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            if backward:
                for x in inputs:
                    if isinstance(x, torch.Tensor):
                        x.grad = None
            # fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
            out = fn(*inputs, **kwinputs)
        # Backward should be done outside autocast
        if backward:
            out.backward(g)
    activities = ([torch.profiler.ProfilerActivity.CPU]
                  if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
    with torch.profiler.profile(
        activities=activities,
        record_shapes=True,
        # profile_memory=True,
        with_stack=True,
    ) as prof:
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            if backward:
                for x in inputs:
                    if isinstance(x, torch.Tensor):
                        x.grad = None
            out = fn(*inputs, **kwinputs)
        if backward:
            out.backward(g)
    if verbose:
        # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
        print(prof.key_averages().table(row_limit=50))
    if trace_filename is not None:
        prof.export_chrome_trace(trace_filename)


def benchmark_memory(fn, *inputs, desc='', verbose=False, **kwinputs):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    fn(*inputs, **kwinputs)
    torch.cuda.synchronize()
    mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
    if verbose:
        print(f'{desc} max memory: {mem}GB')
    torch.cuda.empty_cache()
    return mem
