import os
from typing import Literal

import cv2
import numpy as np
import torch
from tqdm import tqdm

from hip_attn.v1_3.attention import HiPAttentionArgs, ScanStage, hip_attention
from hip_attn.v1_3.attention_tune import hip_tune_attention
from hip_attn.v1_3.kernels.block_sparse_attention import block_sparse_attention
from hip_research.utils.load_checkouts import load_checkouts
from hip_research.utils.seed import seed
from hip_research.utils.triton_flash import attention as triton_flash


def main_debug():
    seed()

    preset_name = os.getenv("PRESET", "default")

    attn_backend = {
        "hip": hip_attention,
        "hip-tune": hip_tune_attention,
    }[os.getenv("BACKEND", "hip")]

    if preset_name == "minimal":
        seq_len = int(os.getenv("SEQ_LEN", "1024"))
    else:
        # seq_len = int(os.getenv("SEQ_LEN", "32768"))
        seq_len = int(os.getenv("SEQ_LEN", "16384"))

    query_seq_dups = int(os.getenv("Q_DUPS", "-1"))
    seq_dups = int(os.getenv("DUPS", "1"))
    if query_seq_dups < 0:
        query_seq_dups = seq_dups

    assert seq_dups > 0

    # NOTE: hidden should be larger than 32, if not, the result will be incorrect
    H, S, D = 1, seq_len, 128
    dtype = os.getenv("DTYPE", "bf16")
    if dtype == "fp32":
        dtype = torch.float32
    elif dtype == "fp16":
        dtype = torch.float16
    else:
        dtype = torch.bfloat16
    q = torch.randn(H, S, D, device="cuda:0", dtype=dtype)
    k = torch.randn(H, S, D, device="cuda:0", dtype=dtype)
    v = torch.randn(H, S, D, device="cuda:0", dtype=dtype)
    cos = torch.randn(S, D, device="cuda:0", dtype=dtype)
    sin = torch.randn(S, D, device="cuda:0", dtype=dtype)

    q = q.repeat(1, query_seq_dups, 1).permute(1, 0, 2).contiguous().unsqueeze(0)
    k = k.repeat(1, seq_dups, 1).permute(1, 0, 2).contiguous().unsqueeze(0)
    v = v.repeat(1, seq_dups, 1).permute(1, 0, 2).contiguous().unsqueeze(0)
    if cos is not None:
        cos = cos.repeat(seq_dups, 1)
        sin = sin.repeat(seq_dups, 1)

    if preset_name == "default":
        stages = [
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=1,
                stage_chunk_size=128,
                stage_k=None,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=1,
                stage_chunk_size=32,
                stage_k=16384,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=1,
                stage_chunk_size=8,
                stage_k=4096,
                stage_stride=1,
            ),
        ]

        args = HiPAttentionArgs(
            sliding_window_size=1024,
            sink_token_size=256,
            using_extend=True,
            need_apply_rope=True,
            rope_cos=cos,
            rope_sin=sin,
            second_stage_k=1024,
            stages=stages,
            model_context_length=65536,
            scan_extend_backend="relative",
            sa_extend_backend="streaming",
            block_sparse_block_size_q=stages[-1].stage_block_size_q,
            enable_hip_tune=False,
        )
    elif preset_name == "minimal":
        raise NotImplementedError(
            f"the current setup uses args designed for default length. reconfigure these args if you need to use minimal setup"
        )
        # stages = [
        #     ScanStage(
        #         stage_block_size_q=16,
        #         stage_block_stride_q=1,
        #         stage_chunk_size=4,
        #         stage_k=None,
        #         stage_stride=1,
        #     ),
        #     ScanStage(
        #         stage_block_size_q=16,
        #         stage_block_stride_q=1,
        #         stage_chunk_size=2,
        #         stage_k=64,
        #         stage_stride=1,
        #     ),
        #     ScanStage(
        #         stage_block_size_q=16,
        #         stage_block_stride_q=1,
        #         stage_chunk_size=1,
        #         stage_k=32,
        #         stage_stride=1,
        #     ),
        # ]

        # args = HiPAttentionArgs(
        #     sliding_window_size=32,
        #     sink_token_size=32,
        #     using_extend=True,
        #     need_apply_rope=True,
        #     rope_cos=cos,
        #     rope_sin=sin,
        #     second_stage_k=16,
        #     stages=stages,
        #     model_context_length=64,
        #     scan_extend_backend="relative",
        #     sa_extend_backend="streaming",
        #     enable_hip_tune=False,
        #     block_sparse_block_size_q=32,
        # )
    else:
        raise Exception()

    class SimpleTokenPooler(torch.nn.Module):
        def __init__(self, num_stages: int):
            super().__init__()
            self.num_stages = num_stages
            self.poolers = torch.nn.ParameterDict(
                {
                    "query": torch.nn.ParameterList(
                        [
                            torch.nn.Parameter(torch.randn((32, 128, 128)) * 0.01)
                            for i in range(num_stages)
                        ]
                    ),
                    "key": torch.nn.ParameterList(
                        [
                            torch.nn.Parameter(torch.randn((8, 128, 128)) * 0.01)
                            for i in range(num_stages)
                        ]
                    ),
                    # 'value': torch.nn.ParameterList([
                    #     torch.nn.Parameter(torch.randn((8, 128, 128)) * 0.01)
                    #     for i in range(num_stages)
                    # ])
                }
            )

        def forward(
            self,
            x: torch.Tensor,
            dim: int,
            i_stage: int,
            tensor_type: Literal["query", "key", "value"],
        ):
            assert tensor_type in ["query", "key", "value"]

            # if tensor_type != 'value':
            #     w = self.poolers[tensor_type][i_stage]
            #     xs = []
            #     for i in range(x.shape[-2]):
            #         xs.append(torch.nn.functional.linear(x[..., i, :], w[i].to(x.dtype)))
            #     x = torch.stack(xs, dim=-2) + x

            x_pooled = x.mean(dim=dim)

            return x_pooled

    class SimpleOutputUnpooler(torch.nn.Module):
        def __init__(self, dim=2):
            super().__init__()
            self.pooling_dim = dim

        def forward(self, x: torch.Tensor, rate: int):
            return x.repeat_interleave(rate, dim=self.pooling_dim)

    token_pooler_fn = SimpleTokenPooler(num_stages=len(stages)).to(q.device)
    output_unpooler_fn = SimpleOutputUnpooler().to(q.device)
    random_gate_probs = (
        torch.randn(
            # NOTE: [BSZ, N_Q, HEAD, N_GATES]
            (
                q.shape[0],
                q.shape[1],
                q.shape[2],
                len(stages) + 1,
            ),
            device=q.device,
            dtype=q.dtype,
        )
        - 1.0
    )
    random_gate_probs[..., -1] = 10

    args.token_pooler_fn = token_pooler_fn
    args.output_unpooler_fn = output_unpooler_fn
    args.gate_probs = random_gate_probs

    _q = torch.nn.Parameter(q)
    _k = torch.nn.Parameter(k)
    _v = torch.nn.Parameter(v)
    output, metadata = attn_backend(
        q=_q,
        k=_k,
        v=_v,
        args=args,
    )

    torch.cuda.synchronize()

    print("[PASS] fwd smoketest")

    output.sum().backward()

    torch.cuda.synchronize()

    print("[PASS] bwd smoketest")

    def attn(_q, _k, _v, _cos, _sin, kwargs=None, kernel="flash"):
        # s = 1.0
        # s = 1 / np.sqrt(D)

        if kernel == "flash":
            out = triton_flash(_q, _k, _v, True, s)
            return out, None
        elif kernel == "hip":
            if kwargs is None:

                os.environ["HIP_RETURN_KWARGS_ONLY"] = "1"
                kwargs = hip_attention(
                    q=_q,
                    k=_k,
                    v=_v,
                    args=HiPAttentionArgs(
                        sliding_window_size=1024,
                        sink_token_size=256,
                        using_extend=True,
                        need_apply_rope=True,
                        rope_cos=_cos,
                        rope_sin=_sin,
                        second_stage_k=1024,
                        stages=stages,
                        model_context_length=65536,
                        scan_extend_backend="relative",
                        sa_extend_backend="streaming",
                        block_sparse_block_size_q=stages[-1].stage_block_size_q,
                        enable_hip_tune=False,
                        # sm_scale=s,
                    ),
                    # args=HiPAttentionArgs(
                    #     sliding_window_size=32,
                    #     sink_token_size=32,
                    #     using_extend=False,
                    #     need_apply_rope=False,
                    #     rope_cos=_cos,
                    #     rope_sin=_sin,
                    #     second_stage_k=1024,
                    #     stages=stages,
                    #     model_context_length=64,
                    #     scan_extend_backend="relative",
                    #     sa_extend_backend="streaming",
                    #     enable_hip_tune=False,
                    #     sm_scale=s,
                    #     # token_pooler_fn=token_pooler_fn,
                    #     # output_unpooler_fn=output_unpooler_fn,
                    #     # gate_probs=random_gate_probs_tensor,
                    # ),
                )

            assert not torch.any(torch.isnan(_q))
            assert not torch.any(torch.isnan(_k))
            assert not torch.any(torch.isnan(_v))
            for k, v in kwargs.items():
                if isinstance(v, torch.Tensor):
                    assert not torch.any(torch.isnan(v))

            args = [v for _, v in kwargs.items()]
            out, _ = block_sparse_attention(_q, _k, _v, *args)
            return out, kwargs
        elif kernel == "hip-tune":
            args = HiPAttentionArgs()

            hip_static_args = HiPAttentionArgs(
                sliding_window_size=1024,
                sink_token_size=256,
                using_extend=True,
                need_apply_rope=True,
                rope_cos=_cos,
                rope_sin=_sin,
                second_stage_k=1024,
                stages=stages,
                model_context_length=65536,
                scan_extend_backend="relative",
                sa_extend_backend="streaming",
                block_sparse_block_size_q=stages[-1].stage_block_size_q,
                enable_hip_tune=True,
                # sm_scale=s,
                token_pooler_fn=token_pooler_fn,
                output_unpooler_fn=output_unpooler_fn,
                gate_probs=random_gate_probs,
            )
            if kwargs is None:
                os.environ["HIP_RETURN_KWARGS_ONLY"] = "1"
                kwargs = hip_tune_attention(q=_q, k=_k, v=_v, args=hip_static_args)
                # print(f"kwargs after: {kwargs}")

            assert not torch.any(torch.isnan(_q))
            assert not torch.any(torch.isnan(_k))
            assert not torch.any(torch.isnan(_v))
            for k, v in kwargs.items():
                if isinstance(v, torch.Tensor):
                    assert not torch.any(torch.isnan(v))

            os.environ["HIP_RETURN_KWARGS_ONLY"] = "0"
            os.environ["HIP_TESTING"] = "1"
            # testing_args = [v for _, v in kwargs.items()]
            out, _ = hip_tune_attention(_q, _k, _v, hip_static_args, kwargs=kwargs)
            return out, kwargs
        else:
            raise NotImplementedError(f"unknown kernel: {kernel=}")

    import transformers
    from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

    config = transformers.AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B")
    rope = LlamaRotaryEmbedding(config, device="cuda:0")
    cos, sin = rope(q, torch.arange(0, seq_len, device="cuda:0")[None, :])
    cos = cos.squeeze(0).to(dtype)
    sin = sin.squeeze(0).to(dtype)

    for kernel in ["hip-tune", "hip", "flash"]:
        for true_grad in [False, True]:
            q_diffs, k_diffs, v_diffs = [], [], []
            with tqdm(range(128)) as pbar:
                for i in pbar:
                    q = torch.randn(H, S, D, device="cuda:0", dtype=dtype)
                    k = torch.randn(H, S, D, device="cuda:0", dtype=dtype)
                    v = torch.randn(H, S, D, device="cuda:0", dtype=dtype)

                    EPS = 1e0

                    q = q.repeat(1, query_seq_dups, 1).unsqueeze(0).contiguous()
                    k = k.repeat(1, seq_dups, 1).unsqueeze(0).contiguous()
                    v = v.repeat(1, seq_dups, 1).unsqueeze(0).contiguous()
                    if "hip" in kernel:
                        q = q.permute(0, 2, 1, 3).contiguous()
                        k = k.permute(0, 2, 1, 3).contiguous()
                        v = v.permute(0, 2, 1, 3).contiguous()

                    # if cos is not None:
                    #     cos = cos.repeat(seq_dups, 1)
                    #     sin = sin.repeat(seq_dups, 1)

                    q = torch.nn.Parameter(q)
                    k = torch.nn.Parameter(k)
                    v = torch.nn.Parameter(v)

                    with torch.no_grad():
                        u_mask = (
                            torch.rand(S, device=q.data.device, dtype=q.data.dtype)
                            > 0.80
                        ).float()
                        vec_mask = (
                            torch.rand(S, device=q.data.device, dtype=q.data.dtype)
                            > 0.80
                        ).float()

                        if kernel == "flash":
                            u = torch.ones_like(q.data) * u_mask.view(1, 1, -1, 1)
                            vec = torch.ones_like(q.data) * vec_mask.view(1, 1, -1, 1)
                        elif "hip" in kernel:
                            u = torch.ones_like(q.data) * u_mask.view(1, -1, 1, 1)
                            vec = torch.ones_like(q.data) * vec_mask.view(1, -1, 1, 1)

                        u = u / torch.norm(u)
                        vec = vec / torch.norm(vec)

                        if torch.any(torch.isnan(u)):
                            raise Exception(f"nans found after norm {u}")
                        if torch.any(torch.isinf(u)):
                            raise Exception(f"infs found after norm {u}")

                    out, kwargs = attn(q, k, v, cos, sin, kernel=kernel)

                    assert not torch.any(torch.isnan(out))

                    # Compute the gradient of g with respect to x using autograd:
                    (out * u).sum().backward()

                    for tns, name, diff in zip(
                        (q, k, v), ("q", "k", "v"), (q_diffs, k_diffs, v_diffs)
                    ):
                        uJ = tns.grad
                        assert not torch.any(torch.isnan(uJ))

                        # if we are not testing the true grad, make a random gradient with the same norm
                        # as the true grad
                        if not true_grad:
                            # uJ_temp = torch.randn_like(uJ)
                            if name == "v":
                                # this is because d AV / dv = uA = sum(A, dim=0) and u is all ones. This means
                                # that the wrong gradient will be close if we use all ones
                                uJ_temp = torch.randn_like(uJ)
                            else:
                                uJ_temp = torch.ones_like(uJ)

                            uJ = torch.norm(uJ, dim=-1, keepdim=True) * (
                                uJ_temp / torch.norm(uJ_temp, dim=-1, keepdim=True)
                            )

                        with torch.no_grad():
                            uJv = (uJ * vec).sum()

                        # Compute the finite difference approximation:
                        with torch.no_grad():
                            tns_ = tns.data.clone()
                            q_ = q.data.clone()
                            k_ = k.data.clone()
                            v_ = v.data.clone()
                            cos_ = cos.clone()
                            sin_ = sin.clone()

                            x_plus = (tns_ + EPS * vec).to(dtype)
                            x_minus = (tns_ - EPS * vec).to(dtype)

                            if name == "q":
                                g_plus, _ = attn(
                                    x_plus,
                                    k_,
                                    v_,
                                    cos_,
                                    sin_,
                                    kwargs=kwargs,
                                    kernel=kernel,
                                )
                                g_minus, _ = attn(
                                    x_minus,
                                    k_,
                                    v_,
                                    cos_,
                                    sin_,
                                    kwargs=kwargs,
                                    kernel=kernel,
                                )
                            elif name == "k":
                                g_plus, _ = attn(
                                    q_,
                                    x_plus,
                                    v_,
                                    cos_,
                                    sin_,
                                    kwargs=kwargs,
                                    kernel=kernel,
                                )
                                g_minus, _ = attn(
                                    q_,
                                    x_minus,
                                    v_,
                                    cos_,
                                    sin_,
                                    kwargs=kwargs,
                                    kernel=kernel,
                                )
                            elif name == "v":
                                g_plus, _ = attn(
                                    q_,
                                    k_,
                                    x_plus,
                                    cos_,
                                    sin_,
                                    kwargs=kwargs,
                                    kernel=kernel,
                                )
                                g_minus, _ = attn(
                                    q_,
                                    k_,
                                    x_minus,
                                    cos_,
                                    sin_,
                                    kwargs=kwargs,
                                    kernel=kernel,
                                )

                            d = g_plus - g_minus
                            Jv_taylor = (g_plus - g_minus) / (2 * EPS)
                            uJv_taylor = (u * Jv_taylor).sum()

                        diff.append((uJv.item() - uJv_taylor.item()) ** 2)

                        if i > 0:
                            s = (
                                f"q mean: {sum(q_diffs) / len(q_diffs):.12f} "
                                + f"k mean: {sum(k_diffs) / len(k_diffs):.12f} "
                                + f"v mean: {sum(v_diffs) / len(v_diffs):.12f} "
                                + f"{len(q_diffs)=}"
                            )
                            pbar.set_description(s)

            print(f"q mse {kernel=} {true_grad=}: {sum(q_diffs) / len(q_diffs):.12f}")
            print(f"k mse {kernel=} {true_grad=}: {sum(k_diffs) / len(k_diffs):.12f}")
            print(f"v mse {kernel=} {true_grad=}: {sum(v_diffs) / len(v_diffs):.12f}")


if __name__ == "__main__":
    main_debug()
