import os
from typing import Literal

import cv2
import numpy as np
import torch

from hip_attn.v1_3.attention import HiPAttentionArgs, ScanStage, hip_attention
from hip_attn.v1_3.attention_tune import hip_tune_attention
from hip_research.utils.load_checkouts import load_checkouts
from hip_research.utils.seed import seed

TRITON_DEBUG = os.getenv("TRITON_DEBUG", "0") == "1"
if TRITON_DEBUG:
    torch.set_anomaly_enabled(True, True)


def main_debug():

    seed()

    preset_name = os.getenv("PRESET", "default")

    attn_backend = {
        "hip": hip_attention,
        "hip-tune": hip_tune_attention,
    }[os.getenv("BACKEND", "hip")]

    if preset_name == "minimal":
        seq_len = int(os.getenv("SEQ_LEN", "1024"))
    else:
        seq_len = int(os.getenv("SEQ_LEN", "32768"))
    query_seq_dups = int(os.getenv("Q_DUPS", "-1"))
    seq_dups = int(os.getenv("DUPS", "1"))
    if query_seq_dups < 0:
        query_seq_dups = seq_dups

    assert seq_dups > 0

    q, k, v, out, cos, sin = load_checkouts(
        idx=0,
        window=32,
        seq_len=seq_len,
        return_cos_sin=True,
        derope=True,
        dtype=torch.bfloat16,
        sm_scale=1.0,
    )
    seq_len = seq_len * seq_dups

    q = q.repeat(1, query_seq_dups, 1).permute(1, 0, 2).contiguous().unsqueeze(0)
    k = k.repeat(1, seq_dups, 1).permute(1, 0, 2).contiguous().unsqueeze(0)
    v = v.repeat(1, seq_dups, 1).permute(1, 0, 2).contiguous().unsqueeze(0)
    if cos is not None:
        cos = cos.repeat(seq_dups, 1)
        sin = sin.repeat(seq_dups, 1)

    print(q.shape, k.shape, v.shape)

    from flash_attn import flash_attn_func

    torch.cuda.synchronize()
    _q = torch.nn.Parameter(q)
    _k = torch.nn.Parameter(k)
    _v = torch.nn.Parameter(v)
    output = flash_attn_func(_q, _k, _v, causal=True)
    output.sum().backward()
    torch.cuda.synchronize()

    if preset_name == "default":
        stages = [
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=128,
                stage_k=None,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=32,
                stage_k=min(seq_len, 16384),
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=1,
                stage_chunk_size=8,
                stage_k=min(seq_len, 4096),
                stage_stride=1,
            ),
        ]

        args = HiPAttentionArgs(
            sliding_window_size=1024,
            sink_token_size=256,
            using_extend=True,
            need_apply_rope=True,
            rope_cos=cos,
            rope_sin=sin,
            second_stage_k=1024,
            stages=stages,
            model_context_length=65536,
            scan_extend_backend="relative",
            sa_extend_backend="streaming",
            block_sparse_block_size_q=stages[-1].stage_block_size_q,
            enable_hip_tune=attn_backend == hip_tune_attention,
            block_sparse_bwd_block_size_q=64,
        )
    elif preset_name == "minimal":
        stages = [
            ScanStage(
                stage_block_size_q=32,
                stage_block_stride_q=2,
                stage_chunk_size=4,
                stage_k=None,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=32,
                stage_block_stride_q=2,
                stage_chunk_size=2,
                stage_k=64,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=32,
                stage_block_stride_q=1,
                stage_chunk_size=1,
                stage_k=32,
                stage_stride=1,
            ),
        ]

        args = HiPAttentionArgs(
            sliding_window_size=32,
            sink_token_size=32,
            using_extend=True,
            need_apply_rope=True,
            rope_cos=cos,
            rope_sin=sin,
            second_stage_k=16,
            stages=stages,
            model_context_length=64,
            scan_extend_backend="relative",
            sa_extend_backend="streaming",
            enable_hip_tune=False,
            block_sparse_block_size_q=32,
        )
    else:
        raise Exception()

    class SimpleTokenPooler(torch.nn.Module):
        def __init__(self, num_stages: int):
            super().__init__()
            self.num_stages = num_stages
            self.poolers = torch.nn.ParameterDict(
                {
                    "query": torch.nn.ParameterList(
                        [
                            torch.nn.Parameter(torch.randn((32, 128, 128)) * 0.01)
                            for i in range(num_stages)
                        ]
                    ),
                    "key": torch.nn.ParameterList(
                        [
                            torch.nn.Parameter(torch.randn((8, 128, 128)) * 0.01)
                            for i in range(num_stages)
                        ]
                    ),
                    # 'value': torch.nn.ParameterList([
                    #     torch.nn.Parameter(torch.randn((8, 128, 128)) * 0.01)
                    #     for i in range(num_stages)
                    # ])
                }
            )

        def forward(
            self,
            x: torch.Tensor,
            dim: int,
            i_stage: int,
            tensor_type: Literal["query", "key", "value"],
        ):
            assert tensor_type in ["query", "key", "value"]

            # if tensor_type != 'value':
            #     w = self.poolers[tensor_type][i_stage]
            #     xs = []
            #     for i in range(x.shape[-2]):
            #         xs.append(torch.nn.functional.linear(x[..., i, :], w[i].to(x.dtype)))
            #     x = torch.stack(xs, dim=-2) + x

            x_pooled = x.mean(dim=dim)

            return x_pooled

    class SimpleOutputUnpooler(torch.nn.Module):
        def __init__(self, dim=2):
            super().__init__()
            self.pooling_dim = dim

        def forward(self, x: torch.Tensor, rate: int):
            return x.repeat_interleave(rate, dim=self.pooling_dim)

    token_pooler_fn = SimpleTokenPooler(num_stages=len(stages)).to(q.device)
    output_unpooler_fn = SimpleOutputUnpooler().to(q.device)
    random_gate_probs = (
        torch.randn(
            # NOTE: [BSZ, N_Q, HEAD, N_GATES]
            (
                q.shape[0],
                q.shape[1],
                q.shape[2],
                len(stages) + 1,
            ),
            device=q.device,
            dtype=q.dtype,
        )
        - 1.0
    )
    random_gate_probs[..., -1] = 10

    args.token_pooler_fn = token_pooler_fn
    args.output_unpooler_fn = output_unpooler_fn
    args.gate_probs = random_gate_probs.sigmoid().to(torch.bfloat16)

    _q = torch.nn.Parameter(q.float())
    _k = torch.nn.Parameter(k.float())
    _v = torch.nn.Parameter(v.float())
    output, _ = attn_backend(
        q=_q.to(torch.bfloat16),
        k=_k.to(torch.bfloat16),
        v=_v.to(torch.bfloat16),
        args=args,
    )

    assert not torch.any(torch.isnan(output))

    torch.cuda.synchronize()

    print("[PASS] fwd")

    output[:, :].mean().backward()

    assert not torch.any(torch.isnan(_q.grad))
    assert not torch.any(torch.isnan(_k.grad))
    assert not torch.any(torch.isnan(_v.grad))

    torch.cuda.synchronize()

    print("[PASS] bwd")

    # NOTE: make easy solution for solver. Find values between 100-200
    # v = v.abs()

    # start_idx = v.shape[1] - 1000 # to test sw bwd
    # start_idx = 0 # to test sink token bwd
    start_idx = 8192  # to test middle bwd
    v_truth = torch.repeat_interleave(
        v[:, start_idx : start_idx + 1, :, :].clone(), 4, dim=2
    )
    v[:, start_idx : start_idx + 1024, :, :] = v[
        :, start_idx : start_idx + 1, :, :
    ].clone()

    q = torch.nn.Parameter(q.float())
    k = torch.nn.Parameter(k.float())
    v = torch.nn.Parameter(v.float())
    random_gate_probs = torch.nn.Parameter(random_gate_probs)
    # params = [q, k, v, random_gate_probs] # NOTE: this will be so easy, because value will be zero
    params = [
        q,
        k,
        # v,
        random_gate_probs,
        *token_pooler_fn.parameters(),
    ]

    import bitsandbytes as bnb

    lr = 1e-2
    optimizer = bnb.optim.Adam32bit(params, lr=lr)

    for istep in range(1000):
        random_gate_probs_tensor = random_gate_probs.sigmoid().to(
            torch.bfloat16
        )  # torch.softmax(random_gate_probs, dim=-1)
        q_tensor = q.to(torch.bfloat16)
        k_tensor = k.to(torch.bfloat16)
        v_tensor = v.to(torch.bfloat16)

        _args = args.clone()
        _args.gate_probs = random_gate_probs_tensor

        # output = flash_attn_func(q_tensor, k_tensor, v_tensor, causal=True)
        output, _ = attn_backend(
            q=q_tensor,
            k=k_tensor,
            v=v_tensor,
            args=_args,
        )

        assert not torch.any(torch.isnan(output))

        loss = (output[:, -512:] - v_truth).square().sum(dim=-1).mean()
        # loss = torch.nn.functional.cross_entropy(logits, label)
        (loss * 16000).backward()

        # print(q.grad.isnan().nonzero())

        # print(q.grad.view(-1).norm().item(), k.grad.view(-1).norm().item(), v.grad.view(-1).norm().item())
        # print(random_gate_probs[0, 0, 0])
        # print(random_gate_probs.grad[0, 0, 0])

        # torch.nn.utils.clip_grad_norm_(params, 2.0)

        # q.data.add_(-q.grad * lr)
        # k.data.add_(-k.grad * lr)
        # v.data.add_(-v.grad * lr)   # NOTE: comment this update, to prevent too easy optimization problem.
        # random_gate_probs.data.add(-random_gate_probs.grad * lr)
        # q.grad.zero_()
        # k.grad.zero_()
        # v.grad.zero_()
        # random_gate_probs.grad.zero_()

        assert not torch.any(torch.isnan(q.grad))
        assert not torch.any(torch.isnan(k.grad))
        assert not torch.any(torch.isnan(v.grad))

        optimizer.step()
        optimizer.zero_grad()

        if (istep % 10) == 0:
            print(istep, loss.item())


if __name__ == "__main__":
    main_debug()
