import time

import torch
import triton
from triton import language as tl


@triton.jit
def addvv_triton_kernel(
        vec_a_ptr,
        mat_c_ptr,
        size_a: int,
        BLOCK_SIZE_B: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    offset_a = pid % size_a
    offset_b_begin = pid // size_a * BLOCK_SIZE_B
    if offset_b_begin > offset_a:
        return  # only compute lower-triangular
    vec_b_ptr = vec_a_ptr
    size_b = size_a
    offsets_b = offset_b_begin + tl.arange(0, BLOCK_SIZE_B)
    mask = offsets_b < size_b
    c_ptrs = mat_c_ptr + offset_a * size_b + offsets_b
    a = tl.load(vec_a_ptr + offset_a)
    b = tl.load(vec_b_ptr + offsets_b, mask=mask)
    c = tl.load(c_ptrs, mask=mask)
    c = tl.fma(-a, b, c)
    tl.store(c_ptrs, c, mask=mask)


def addvv_triton(
        vec: torch.Tensor,
        mat: torch.Tensor,
) -> None:
    size: int = vec.size(-1)
    grid = lambda meta: (size * triton.cdiv(size, meta['BLOCK_SIZE_B']), )
    addvv_triton_kernel[grid](
        vec,
        mat,
        size,
        BLOCK_SIZE_B=256,
    )


@torch.no_grad()
def min_pivot_order_triton(hessian: torch.Tensor, output: torch.Tensor = None) -> torch.Tensor:
    """

    Args:
        hessian: (C, C)

    Returns:
        order: (C), int64
    """

    dtype, device = hessian.dtype, hessian.device
    n_columns: int = hessian.size(-1)
    inf: torch.Tensor = torch.full((), torch.inf, dtype=dtype, device=device)  # ()
    false: torch.Tensor = torch.zeros((), dtype=torch.bool, device=device)  # ()
    order: torch.Tensor = torch.empty(n_columns, dtype=torch.int64, device=device) if output is None else output  # (C), int64
    mask: torch.Tensor = torch.ones(n_columns, dtype=torch.bool, device=device)  # (C), bool
    d_tmp: torch.Tensor = torch.empty(n_columns, dtype=dtype, device=device)  # (C), fp
    h_tmp_row: torch.Tensor = torch.empty(n_columns, dtype=dtype, device=device)  # (C), fp
    h_tmp_col: torch.Tensor = torch.empty(n_columns, dtype=dtype, device=device)  # (C), fp
    h_tmp_rc: torch.Tensor = torch.empty((), dtype=dtype, device=device)  # (), fp
    h_tmp: torch.Tensor = torch.empty(n_columns, dtype=dtype, device=device)  # (C), fp
    indices: torch.Tensor = torch.arange(n_columns, dtype=torch.int64, device=device)  # (C), int64
    cmp: torch.Tensor = torch.empty(n_columns, dtype=torch.bool, device=device)  # (C), bool
    hessian: torch.Tensor = hessian.clone()  # (C, C), fp
    for k in range(n_columns):
        torch.where(mask, hessian.diagonal(), inf, out=d_tmp)  # (C)
        j: torch.Tensor = torch.argmin(d_tmp, out=order[k])  # ()
        torch.index_select(hessian, dim=-2, index=j, out=h_tmp_row[None, :])  # (1, C)
        torch.index_select(hessian, dim=-1, index=j, out=h_tmp_col[:, None])  # (C, 1)
        torch.index_select(h_tmp_col, dim=-1, index=j, out=h_tmp_rc[None])  # (1)
        torch.lt(indices, j, out=cmp)  # (C)
        torch.where(cmp, h_tmp_row, h_tmp_col, out=h_tmp)  # (C)
        h_tmp *= h_tmp_rc.pow_(-.5)  # (C)
        addvv_triton(h_tmp, hessian)  # (C, C)
        mask.scatter_(dim=-1, index=j, src=false)  # ()
    return order  # (C), int64


@torch.no_grad()
def min_pivot_order_baseline(hessian: torch.Tensor, output: torch.Tensor = None) -> torch.Tensor:
    """

    Args:
        hessian: (C, C)

    Returns:
        order: (C), int64
    """

    dtype, device = hessian.dtype, hessian.device
    n_columns: int = hessian.size(-1)
    inf: torch.Tensor = torch.full((), torch.inf, dtype=dtype, device=device)  # ()
    false: torch.Tensor = torch.zeros((), dtype=torch.bool, device=device)  # ()
    order: torch.Tensor = torch.empty(n_columns, dtype=torch.int64, device=device) if output is None else output  # (C), int64
    mask: torch.Tensor = torch.ones(n_columns, dtype=torch.bool, device=device)  # (C), bool
    hessian: torch.Tensor = hessian.clone()  # (C, C)
    for k in range(n_columns):
        j = order[k] = torch.where(mask, hessian.diagonal(), inf).argmin()  # ()
        h_tmp: torch.Tensor = hessian.index_select(dim=-2, index=j)  # (1, C)
        h_tmp *= h_tmp.index_select(dim=-1, index=j) ** -.5  # (1, C)
        hessian -= h_tmp.transpose(-2, -1) * h_tmp  # (C, C)
        mask.scatter_(dim=-1, index=j, src=false)  # ()
    return order  # (C), int64


@torch.compile(fullgraph=True, dynamic=True, mode='reduce-overhead')
def min_pivot_order_compiled(hessian: torch.Tensor) -> torch.Tensor:
    return min_pivot_order_baseline(hessian)


@torch.no_grad()
def min_pivot_order_reference(hessian: torch.Tensor) -> torch.Tensor:
    """

    Args:
        hessian: (C, C)

    Returns:
        order: (C), int64
    """

    dtype, device = hessian.dtype, hessian.device
    n_columns: int = hessian.size(-1)
    inf: torch.Tensor = torch.full((), torch.inf, dtype=dtype, device=device)  # ()
    order: torch.Tensor = torch.empty(n_columns, dtype=torch.int64, device=device)  # (C), int64
    mask: torch.Tensor = torch.ones(n_columns, dtype=torch.bool, device=device)  # (C), bool
    hessian: torch.Tensor = hessian.clone()  # (C, C)
    for k in range(n_columns):
        order[k] = j = torch.where(mask, hessian.diagonal(), inf).argmin()  # ()
        h_tmp: torch.Tensor = hessian[j, :] * hessian[j, j] ** -.5  # (C)
        hessian -= h_tmp[:, None] * h_tmp  # (C, C)
        mask[j] = False  # ()
    return order  # (C), int64


def min_pivot_order(
        hessian: torch.Tensor,
        output: torch.Tensor = None,
        direct: bool = False,
        debug_mode: bool = False,
) -> torch.Tensor:
    """
    CUDA Graph wrapper

    Args:
        hessian: (C, C)

    Returns:
        order: (C), int64
    """

    if debug_mode:
        return min_pivot_order_baseline(hessian=hessian, output=output)

    if direct:
        assert hessian.is_contiguous()
        return min_pivot_order_triton(hessian=hessian, output=output)

    n_columns: int = hessian.size(-1)
    dtype: torch.dtype = hessian.dtype
    device: torch.device = hessian.device

    previous_device: torch.device = torch.device(f'cuda:{torch.cuda.current_device()}')
    torch.cuda.set_device(hessian.device)
    if not hasattr(min_pivot_order, 'graph_info'):
        min_pivot_order.graph_info = {}
    graph_key: tuple = n_columns, dtype, device
    if graph_key not in min_pivot_order.graph_info:
        graph: torch.cuda.CUDAGraph = torch.cuda.CUDAGraph()
        graph_tensors: dict[str, torch.Tensor] = {
            'hessian': torch.empty_like(hessian.contiguous()),
            'output': torch.empty(n_columns, dtype=torch.int64, device=device),
        }
        n_warmups: int = 5
        s: torch.cuda.Stream = torch.cuda.Stream()
        s.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(s):
            for _ in range(n_warmups):
                min_pivot_order(**graph_tensors, direct=True)
        torch.cuda.current_stream().wait_stream(s)
        with torch.cuda.graph(graph):
            min_pivot_order(**graph_tensors, direct=True)
        min_pivot_order.graph_info[graph_key] = {'graph': graph, 'tensors': graph_tensors}

    graph, graph_tensors = min_pivot_order.graph_info[graph_key]['graph'], min_pivot_order.graph_info[graph_key]['tensors']
    graph_tensors['hessian'].copy_(hessian)
    graph.replay()
    torch.cuda.set_device(previous_device)
    return graph_tensors['output'].clone()


def _get_random_inputs(size: int = 512, seed: int = 0) -> torch.Tensor:
    torch.manual_seed(seed)
    dtype: torch.dtype = torch.float32
    device: torch.device = torch.device('cuda')
    hessian = torch.randn(size * 2, size, dtype=dtype, device=device)
    hessian = hessian.t() @ hessian
    return hessian


def _unit_test() -> None:
    hessian = _get_random_inputs(size=1024)
    order_ref = min_pivot_order_reference(hessian)
    print(order_ref.tolist())
    order_base = min_pivot_order_baseline(hessian)
    print(order_base.tolist())
    print(order_base.equal(order_ref))
    order = min_pivot_order(hessian)
    print(order.tolist())
    print(order.equal(order_ref))

    # min_pivot_order_compiled(torch.randn_like(hessian))
    # order_compiled = min_pivot_order_compiled(hessian)
    # print(order_compiled.tolist())
    # print(order_compiled.equal(order_ref))


def _basic_benchmark(f, n_repeats: int = 2000, n_warmup_repeats: int = 1000, measure_latency: bool = False) -> float:
    if measure_latency:
        for _ in range(n_warmup_repeats):
            f()
            torch.cuda.synchronize()
    else:  # measure throughput
        for _ in range(n_warmup_repeats):
            f()
        torch.cuda.synchronize()
    if measure_latency:
        t_start: float = time.perf_counter()
        for _ in range(n_repeats):
            f()
            torch.cuda.synchronize()
        t_end: float = time.perf_counter()
    else:  # measure throughput
        t_start: float = time.perf_counter()
        for _ in range(n_repeats):
            f()
        torch.cuda.synchronize()
        t_end: float = time.perf_counter()
    t: float = (t_end - t_start) / n_repeats
    return t


def _benchmark() -> None:
    hessian: torch.Tensor = _get_random_inputs(size=4096)
    t_base: float = _basic_benchmark(lambda hessian=hessian: min_pivot_order_baseline(hessian), n_repeats=5, n_warmup_repeats=5)
    t: float = _basic_benchmark(lambda hessian=hessian: min_pivot_order(hessian), n_repeats=5, n_warmup_repeats=5)
    print(t_base, t)


if __name__ == '__main__':
    _unit_test()
    _benchmark()
