import unittest
from pathlib import Path
import numpy as np

import torch
import torch.utils.checkpoint

from quick_extend.ref_impl.mask_gen import mask_gen_ref
from quick_extend.ref_impl.sparse_attn import sparse_attn_ref, sparse_attn_flash_ref
from quick_extend.ref_impl.sparse_attn_bwd import sparse_attn_bwd_ref

from quick_extend.triton_impl.mask_gen import mask_gen_triton
from quick_extend.triton_impl.sparse_attn import sparse_attn_triton
from quick_extend.triton_impl.sparse_attn_bwd import sparse_attn_bwd_triton
from quick_extend.triton_impl.mixed_fwd_bwd import mask_gen_orig


def print_errors(a, b):
    assert a.shape == b.shape, f"Shapes do not match: {a.shape} != {b.shape}"
    a = a.float()
    b = b.float()
    if a.numel() < 16_000_000:
        quantile = 0.99
        percent = f"{quantile * 100:.0f}%"
        abs_95 = torch.quantile(torch.abs(a - b), torch.tensor(quantile, device=a.device)).item()
        rel_95 = torch.quantile((torch.abs(a - b) / (1e-6 + torch.abs(a))), torch.tensor(quantile, device=a.device)).item()
        msg = f"abserr ({percent}): {abs_95:.6f}, relerr ({percent}): {rel_95:.6f}\n"
    else:
        abs_avg = torch.mean(torch.abs(a - b)).item()
        rel_avg = torch.mean((torch.abs(a - b) / (1e-6 + torch.abs(a)))).item()
        msg = f"abserr (avg): {abs_avg:.6f}, relerr (avg): {rel_avg:.6f}\n"
    abs_max = torch.amax(torch.abs(a - b)).item()
    rel_max = torch.amax((torch.abs(a - b) / (1e-6 + torch.abs(a)))).item()
    msg += f"abserr (max): {abs_max:.6f}, relerr (max): {rel_max:.6f}\n"
    print(msg)
    return msg


def compare_attention_results(
        test_case,
        run_ref_kernel,
        run_kernel_funcs,
        q, k, v,
        print_output=False,
        test_gradients=True,
        n_repeat=3,
        atol=5e-2,
        rtol=1e-3,
):
    if print_output:
        output_dir = Path("output")
        output_dir.mkdir(exist_ok=True)
        np.set_printoptions(threshold=np.inf, linewidth=200)

    if test_gradients:
        q.requires_grad = True
        k.requires_grad = True
        v.requires_grad = True

    m_ref = None
    grad_noise = None
    if run_ref_kernel is not None:
        print("Running reference")
        m_ref, r_ref, _, _, _ = run_ref_kernel(q, k, v)

        if print_output:
            with open(output_dir / "ref.txt", "w") as f:
                f.write(str(r_ref.data.cpu().numpy()))

        assert torch.all(m_ref >= 0)

        if test_gradients:
            grad_noise = torch.randn_like(r_ref)
            r_ref.backward(grad_noise)
            dq_auto, dk_auto, dv_auto, q.grad, k.grad, v.grad = q.grad, k.grad, v.grad, None, None, None

            if print_output:
                with open(output_dir / "auto_flash_dq.txt", "w") as f:
                    f.write(str(dq_auto.cpu().numpy()))
                with open(output_dir / "auto_flash_dk.txt", "w") as f:
                    f.write(str(dk_auto.cpu().numpy()))
                with open(output_dir / "auto_flash_dv.txt", "w") as f:
                    f.write(str(dv_auto.cpu().numpy()))

    for method_name, run_new_kernel in run_kernel_funcs:
        print(f"Running {method_name}")
        # Warmup
        m, r, dq, dk, dv = run_new_kernel(q, k, v, grad_noise=grad_noise)

        # Measure time
        start_time = torch.cuda.Event(enable_timing=True)
        end_time = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start_time.record()
        for _ in range(n_repeat):
            mask = None if m_ref is None else m_ref
            _, r, dq, dk, dv = run_new_kernel(q, k, v, mask, grad_noise=grad_noise)
        end_time.record()
        torch.cuda.synchronize()
        print(f"Memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB")
        print("Finished")

        BATCH_SIZE = r.size(0)
        print(f'{method_name} Elapsed time: {start_time.elapsed_time(end_time) / n_repeat / BATCH_SIZE:.2f}ms')

        print(f"{method_name} <-> ref mask")
        with test_case.subTest(msg=f"{method_name} <-> ref mask"):
            if m_ref is not None and not isinstance(m, tuple):
                test_case.assertEqual(m.shape, m_ref.shape, f"Shapes do not match: {m.shape} != {m_ref.shape}")
                TOP_K = m_ref.size(-1)
                np_m = torch.sort(m, dim=-1)[0].reshape(-1, TOP_K).data.cpu().numpy()
                np_m_ref = torch.sort(m_ref, dim=-1)[0].reshape(-1, TOP_K).data.cpu().numpy()
                diff_count = sum(
                    max(np.setdiff1d(np_m[i], np_m_ref[i]).shape[0], np.setdiff1d(np_m_ref[i], np_m[i]).shape[0])
                    for i in range(np_m.shape[0])
                )
                print(f"Masks {diff_count} ({diff_count / m.numel() * 100:.2f}%) Different")
                test_case.assertLess(diff_count / m.numel(), 0.03, "Masks differ too much")

        if print_output:
            print("result:\n", torch.sort(m, dim=-1)[0].cpu().numpy())
            with open(output_dir / f"{method_name}.txt", "w") as f:
                f.write(str(r.data.cpu().numpy()))

            if test_gradients:
                with open(output_dir / f"{method_name}_dq.txt", "w") as f:
                    f.write(str(dq.data.cpu().numpy()))
                with open(output_dir / f"{method_name}_dk.txt", "w") as f:
                    f.write(str(dk.data.cpu().numpy()))
                with open(output_dir / f"{method_name}_dv.txt", "w") as f:
                    f.write(str(dv.data.cpu().numpy()))

        if not isinstance(m, tuple):
            assert torch.all(m >= 0)

        # Print absolute and relative errors
        if run_ref_kernel is not None:
            print(f"{method_name} <-> ref")
            with test_case.subTest(msg=f"{method_name} <-> ref"):
                print_errors(r_ref, r)
                test_case.assertTrue(torch.allclose(r_ref, r, rtol=rtol, atol=atol))

            if test_gradients:
                print(f"{method_name} <-> auto dq")
                with test_case.subTest(msg=f"{method_name} <-> auto dq"):
                    print_errors(dq, dq_auto)
                    test_case.assertTrue(torch.allclose(dq, dq_auto, rtol=rtol, atol=atol))
                print(f"{method_name} <-> auto dk")
                with test_case.subTest(f"{method_name} <-> auto dk"):
                    print_errors(dk, dk_auto)
                    test_case.assertTrue(torch.allclose(dk, dk_auto, rtol=rtol, atol=atol))
                print(f"{method_name} <-> auto dv")
                with test_case.subTest(f"{method_name} <-> auto dv"):
                    print_errors(dv, dv_auto)
                    test_case.assertTrue(torch.allclose(dv, dv_auto, rtol=rtol, atol=atol))


def run_test(
        test_case,
        BATCH_SIZE=1,
        NUM_HEADS=32,
        QUERY_LEN=32768,
        KEY_VALUE_LEN=32768,
        HEAD_DIM=128,
        VALUE_DIM=128,
        BLOCK_SIZE_Q=16,
        BLOCK_SIZE_K=2,
        TOP_K_ELEMS=1024,
        QUERY_OFFSET=0,
        START_SINK_TOKENS=2,
        END_SINK_TOKENS=16,

        run_new_impl=True,
        run_ref_impl=False,
        run_softmask_impl=False,
        run_old_impl=False,
        print_output=False,
        test_gradients=True,
        dtype=torch.float16):
    TOP_K = TOP_K_ELEMS // BLOCK_SIZE_K

    torch.random.manual_seed(0)
    dev = torch.device("cuda")
    q = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, HEAD_DIM, dtype=dtype, device=dev)
    k = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, HEAD_DIM, dtype=dtype, device=dev)
    v = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, VALUE_DIM, dtype=dtype, device=dev)
    mask = torch.ones(BATCH_SIZE, KEY_VALUE_LEN, dtype=torch.bool, device=dev)

    if test_gradients:
        q.requires_grad = True
        k.requires_grad = True
        v.requires_grad = True

    def ref_impl(q, k, v, m=None, grad_noise=None):
        if m is None:
            m = mask_gen_ref(q, k, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                                  START_SINK_TOKENS, END_SINK_TOKENS, soft_sort=False)
        r = sparse_attn_ref(q, k, m, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                            START_SINK_TOKENS, END_SINK_TOKENS)
        return m, r, None, None, None

    def ref_softmask_impl(q, k, v, m=None, grad_noise=None):
        # Always recompute mask
        m = torch.utils.checkpoint.checkpoint(
            mask_gen_ref,
            q, k, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
            START_SINK_TOKENS, END_SINK_TOKENS, True,
            use_reentrant=True,
        )
        r, l = sparse_attn_flash_ref(q, k, m, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, 32, 32, QUERY_OFFSET,
                                     START_SINK_TOKENS, END_SINK_TOKENS)
        dq, dk, dv = None, None, None
        if test_gradients:
            r.backward(grad_noise)
            dq, dk, dv, q.grad, k.grad, v.grad = q.grad, k.grad, v.grad, None, None, None
        return m, r, dq, dk, dv

    def ref_flash_impl(q, k, v, m=None, grad_noise=None):
        if m is None:
            m = mask_gen_ref(q, k, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                             START_SINK_TOKENS, END_SINK_TOKENS, soft_sort=False)
        r, l = sparse_attn_flash_ref(q, k, m, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, 32, 32, QUERY_OFFSET,
                                     START_SINK_TOKENS, END_SINK_TOKENS)
        dq_ref, dk_ref, dv_ref = None, None, None
        if test_gradients:
            dq_ref, dk_ref, dv_ref = sparse_attn_bwd_ref(
                q, k, m, v, r, grad_noise, l, BLOCK_SIZE_Q, BLOCK_SIZE_K, 32, 32, QUERY_OFFSET,
                START_SINK_TOKENS, END_SINK_TOKENS)
        return m, r, dq_ref, dk_ref, dv_ref

    def original_impl(q, k, v, m=None):
        import hip
        return hip.hip_attention_mask(
            queries=q.reshape(BATCH_SIZE * NUM_HEADS, QUERY_LEN, HEAD_DIM),
            keys=k.reshape(BATCH_SIZE * NUM_HEADS, KEY_VALUE_LEN, HEAD_DIM),
            attention_mask=mask[:, None].expand(-1, NUM_HEADS, -1).reshape(-1, KEY_VALUE_LEN),
            kv_repeat_interleave=1,

            w_start=TOP_K_ELEMS * 2,
            n_patches=TOP_K_ELEMS // 2,
            mask_k=TOP_K_ELEMS,
            scale_up=2,
            is_causal=True,

            block_size_q=BLOCK_SIZE_Q,
            block_size_k=BLOCK_SIZE_K,
            reduce_method='max',
            reduce_stride=1,

            is_flash=False,
            enable_sparq=False,
            sampling_method='first',

            using_sliding_window=False,
            sliding_window_size=128,

            rope_method='none',
        )

    def new_kernel(q, k, v, m_=None, grad_noise=None):
        if m_ is None:
            m_, k0 = mask_gen_triton(q, k, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                                     START_SINK_TOKENS, END_SINK_TOKENS)
        r_, l_ = sparse_attn_triton(q, k, m_, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                                    START_SINK_TOKENS, END_SINK_TOKENS)
        dq, dk, dv = None, None, None
        if test_gradients:
            if grad_noise is None:
                grad_noise = torch.randn_like(r_)
            dq, dk, dv = sparse_attn_bwd_triton(
                q, k, m_, v, r_, grad_noise, l_, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                START_SINK_TOKENS, END_SINK_TOKENS)
        return m_, r_, dq, dk, dv

    def new_kernel_2(q, k, v, m_=None, grad_noise=None):
        m_ = mask_gen_orig(q, k, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                           START_SINK_TOKENS, END_SINK_TOKENS)
        r_, l_ = sparse_attn_triton(q, k, m_, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                                    0, END_SINK_TOKENS)
        dq, dk, dv = None, None, None
        if test_gradients:
            if grad_noise is None:
                grad_noise = torch.randn_like(r_)
            dq, dk, dv = sparse_attn_bwd_triton(
                q, k, m_, v, r_, grad_noise, l_, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                0, END_SINK_TOKENS)
        return m_, r_, dq, dk, dv

    func_list = []
    if run_ref_impl:
        func_list.append(("Ref Flash Impl", ref_flash_impl))
    if run_softmask_impl:
        func_list.append(("Ref Softmask Impl", ref_softmask_impl))
    if run_new_impl:
        func_list.append(("New Impl", new_kernel_2))
    if run_old_impl:
        func_list.append(("Old Impl", original_impl))
    compare_attention_results(
        test_case,
        ref_impl if run_ref_impl else None,
        func_list,
        q, k, v,
        print_output=print_output,
        test_gradients=test_gradients,
    )


class TestHiP(unittest.TestCase):

    def test_hip_big_shape(self):
        run_test(
            self,
            BATCH_SIZE=1,
            NUM_HEADS=32,
            QUERY_LEN=32768,
            KEY_VALUE_LEN=32768,
            HEAD_DIM=128,
            VALUE_DIM=128,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=1024,
            QUERY_OFFSET=0,
            START_SINK_TOKENS=2,
            END_SINK_TOKENS=16,

            run_new_impl=True,
            run_ref_impl=False,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.float16,
        )

    def test_hip_middle_shape(self):
        run_test(
            self,
            BATCH_SIZE=2,
            NUM_HEADS=32,
            QUERY_LEN=4096,
            KEY_VALUE_LEN=4096,
            HEAD_DIM=128,
            VALUE_DIM=128,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=512,
            QUERY_OFFSET=0,
            START_SINK_TOKENS=2,
            END_SINK_TOKENS=16,

            run_new_impl=True,
            run_softmask_impl=True,
            run_ref_impl=False,
            run_old_impl=False,
            print_output=False,
            test_gradients=False,
            dtype=torch.float16,
        )

    def test_hip_llama7b(self):
        run_test(
            self,
            BATCH_SIZE=4,
            NUM_HEADS=32,
            QUERY_LEN=32768 * 4,
            KEY_VALUE_LEN=32768 * 4,
            HEAD_DIM=128,
            VALUE_DIM=128,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=512,
            QUERY_OFFSET=0,
            START_SINK_TOKENS=32,
            END_SINK_TOKENS=128,

            run_new_impl=True,
            run_ref_impl=False,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.bfloat16,
        )

    def test_hip_llama13b(self):
        run_test(
            self,
            BATCH_SIZE=2,
            NUM_HEADS=40,
            QUERY_LEN=32768 * 4,
            KEY_VALUE_LEN=32768 * 4,
            HEAD_DIM=128,
            VALUE_DIM=128,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=512,
            QUERY_OFFSET=0,
            START_SINK_TOKENS=32,
            END_SINK_TOKENS=128,

            run_new_impl=True,
            run_ref_impl=False,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.bfloat16,
        )

    def test_hip_small_shape(self):
        run_test(
            self,
            BATCH_SIZE=2,
            NUM_HEADS=2,
            QUERY_LEN=895,
            KEY_VALUE_LEN=1024,
            HEAD_DIM=32,
            VALUE_DIM=32,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=128,
            QUERY_OFFSET=1024 - 895,
            START_SINK_TOKENS=32,
            END_SINK_TOKENS=32,

            run_new_impl=True,
            run_ref_impl=True,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.float32,
        )

    def test_hip_softmask(self):
        run_test(
            self,
            BATCH_SIZE=2,
            NUM_HEADS=2,
            QUERY_LEN=512,
            KEY_VALUE_LEN=512,
            HEAD_DIM=32,
            VALUE_DIM=32,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=128,
            QUERY_OFFSET=512 - 512,
            START_SINK_TOKENS=32,
            END_SINK_TOKENS=32,

            run_new_impl=True,
            run_ref_impl=True,
            run_softmask_impl=True,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.float32,
        )

    def test_hip_small_shape_2(self):
        run_test(
            self,
            BATCH_SIZE=2,
            NUM_HEADS=2,
            QUERY_LEN=1024,
            KEY_VALUE_LEN=1024,
            HEAD_DIM=32,
            VALUE_DIM=32,
            BLOCK_SIZE_Q=32,
            BLOCK_SIZE_K=4,
            TOP_K_ELEMS=128,
            QUERY_OFFSET=1024 - 1024,
            START_SINK_TOKENS=64,
            END_SINK_TOKENS=32,

            run_new_impl=True,
            run_ref_impl=True,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.float32,
        )

    def test_hip_small_shape_3(self):
        run_test(
            self,
            BATCH_SIZE=2,
            NUM_HEADS=2,
            QUERY_LEN=895,
            KEY_VALUE_LEN=896,
            HEAD_DIM=32,
            VALUE_DIM=32,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=512,
            QUERY_OFFSET=896 - 895,
            START_SINK_TOKENS=32,
            END_SINK_TOKENS=32,

            run_new_impl=True,
            run_ref_impl=True,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.float32,
        )

    def test_hip_decode(self):
        run_test(
            self,
            BATCH_SIZE=2,
            NUM_HEADS=2,
            QUERY_LEN=1,
            KEY_VALUE_LEN=896,
            HEAD_DIM=32,
            VALUE_DIM=32,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=512,
            QUERY_OFFSET=896 - 1,
            START_SINK_TOKENS=32,
            END_SINK_TOKENS=32,

            run_new_impl=True,
            run_ref_impl=True,
            run_old_impl=False,
            print_output=False,
            test_gradients=True,
            dtype=torch.float32,
        )

    def test_hip_debug(self):
        run_test(
            self,
            BATCH_SIZE=1,
            NUM_HEADS=1,
            QUERY_LEN=331,
            KEY_VALUE_LEN=343,
            HEAD_DIM=16,
            VALUE_DIM=32,
            BLOCK_SIZE_Q=16,
            BLOCK_SIZE_K=2,
            TOP_K_ELEMS=64,
            QUERY_OFFSET=343 - 331,
            START_SINK_TOKENS=2,
            END_SINK_TOKENS=16,

            run_new_impl=True,
            run_ref_impl=True,
            run_old_impl=False,
            print_output=True,
            test_gradients=True,
            dtype=torch.float32,
        )


if __name__ == "__main__":
    unittest.main()
