import os
from typing import Any, Optional

import einx
from matplotlib import pyplot as plt
import torch
from torch import Tensor
import triton

from .kernels.block_sparse_attention import block_sparse_attention
from .kernels.flash_attention import flash_attention
from .metadata import (
    HiPAttentionArgs,
    HiPAttentionCacheAccessStatistics,
    HiPAttentionOutputMetadata,
)
import cv2
import numpy as np

HIP_DEBUG = os.getenv('HIP_DEBUG', '0') == '1'
TRITON_DEBUG = os.getenv("TRITON_DEBUG", "0") == "1"
MAXIMUM_INDEX = 987654321


def hip_tune_attention(
    q: Tensor,
    k: Optional[Tensor],
    v: Optional[Tensor],
    args: HiPAttentionArgs,
    cached_metadata: Optional[HiPAttentionOutputMetadata] = None,
    kwargs: Any = None,
):
    assert args.enable_hip_tune
    assert args.token_pooler_fn is not None
    assert args.output_unpooler_fn is not None
    assert args.gate_probs is not None

    BSZ, TDST, HEAD, HID = q.shape
    if k is not None:
        BSZ, TSRC, HEAD_KV, HID = k.shape
        assert v.shape == k.shape
        MAX_TSRC = TSRC
    else:
        MAX_TSRC = args.extend_context_length
        if args.k_cache is not None:
            HEAD_KV = args.k_cache.shape[-2]
        else:
            HEAD_KV = args.offload_cache.k_uvm.bank_cpu.shape[-2]
        TSRC = MAX_TSRC

    if args.sm_scale is None:
        args.sm_scale = 1 / (q.shape[-1] ** 0.5)

    if torch.cuda.is_current_stream_capturing() or args.position_ids is not None:
        assert args.position_ids is not None
        position_ids = args.position_ids
    else:
        position_ids = torch.arange(0, TDST, device=q.device) + (TSRC - TDST)
        position_ids = position_ids[None, :].expand(BSZ, TDST)

    if args.using_paged_cache:
        MAX_PAGE = args.paged_cache_page_count
    else:
        MAX_PAGE = MAX_TSRC

    if args.require_cache_statistics:
        mask_access_counter = torch.zeros(
            (BSZ, HEAD_KV, MAX_PAGE), dtype=torch.int32, device=q.device
        )
        mask_cache_miss_counter = torch.zeros(
            (BSZ, HEAD_KV, MAX_PAGE), dtype=torch.int32, device=q.device
        )
        sa_access_counter = torch.zeros(
            (BSZ, HEAD_KV, MAX_PAGE), dtype=torch.int32, device=q.device
        )
        sa_cache_miss_counter = torch.zeros(
            (BSZ, HEAD_KV, MAX_PAGE), dtype=torch.int32, device=q.device
        )
    else:
        sa_cache_miss_counter = sa_access_counter = mask_cache_miss_counter = (
            mask_access_counter
        ) = None

    # TODO: stopped here. kwargs should be throughput to all funcs here (I think)
    stage_caches = []
    compressed_outputs = []
    comp_indices = None
    prev_block_size_q = prev_block_size_k = None
    for i_stage, stage_info in enumerate(args.stages):
        block_size_k = stage_info.stage_chunk_size
        block_size_q = stage_info.stage_block_size_q
        next_topk = (
            args.stages[i_stage + 1].stage_k
            if i_stage + 1 < len(args.stages)
            else args.second_stage_k
        )
        cur_topk = stage_info.stage_k

        # NOTE: need to handle pooled paged cache

        if i_stage == 0:
            comp_output, comp_indices = hip_first_stage(
                q=q,
                k=k,
                v=v,
                position_ids=position_ids,
                rope_cos=args.rope_cos,
                rope_sin=args.rope_sin,
                token_pooler_fn=args.token_pooler_fn,
                output_unpooler_fn=args.output_unpooler_fn,
                next_topk=next_topk,
                block_size_q=block_size_q,
                block_size_k=block_size_k,
                sm_scale=args.sm_scale,
                args=args,
            )
            if TRITON_DEBUG:
                assert not torch.any(torch.isnan(comp_output))
            compressed_outputs.append(comp_output)
        else:
            comp_output, comp_indices = hip_stage(
                q=q,
                k=k,
                v=v,
                position_ids=position_ids,
                comp_indices=comp_indices,
                rope_cos=args.rope_cos,
                rope_sin=args.rope_sin,
                token_pooler_fn=args.token_pooler_fn,
                output_unpooler_fn=args.output_unpooler_fn,
                prev_block_size_q=prev_block_size_q,
                prev_block_size_k=prev_block_size_k,
                block_size_q=block_size_q,
                block_size_k=block_size_k,
                cur_topk=cur_topk,
                next_topk=next_topk,
                sm_scale=args.sm_scale,
                i_stage=i_stage,
                args=args,
            )
            if TRITON_DEBUG:
                assert not torch.any(torch.isnan(comp_output))
            compressed_outputs.append(comp_output)

        if HIP_DEBUG:
            from hip_attn.v1_2.attention_extend import render_plot_dynamic

            out_indices_cpu = (
                (comp_indices * block_size_k)
                .permute(0, 2, 1, 3)
                .contiguous()
                .cpu()
                .numpy()
            )
            debug = np.zeros(
                (triton.cdiv(TDST, args.block_size_q), triton.cdiv(TSRC, args.block_size_q))
            )
            render_plot_dynamic(
                out_indices_cpu,
                debug,
                0,
                args.block_size_q,
                out_indices_cpu.shape[-1] * block_size_k,
                block_size_k,
                causal_mask=True,
                sliding_window_size=args.sliding_window_size,
            )
            cv2.imwrite(f"dummy_sampled_stage_{i_stage}.png", debug * 255)

        prev_block_size_q = block_size_q
        prev_block_size_k = block_size_k

    # Final block sparse attention
    indices = torch.where(
        comp_indices < MAXIMUM_INDEX,
        comp_indices * prev_block_size_k,
        MAXIMUM_INDEX,
    )
    indices = torch.sort(indices, dim=-1).values
    indices = torch.where(
        indices < k.shape[1], 
        indices, 
        MAXIMUM_INDEX,
    )

    indices = indices.flatten(0, 1)
    ks = (indices < MAXIMUM_INDEX).sum(-1)
    ks_count = ks.unsqueeze(-1)
    ks_start_end = torch.stack([torch.zeros_like(ks), ks], dim=-1)

    args = args.clone()
    args.block_size_q = prev_block_size_q
    args.block_size_k = prev_block_size_k
    assert args.block_sparse_block_size_q == prev_block_size_q

    if os.environ.get("HIP_RETURN_KWARGS_ONLY", "0") == "1":
        kwargs = {
            "seq_lens": position_ids[:, -q.shape[1] :] + 1,
            "indices": indices,
            "ks": ks,
            "ks_count": ks_count,
            "ks_start_end": ks_start_end,
            "args": args,
            "access_counter": sa_access_counter,
            "cache_miss_counter": sa_cache_miss_counter,
            "extend_backend": args.sa_extend_backend,  # streaming works way much better in Gemma2, than dynamic_extend
            "model_context_length": args.model_context_length,
            "extend_context_length": args.extend_context_length,
            "offload_update_cache": (cached_metadata is None)
            and args.online_update_cache,
            "return_attention_scores": False,
        }
        return kwargs
    elif os.environ.get("HIP_TESTING", "0") == "1":
        context, _ = block_sparse_attention(
            q=q,
            k=k,
            v=v,
            **kwargs,
        )
    else:
        if os.getenv("CHECKOUT_BSA_ARGS", "0") == "1":
            args.token_pooler_fn = None
            args.output_unpooler_fn = None
            torch.save(
                dict(
                    q=q,
                    k=k,
                    v=v,
                    seq_lens=position_ids[:, -q.shape[1] :] + 1,
                    indices=indices,
                    ks=ks,
                    ks_count=ks_count,
                    ks_start_end=ks_start_end,
                    args=args,
                    access_counter=sa_access_counter,
                    cache_miss_counter=sa_cache_miss_counter,
                    extend_backend=args.sa_extend_backend,  # streaming works way much better in Gemma2, than dynamic_extend
                    model_context_length=args.model_context_length,
                    extend_context_length=args.extend_context_length,
                    offload_update_cache=(cached_metadata is None)
                    and args.online_update_cache,
                    return_attention_scores=False,
                ),
                "bsa_args.pt",
            )
            exit(0)

        # indices[:, -1, :] = MAXIMUM_INDEX
        ks = (
            indices 
            < (position_ids[:, ::block_size_q, None] + block_size_q)
        ).to(ks.dtype).sum(dim=-1)
        ks_count = ks.unsqueeze(-1)
        ks_start_end[:, :, -1] = ks

        if HIP_DEBUG:
            from hip_attn.v1_2.attention_extend import render_plot

            out_indices_cpu = indices\
                .reshape(BSZ, HEAD, indices.shape[1], indices.shape[2])\
                .permute(0, 2, 1, 3)\
                .contiguous()\
                .cpu()\
                .numpy()
            # print(out_indices_cpu[0, -1, 0, :])
            debug = np.zeros(
                (triton.cdiv(TDST, args.block_size_q), triton.cdiv(TSRC, args.block_size_q))
            )
            render_plot(out_indices_cpu, debug, 0, args.block_size_q)
            cv2.imwrite("dummy_sampled_final.png", debug * 255)
            if q.shape[1] == 1:
                cmd = input(f'{q.shape=} {k.shape=} {v.shape=} >>> ')
                if cmd.lower().strip() == 'pdb':
                    breakpoint()
            
        # from hip_attn.v1_2.attention_extend import block_sparse_attention as bsa_legacy
        
        context, _ = block_sparse_attention(
            q=q,
            k=k,
            v=v,
            seq_lens=position_ids[:, -q.shape[1] :] + 1,
            indices=indices,
            ks=ks,
            ks_count=ks_count,
            ks_start_end=ks_start_end,
            args=args,
            access_counter=sa_access_counter,
            cache_miss_counter=sa_cache_miss_counter,
            extend_backend=args.sa_extend_backend,  # streaming works way much better in Gemma2, than dynamic_extend
            model_context_length=args.model_context_length,
            extend_context_length=args.extend_context_length,
            offload_update_cache=(cached_metadata is None) and args.online_update_cache,
            return_attention_scores=False,
        )

    if TRITON_DEBUG:
        assert not torch.any(torch.isnan(context))
    compressed_outputs.append(context)

    context = torch.stack(compressed_outputs, dim=-2)
    context = (args.gate_probs.unsqueeze(-1) * context).sum(dim=-2)

    return context, HiPAttentionOutputMetadata(
        indices=indices,
        ks=ks,
        ks_count=ks_count,
        ks_start_end=ks_start_end,
        mask_cache_statistics=(
            HiPAttentionCacheAccessStatistics(
                access_counter=mask_access_counter,
                cache_miss_counter=mask_cache_miss_counter,
            )
            if (cached_metadata is None) or (cached_metadata.indices is None)
            else None
        ),
        sa_cache_statistics=HiPAttentionCacheAccessStatistics(
            access_counter=sa_access_counter,
            cache_miss_counter=sa_cache_miss_counter,
        ),
        stage_caches=stage_caches,
    )

def token_sequence_pooling(
    x: torch.Tensor, 
    token_pooler_fn, 
    block_size: int, 
    tensor_type: str
):
    T = x.shape[1]
    if (T % block_size) == 0:
        x_stage = x.view(
            x.shape[0],
            x.shape[1] // block_size,
            block_size,
            x.shape[2],
            x.shape[3],
        )
        x_stage = token_pooler_fn(x_stage, dim=2, i_stage=0, tensor_type=tensor_type)
        return x_stage
    elif (T < block_size):
        x_stage = x_stage.unsqueeze(1)
        x_stage = token_pooler_fn(x_stage, dim=2, i_stage=0, tensor_type=tensor_type)
        return x_stage
    else:
        left_over = T % block_size
        
        x_pack = x[:, :-left_over, :, :].contiguous()
        x_pack_stage = x_pack.view(
            x_pack.shape[0],
            x_pack.shape[1] // block_size,
            block_size,
            x_pack.shape[2],
            x_pack.shape[3],
        )
        x_pack_stage = token_pooler_fn(x_pack_stage, dim=2, i_stage=0, tensor_type=tensor_type)
        
        x_leftover_stage = x[:, -left_over:, :, :].contiguous()
        x_leftover_stage = x_leftover_stage.unsqueeze(1)
        x_leftover_stage = token_pooler_fn(x_leftover_stage, dim=2, i_stage=0, tensor_type=tensor_type)
        
        x_stage = torch.cat([
            x_pack_stage,
            x_leftover_stage,
        ], dim=1)
        
        return x_stage
        
def hip_first_stage(
    q,
    k,
    v,
    position_ids,
    rope_cos,
    rope_sin,
    token_pooler_fn,
    output_unpooler_fn,
    next_topk,
    block_size_q,
    block_size_k,
    sm_scale,
    args: HiPAttentionArgs,
):
    q_seq_len = q.shape[1]
    
    block_size_q = min(block_size_q, triton.next_power_of_2(q_seq_len))
    
    q_stage = token_sequence_pooling(
        q, token_pooler_fn, block_size_q, "query",
    )
    k_stage = token_sequence_pooling(
        k, token_pooler_fn, block_size_k, "key",
    )
    v_stage = token_sequence_pooling(
        v, token_pooler_fn, block_size_k, "value",
    )
    
    def rotate_half(vec):
        # assert len(vec.shape) == 1
        out = torch.zeros_like(vec)
        x1 = vec[..., : vec.shape[-1] // 2]
        x2 = vec[..., vec.shape[-1] // 2 :]
        out[..., : vec.shape[-1] // 2] = -x2
        out[..., vec.shape[-1] // 2 :] = x1
        return out

    def apply_rope(vec, cos, sin):
        vec_rope = (vec * cos) + (rotate_half(vec) * sin)
        return vec_rope

    if args.need_apply_rope and args.hip_tune_first_stage_apply_rope:
        q_stage = apply_rope(
            q_stage,
            rope_cos[None, args.position_ids[0, ::block_size_q] // block_size_k, None, :],
            rope_sin[None, args.position_ids[0, ::block_size_q] // block_size_k, None, :],
        )
        k_stage = apply_rope(
            k_stage,
            rope_cos[None, : k_stage.shape[1], None, :],
            rope_sin[None, : k_stage.shape[1], None, :],
        )
        v_stage = apply_rope(
            v_stage,
            rope_cos[None, : v_stage.shape[1], None, :],
            rope_sin[None, : v_stage.shape[1], None, :],
        )

    # FIXME: test correctness of backward for q_factor & kv_factor
    comp_output, scores = flash_attention(
        q=q_stage.permute(0, 2, 1, 3).contiguous(),
        k=k_stage.permute(0, 2, 1, 3).contiguous(),
        v=v_stage.permute(0, 2, 1, 3).contiguous(),
        seq_lens=(position_ids[:, ::block_size_q] + 1) // block_size_k,
        causal=True,
        EXCLUDE_LAST_WINDOW=False, #NOTE: last block is already excluded by seq_lens
        sm_scale=sm_scale,
        RETURN_SCORES=True,
        q_factor=block_size_q,
        kv_factor=block_size_k,
    )
    comp_output = comp_output.permute(0, 2, 1, 3)  # (B, q_chunks, q_heads, dim)
    
    if HIP_DEBUG:
        plt.clf()
        plt.imshow(scores[0, 0].cpu().float().numpy())
        plt.colorbar()
        plt.savefig('dummy_scores.png')

    comp_output = output_unpooler_fn(
        comp_output.unsqueeze(2),
        rate=block_size_q,
    )
    comp_output = comp_output.flatten(1, 2)
    comp_output = comp_output[:, :q_seq_len, :, :].contiguous()

    # Ensure sink and streaming tokens not selected
    # q_start_idx = (
    #     torch.arange(scores.shape[-2], device=scores.device)[:, None] * block_size_q
    # )
    q_start_idx = (
        position_ids[:, ::block_size_q, None]
        .contiguous()
        .expand(position_ids.shape[0], scores.shape[-2], scores.shape[-1])[
            :, None, :, :
        ]
    )
    k_start_idx = (
        torch.arange(scores.shape[-1], device=scores.device)[None, None, :]
        * block_size_k
    ).expand(position_ids.shape[0], scores.shape[-2], scores.shape[-1])[:, None, :, :]
    scores = torch.where(
        (k_start_idx < args.sink_token_size)
        | (k_start_idx >= (q_start_idx + block_size_q - args.sliding_window_size)),
        float("-inf"),
        scores,
    )
    scores, indices = torch.topk(
        scores,
        k=min(scores.shape[-1], next_topk // block_size_k),
        dim=-1,
    )  # (B, H, q_seq // prev_block_size_q, cur_topk // prev_block_size_k)
    indices = torch.where(
        scores > -1e4,
        indices,
        MAXIMUM_INDEX,
    )

    return comp_output, indices


def hip_stage(
    q,
    k,
    v,
    position_ids,
    comp_indices,
    rope_cos,
    rope_sin,
    token_pooler_fn,
    output_unpooler_fn,
    prev_block_size_q,
    prev_block_size_k,
    block_size_q,
    block_size_k,
    cur_topk,
    next_topk,
    sm_scale,
    i_stage,
    args: HiPAttentionArgs,
):
    q_stage = token_sequence_pooling(
        q, token_pooler_fn, 1, "query",
    )
    k_stage = token_sequence_pooling(
        k, token_pooler_fn, block_size_k, "key",
    )
    v_stage = token_sequence_pooling(
        v, token_pooler_fn, block_size_k, "value",
    )
    # Perform sparse attention
    assert (prev_block_size_k % block_size_k) == 0
    assert prev_block_size_k >= block_size_k

    indices = torch.where(
        comp_indices < MAXIMUM_INDEX,
        comp_indices * prev_block_size_k // block_size_k,
        MAXIMUM_INDEX
    )
    indices = indices[..., None] + torch.arange(
        prev_block_size_k // block_size_k, 
        device=indices.device
    )
    indices = indices.flatten(-2, -1)
    if prev_block_size_q != block_size_q:
        assert (prev_block_size_q % block_size_q) == 0
        assert prev_block_size_q >= block_size_q
        indices = torch.repeat_interleave(
            indices, repeats=prev_block_size_q // block_size_q, dim=2
        )
    indices = torch.sort(indices, dim=-1).values
    indices = torch.where(
        indices < k.shape[1], 
        indices,
        MAXIMUM_INDEX
    )

    ks = (indices < MAXIMUM_INDEX).sum(-1)
    ks_count = ks.unsqueeze(-1)
    ks_start_end = torch.stack([torch.zeros_like(ks), ks], dim=-1)

    args_stage = args.clone()
    args_stage.block_size_q = block_size_q
    args_stage.block_size_k = 1
    args_stage.sink_token_size = args.sink_token_size // block_size_k
    args_stage.sliding_window_size = args.sliding_window_size // block_size_k
    # args_stage.using_extend = False
    # args_stage.need_apply_rope = False
    
    seq_lens = (position_ids[:, -q.shape[1] :] + 1) // block_size_k
    
    comp_output, scores = block_sparse_attention(
        q=q_stage,
        k=k_stage,
        v=v_stage,
        seq_lens=seq_lens,
        indices=indices.flatten(0, 1),
        ks=ks.flatten(0, 1),
        ks_count=ks_count.flatten(0, 1),
        ks_start_end=ks_start_end.flatten(0, 1),
        args=args_stage,
        access_counter=None,
        cache_miss_counter=None,
        extend_backend=args_stage.sa_extend_backend,
        model_context_length=args_stage.model_context_length,
        extend_context_length=args_stage.extend_context_length,
        offload_update_cache=args_stage.online_update_cache,
        return_attention_scores=True,
        output_attention_scores_reduce_method="max",
    )
    # BUG: why do we need this permute? the output is already B, QCHUNKS, QHEAD, DIM
    # comp_output = comp_output.permute(0, 2, 1, 3)  # (B, q_chunks, q_heads, dim)

    comp_output = output_unpooler_fn(
        comp_output.unsqueeze(2),
        rate=1,
    )
    comp_output = comp_output.flatten(1, 2)
    # comp_output = comp_output.permute(0, 2, 1, 3)

    scores = scores.permute(0, 2, 1, 3)
    
    if HIP_DEBUG:
        plt.clf()
        plt.imshow(scores[0, 0,].cpu().float().numpy())
        plt.colorbar()
        plt.savefig(f'dummy_scores_stage_{i_stage}.png')

    topk = min(scores.size(-1), next_topk // block_size_k)
    scores, new_indices = torch.topk(
        scores,
        k=topk,
        dim=-1,
    )
    new_indices = einx.get_at(
        "b h q [k], b h q i -> b h q i",
        indices,
        new_indices,
    )
    new_indices = torch.where(
        scores > -1e4,
        new_indices,
        MAXIMUM_INDEX,
    )

    return comp_output, new_indices
