import os
import warnings
from typing import Dict, List, Literal, Optional, Tuple

import numba
import numpy as np
import torch
import triton
from torch import Tensor
from triton import cdiv as cdiv_python

from hip_attn.v1_3.metadata import HiPAttentionArgs, safe_stride

from .block_sparse_attention_bwd import (
    block_sparse_attention_backward,
    block_sparse_attention_backward_preprocess,
)
from .block_sparse_attention_fwd import block_sparse_attention_cuda

TRITON_DEBUG = os.getenv("TRITON_DEBUG", "0") == "1"


@numba.njit
def histc(values: np.ndarray, max: int):
    BSZ, NNZ = values.shape
    output = np.zeros((BSZ, max + 1), dtype=np.int64)
    for idx_bsz in numba.prange(BSZ):
        hist_row = np.zeros((max + 1), dtype=np.int64)
        for idx_z in range(NNZ):
            v = values[idx_bsz, idx_z]
            assert v >= 0
            if v < hist_row.shape[0]:
                hist_row[v + 1] += 1
        output[idx_bsz] = hist_row
    return output


def forward(
    q: Tensor,
    k: Optional[Tensor],
    v: Optional[Tensor],
    seq_lens: Tensor,
    indices: Tensor,
    ks: Tensor,
    ks_count: Tensor,
    ks_start_end: Tensor,
    args: "HiPAttentionArgs",
    access_counter: Tensor,
    cache_miss_counter: Tensor,
    extend_backend: str,
    model_context_length: int,
    extend_context_length: int,
    offload_update_cache: bool,
    return_attention_scores: bool,
    output_attention_score_reduce_method: Literal["mean", "max"],
    use_torch_fwd: bool = False,
    batch_chunk_size: Optional[int] = None,
):
    assert output_attention_score_reduce_method in ["mean", "max"]

    BSZ, TDST, HEAD, HID = q.shape
    
    # FIXME: without this, the decoding result is broken. 
    #        potentially the padding in Q dim is broken.
    args.block_size_q = min(
        args.block_size_q, 
        triton.next_power_of_2(TDST)
    )
    
    if k is not None:
        _, TSRC, KV_HEAD, _ = k.shape
        BSRC = cdiv_python(TSRC, args.block_size_k)
        MAX_TSRC = TSRC
        MAX_BSRC = BSRC
    else:
        if args.k_cache is not None:
            NUM_PAGE, PAGE_SIZE, KV_HEAD, _ = args.k_cache.shape
        else:
            KV_HEAD = args.offload_cache.k_uvm.bank_cpu.shape[-2]
        TSRC = None
        BSRC = None
        # MAX_TSRC = NUM_PAGE * PAGE_SIZE
        MAX_TSRC = extend_context_length
        MAX_BSRC = cdiv_python(MAX_TSRC, args.block_size_k)
    N = BSZ * HEAD
    # assert q.shape == k.shape
    BDST = cdiv_python(TDST, args.block_size_q)
    KV_HEAD_REPEAT = HEAD // KV_HEAD
    assert KV_HEAD_REPEAT * KV_HEAD == HEAD

    B = N
    assert B == N
    BK = indices.shape[-1]  # cdiv_python(args.mask_k, args.block_size_k)

    if return_attention_scores:
        output_scores = torch.full(
            (
                BSZ,
                indices.shape[1],
                HEAD,
                indices.shape[2],
            ),
            device=q.device,
            dtype=torch.float32,
            fill_value=float("-inf"),
        )
    else:
        output_scores = None

    score_maximum = torch.zeros(
        (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
    )

    context = torch.empty(q.shape, dtype=q.dtype, device=q.device)

    if q.dtype == torch.float32:
        max_block_size = int(os.getenv("SA_BLOCK_SIZE", "16"))
    else:
        max_block_size = int(os.getenv("SA_BLOCK_SIZE", "32"))
    BLOCK_BK = max_block_size // args.block_size_k
    BLOCK_BK = max(1, min(max_block_size, BLOCK_BK))
    if "SA_BLOCK_BK" in os.environ:
        BLOCK_BK = int(os.environ["SA_BLOCK_BK"])

    assert BLOCK_BK > 0, BLOCK_BK

    # sliding_window_size = min(sliding_window_size, block_size_k * 16)

    if args.rope_cos is not None:
        assert len(args.rope_cos.stride()) == 2
        assert len(args.rope_sin.stride()) == 2

    assert context.ndim == 4
    if ks_start_end is not None:
        assert ks_start_end.ndim == 3
    if indices is not None:
        assert indices.ndim == 3
    assert q.ndim == 4
    if k is not None:
        assert k.ndim == 4
        assert v.ndim == 4
    elif args.using_paged_cache:
        if args.k_cache is not None:
            assert args.k_cache.ndim == 4
            assert args.v_cache.ndim == 4
        else:
            assert args.offload_cache.k_uvm.bank_cpu.ndim == 3
            assert args.offload_cache.v_uvm.bank_cpu.ndim == 3
    else:
        raise Exception()
    assert seq_lens.ndim == 2

    grid = (HEAD, BDST, BSZ)
    pre_device = torch.get_default_device()
    torch.set_default_device(q.device)

    need_benchmark = os.getenv("NEED_BENCHMARK", "0") == "1"

    if need_benchmark:
        start_event = torch.cuda.Event(True)
        start_event.record()
        
    assert args.sink_token_size is not None
    assert args.sliding_window_size is not None, args.sliding_window_size
    assert args.sm_scale is not None

    context, output_scores, score_maximum = block_sparse_attention_cuda.forward(
    # block_sparse_attention_cuda[grid](
        grid,
        q,
        *safe_stride(q, 4),
        k,
        *safe_stride(k, 4),
        v,
        *safe_stride(v, 4),
        seq_lens,
        *safe_stride(seq_lens, 2),
        indices,
        *safe_stride(indices, 3),
        ks_start_end,
        *safe_stride(ks_start_end, 3),
        context,
        *safe_stride(context, 4),
        output_scores,
        *safe_stride(output_scores, 4),
        output_attention_score_reduce_method,
        score_maximum,
        *safe_stride(score_maximum, 3),
        HEAD,
        BK,
        TDST,
        MAX_TSRC,
        KV_HEAD_REPEAT,
        args.sliding_window_size,
        args.sink_token_size,
        args.logit_softcap,
        *args.args_extend(),
        model_context_length,
        *args.args_paged_kv_cache(),
        *args.args_offload_cache(is_masking=False),
        access_counter,
        *safe_stride(access_counter, 3),
        cache_miss_counter,
        *safe_stride(cache_miss_counter, 3),
        args.sm_scale,
        args.is_causal,
        args.block_size_q,
        args.block_size_k,
        HID,
        # 2,
        BLOCK_BK=BLOCK_BK,
        EXTEND_BACKEND=extend_backend,
        UPDATE_CACHE=offload_update_cache,
        # num_warps=4,
        # num_stages=2 if not using_extend else 1,
        use_torch_fwd=use_torch_fwd,
        batch_chunk_size=batch_chunk_size,
    )
    torch.set_default_device(pre_device)

    if need_benchmark:
        end_event = torch.cuda.Event(True)
        end_event.record()
        end_event.synchronize()

        print(
            f"bwd fwd called {q.shape} {k.shape} {v.shape} {indices.shape} {_attention.call_id} {start_event.elapsed_time(end_event)}"
        )

    return context, output_scores, score_maximum


class _attention(torch.autograd.Function):
    call_id: int = 0

    @staticmethod
    def forward(
        ctx,
        q: Tensor,
        k: Optional[Tensor],
        v: Optional[Tensor],
        seq_lens: Tensor,
        indices: Tensor,
        ks: Tensor,
        ks_count: Tensor,
        ks_start_end: Tensor,
        args: "HiPAttentionArgs",
        access_counter: Tensor,
        cache_miss_counter: Tensor,
        extend_backend: str,
        model_context_length: int,
        extend_context_length: int,
        offload_update_cache: bool,
        return_attention_scores: bool,
        output_attention_score_reduce_method: Literal["mean", "max"],
    ):
        indices = torch.where(indices >= 0, indices, 987654321)
        if TRITON_DEBUG: assert not torch.any(indices < 0)
        
        context, output_scores, score_maximum = forward(
            q,
            k,
            v,
            seq_lens,
            indices,
            ks,
            ks_count,
            ks_start_end,
            args,
            access_counter,
            cache_miss_counter,
            extend_backend,
            model_context_length,
            extend_context_length,
            offload_update_cache,
            return_attention_scores,
            output_attention_score_reduce_method,
        )

        BSZ, TDST, HEAD, HID = q.shape

        """
        import matplotlib.pyplot as plt
        print(output_scores[0, :, 0, :])
        plt.imshow(output_scores[0, :, 0, :].cpu().numpy())
        plt.colorbar()
        plt.savefig('dummy.png')
        plt.clf()
        input('>>>')
        """

        args = args.clone()
        ctx.args = args

        if args.block_sparse_bwd_block_size_q is None:
            device_name = torch.cuda.get_device_name()
            args.block_sparse_bwd_block_size_q = {
                "NVIDIA A100-SXM4-80GB": min(64, args.block_sparse_block_size_q),
            }.get(device_name, 16)

        if (args.block_size_q > args.block_sparse_bwd_block_size_q) & (
            triton.cdiv(TDST, args.block_sparse_bwd_block_size_q)
            != triton.cdiv(TDST, args.block_size_q)
        ):
            # NOTE: break-down to fit BSA block size
            assert (args.block_size_q % args.block_sparse_bwd_block_size_q) == 0
            new_bdst = triton.cdiv(TDST, args.block_sparse_bwd_block_size_q)
            indices = indices.repeat_interleave(
                args.block_size_q // args.block_sparse_bwd_block_size_q, 1
            )[:, :new_bdst].contiguous()
            ks = ks.repeat_interleave(
                args.block_size_q // args.block_sparse_bwd_block_size_q, 1
            )[:, :new_bdst].contiguous()
            ks_count = ks_count.repeat_interleave(
                args.block_size_q // args.block_sparse_bwd_block_size_q, 1
            )[:, :new_bdst].contiguous()
            ks_start_end = ks_start_end.repeat_interleave(
                args.block_size_q // args.block_sparse_bwd_block_size_q, 1
            )[:, :new_bdst].contiguous()
            args.block_size_q = args.block_sparse_bwd_block_size_q

        ctx.save_for_backward(
            q,
            k,
            v,
            context,
            score_maximum,
            indices,
            seq_lens,
        )

        ctx.id = _attention.call_id
        _attention.call_id += 1

        if TRITON_DEBUG:
            assert not torch.any(torch.isnan(context))
            assert not torch.any(torch.isinf(context))
            if output_scores is not None:
                assert not torch.any(torch.isnan(output_scores))
                # assert not torch.any(torch.isinf(output_scores))

            # print(f'BSA fwd {ctx.id}, {q.shape=}, {k.shape=}, {v.shape=}, {indices.shape=}, {score_maximum.shape=}, {seq_lens.shape=}, {args.rope_cos.shape=}, {args.rope_sin.shape=}')

        return context, output_scores

    @staticmethod
    def backward(  # noqa
        ctx,
        grad_context: Tensor,
        grad_score: Tensor,
    ):
        pre_device = torch.get_default_device()
        torch.set_default_device(grad_context.device)

        (q, k, v, context, score_maximum, indices, seq_lens) = ctx.saved_tensors
        args = ctx.args  # type: HiPAttentionArgs

        assert k is not None
        assert v is not None
        assert k.shape == v.shape
        assert q.shape[-1] == k.shape[-1]
        assert k.ndim == 4
        assert q.ndim == 4
        if args.using_extend:
            assert args.need_apply_rope
            assert args.sa_extend_backend == "streaming"

        def norm_clip(grad: torch.Tensor):
            if os.getenv("DISABLE_NORM_CLIP", "0") == "1":
                return grad
            norm = grad.view(-1).norm()
            new_grad = (grad / (norm + 1e-20)) * norm.clamp_max(1)
            new_norm = new_grad.view(-1).norm()
            # print(norm, new_norm)
            return new_grad
    
        grad_context = norm_clip(grad_context.contiguous())

        grad_q = torch.zeros_like(q, dtype=torch.float32)
        grad_k = torch.zeros_like(k, dtype=torch.float32)
        grad_v = torch.zeros_like(v, dtype=torch.float32)

        BSZ, TDST, HEAD, HID = q.shape
        _, TSRC, HEAD_KV, HID = k.shape
        assert v.shape == (BSZ, TSRC, HEAD_KV, HID)
        # assert TDST >= TSRC

        BLOCK_PREPROCESS = 128
        assert (TDST % BLOCK_PREPROCESS) == 0
        grid = (
            TDST // BLOCK_PREPROCESS,
            BSZ * HEAD,
        )
        delta = torch.empty_like(score_maximum, dtype=torch.float32)
        block_sparse_attention_backward_preprocess[grid](
            context,
            *safe_stride(context, 4),
            grad_context,
            *safe_stride(grad_context, 4),
            delta,
            *safe_stride(delta, 3),
            BSZ,
            TDST,
            HEAD,
            BLOCK_TDST=BLOCK_PREPROCESS,
            HEAD_DIM=HID,
        )

        BLOCK_TDST = args.block_size_q
        assert (TDST % BLOCK_TDST) == 0

        score_maximum = torch.where(
            ~(torch.isnan(score_maximum) | torch.isinf(score_maximum)),
            score_maximum,
            0.0,
        )

        # NOTE: we need to transpose indices for Q scan process. KV index -> Q indices
        # Lets assume the OOM indices are marked with large numbers
        # Format is different

        N_BLOCKS_KV = k.shape[1] // args.block_size_k
        assert (k.shape[1] % args.block_size_k) == 0
        nnz = indices.shape[1] * indices.shape[2]

        indices_x = indices.clone()
        indices_ix = torch.arange(0, indices.shape[2], device=q.device)[
            None, None, :
        ].repeat(indices.shape[0], indices.shape[1], 1)
        indices_y = (
            torch.arange(0, indices.shape[1], device=q.device) * args.block_size_q
        )[None, :, None].repeat(indices.shape[0], 1, indices.shape[2])
        indices_x = indices_x.flatten(-2, -1)
        indices_ix = indices_ix.flatten(-2, -1)
        indices_y = indices_y.flatten(-2, -1)
        indices_y = torch.where(indices_x < TSRC, indices_y, TSRC * TDST * 2)
        indices_x, sort_map = torch.sort(indices_x, dim=-1)
        indices_ix = torch.gather(indices_ix, dim=-1, index=sort_map)
        indices_y = torch.gather(indices_y, dim=-1, index=sort_map)
        # FIXME: need fix to use GPU
        indices_bx_hist = torch.tensor(
            histc(
                (indices_x // args.block_size_k).cpu().numpy(),
                N_BLOCKS_KV,
            )
        ).to(indices.device)

        indices_colmajor_colstarts = indices_bx_hist.cumsum(dim=-1)
        assert indices_colmajor_colstarts.shape == (
            indices.shape[0],
            N_BLOCKS_KV + 1,
        ), indices_colmajor_colstarts.shape
        assert (
            indices_colmajor_colstarts.dtype == indices.dtype
        ), indices_colmajor_colstarts.dtype
        indices_colmajor_rowis = indices_y
        assert indices_colmajor_rowis.shape == (
            indices.shape[0],
            nnz,
        ), indices_colmajor_rowis.shape
        indices_colmajor_ixs = indices_ix
        assert indices_colmajor_ixs.shape == (
            indices.shape[0],
            nnz,
        ), indices_colmajor_ixs.shape

        """
        # Visualize colmajor sparse matrix

        mat = np.zeros((TDST, TSRC), dtype=np.uint8)
        for idx_tsrc in range(TSRC):
            istart = indices_colmajor_colstarts[0, idx_tsrc // args.block_size_k]
            iend = indices_colmajor_colstarts[0, idx_tsrc // args.block_size_k + 1]
            for idx_z in range(istart, iend):
                idx_tdst = indices_colmajor_rowis[0, idx_z]
                mat[
                    idx_tdst:idx_tdst+args.block_size_q,
                    idx_tsrc:idx_tsrc+args.block_size_k
                ] = 255
        import cv2
        cv2.imwrite('dummy.png', mat)

        print(indices_colmajor_colstarts)
        print(indices_colmajor_rowis)
        input('>>>')

        """

        # NOTE: somehow we need to prescale the loss
        RCP_LN2 = 1.4426950408889634  # = 1.0 / ln(2)
        arg_k = k
        arg_k = arg_k * (args.sm_scale * RCP_LN2)

        device_name = torch.cuda.get_device_name()
        BLOCK_COLWISE_BK = max(
            1,
            {
                "NVIDIA A100-SXM4-80GB": 64,
                "NVIDIA H200": 32,
            }.get(device_name, 16)
            // args.block_size_k,
        )
        # BLOCK_COLWISE_BK = 1

        assert seq_lens is not None

        # if TRITON_DEBUG:
        #     print(f'BSA bwd start {ctx.id}')

        # print(args.sm_scale)

        need_benchmark = os.getenv("NEED_BENCHMARK", "0") == "1"

        for _ in range(5 if need_benchmark else 1):
            if need_benchmark:
                start_event = torch.cuda.Event(enable_timing=True)
                start_event.record()

            grid = (
                max(
                    TDST // args.block_size_q,
                    TSRC // (args.block_size_k * BLOCK_COLWISE_BK),
                ),
                BSZ * HEAD,
            )
            block_sparse_attention_backward[grid](
                q,
                *safe_stride(q, 4),
                arg_k,
                *safe_stride(arg_k, 4),
                v,
                *safe_stride(v, 4),
                args.rope_cos,
                *safe_stride(args.rope_cos, 2),
                args.rope_sin,
                *safe_stride(args.rope_sin, 2),
                grad_context,
                *safe_stride(grad_context, 4),
                grad_q,
                *safe_stride(grad_q, 4),
                grad_k,
                *safe_stride(grad_k, 4),
                grad_v,
                *safe_stride(grad_v, 4),
                score_maximum,
                *safe_stride(score_maximum, 3),
                delta,
                *safe_stride(delta, 3),
                seq_lens - 1,
                *safe_stride(seq_lens, 2),
                indices,
                *safe_stride(indices, 3),
                indices.shape[-1],
                max(1, 32 // args.block_size_k),
                indices_colmajor_colstarts,
                *safe_stride(indices_colmajor_colstarts, 2),
                indices_colmajor_rowis,
                *safe_stride(indices_colmajor_rowis, 2),
                indices_colmajor_ixs,
                *safe_stride(indices_colmajor_ixs, 2),
                indices_colmajor_rowis.shape[-1],
                args.sm_scale,
                HEAD,
                HEAD // HEAD_KV,
                TDST,
                TSRC,
                args.sink_token_size,
                args.sliding_window_size,
                q.shape[-1],
                args.using_extend,
                args.block_size_q,
                args.block_size_k,
                BLOCK_COLWISE_BK,
            )

            if need_benchmark:
                end_event = torch.cuda.Event(enable_timing=True)
                end_event.record()
                end_event.synchronize()

                print(
                    "bwd called",
                    q.shape,
                    k.shape,
                    v.shape,
                    indices.shape,
                    indices_colmajor_rowis.shape,
                    grid,
                    ctx.id,
                    start_event.elapsed_time(end_event),
                )

        # print(grad_k.mean())
        # print(grad_v.mean())

        # grad_q.zero_()
        # grad_k.zero_()
        # grad_v.zero_()

        torch.set_default_device(pre_device)

        # print('hello')
    
        grad_q = norm_clip(grad_q)
        grad_k = norm_clip(grad_k)
        grad_v = norm_clip(grad_v)

        # if q.shape[:2] != k.shape[:2] or True:
        #     grad_q.zero_()
        #     grad_k.zero_()
        #     grad_v.zero_()
        #     print(f'zeroed {q.shape=} {k.shape=} {v.shape=}')
        # else:
        #     print(f'pass {q.shape=} {k.shape=} {v.shape=}')

        # print('norm o', grad_context.contiguous().view(-1).norm().item())
        # print('norm q', grad_q.view(-1).norm().item())
        # print('norm k', grad_k.view(-1).norm().item())
        # print('norm v', grad_v.view(-1).norm().item())

        grad_q = grad_q.to(q.dtype)
        grad_k = grad_k.to(k.dtype)
        grad_v = grad_v.to(v.dtype)

        if TRITON_DEBUG:
            # print(
            #     args.rope_cos.shape,
            #     args.rope_cos.stride(),
            #     args.rope_sin.shape,
            #     args.rope_sin.stride(),
            # )
            if torch.any(torch.isnan(grad_q)):
                print(grad_q)
                assert False, "grad_q"
            if torch.any(torch.isnan(grad_k)):
                print(grad_k)
                print(grad_k.isnan().nonzero(), grad_k.isnan().nonzero().shape)
                assert False, "grad_k"
            if torch.any(torch.isnan(grad_v)):
                print(grad_v)
                assert False, "grad_v"
            assert not torch.any(torch.isinf(grad_q))
            assert not torch.any(torch.isinf(grad_k))
            assert not torch.any(torch.isinf(grad_v))

        return (
            # q: Tensor,
            grad_q,
            # k: Optional[Tensor],
            grad_k,
            # v: Optional[Tensor],
            grad_v,
            # seq_lens: Tensor,
            None,
            # indices: Tensor,
            None,
            # ks: Tensor,
            None,
            # ks_count: Tensor,
            None,
            # ks_start_end: Tensor,
            None,
            # args: "HiPAttentionArgs",
            None,
            # access_counter: Tensor,
            None,
            # cache_miss_counter: Tensor,
            None,
            # extend_backend: str = DEFAULT_EXTEND_BACKEND,
            None,
            # model_context_length: int = 131072,
            None,
            # extend_context_length: int = 131072,
            None,
            # offload_update_cache: bool = False,
            None,
            # return_attention_scores: bool,
            None,
            # output_attention_score_reduce_method: Literal['mean', 'max'],
            None,
        )


def block_sparse_attention(
    q: Tensor,
    k: Optional[Tensor],
    v: Optional[Tensor],
    seq_lens: Tensor,
    indices: Tensor,
    ks: Tensor,
    ks_count: Tensor,
    ks_start_end: Tensor,
    args: "HiPAttentionArgs",
    access_counter: Tensor,
    cache_miss_counter: Tensor,
    extend_backend: str = "streaming",
    model_context_length: int = 131072,
    extend_context_length: int = 131072,
    offload_update_cache: bool = False,
    return_attention_scores: bool = False,
    output_attention_scores_reduce_method: Literal["mean", "max"] = "max",
):
    return _attention.apply(
        q,
        k,
        v,
        seq_lens,
        indices,
        ks,
        ks_count,
        ks_start_end,
        args,
        access_counter,
        cache_miss_counter,
        extend_backend,
        model_context_length,
        extend_context_length,
        offload_update_cache,
        return_attention_scores,
        output_attention_scores_reduce_method,
    )
