import unittest

import torch

from test_hip import print_errors
from quick_extend.ref_impl.full_forward import reference_attn_flash
from quick_extend.ref_impl.full_backward import reference_attn_flash_bwd
from quick_extend.triton_impl.full_forward import attn_flash_triton
from quick_extend.triton_impl.full_backward import attn_flash_bwd_triton
from quick_extend.triton_impl.mixed_fwd_bwd import mixed_attn, HipAttention
from quick_extend.triton_impl.full_fwd_bwd import flash_attn
from quick_extend.ref_impl.ste_backward import ste_sparse_attn_ref


class TestFullAttn(unittest.TestCase):

    def test_full_attn(self):
        BATCH_SIZE = 1
        NUM_HEADS = 1
        QUERY_LEN = 331
        KEY_VALUE_LEN = 343
        HEAD_DIM = 16
        VALUE_DIM = 32
        QUERY_OFFSET = KEY_VALUE_LEN - QUERY_LEN

        do_mixed = False

        torch.random.manual_seed(0)
        dev = torch.device("cuda")
        dtype = torch.float16
        q = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, HEAD_DIM, dtype=dtype, device=dev)
        k = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, HEAD_DIM, dtype=dtype, device=dev)
        v = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, VALUE_DIM, dtype=dtype, device=dev)

        q.requires_grad = True
        k.requires_grad = True
        v.requires_grad = True

        attn_mask = (
            torch.arange(KEY_VALUE_LEN, device=dev)[None, :]
            <= QUERY_OFFSET + torch.arange(QUERY_LEN, device=dev)[:, None]
        )
        ref_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        my_output, L = reference_attn_flash(q, k, v, 16, 16, query_offset=QUERY_OFFSET)

        print("Errors: ")
        print_errors(ref_output, my_output)

        triton_output = flash_attn(q, k, v, QUERY_OFFSET)
        print("Triton errors: ")
        print_errors(ref_output, triton_output)

        noise = torch.randn_like(ref_output, dtype=dtype, device=dev)
        ref_output.backward(noise)

        ref_dq, ref_dk, ref_dv = q.grad, k.grad, v.grad
        q.grad = k.grad = v.grad = None

        my_dq, my_dk, my_dv = reference_attn_flash_bwd(q, k, v, ref_output, noise, L, 16, 16, QUERY_OFFSET)

        triton_output.backward(noise)
        tr_dq, tr_dk, tr_dv = q.grad, k.grad, v.grad
        q.grad = k.grad = v.grad = None

        if do_mixed:
            test_output = mixed_attn(q, k, v, 16, 2, 64, QUERY_OFFSET, 0, 32, 32)
            test_output.backward(noise)

            mix_dq, mix_dk, mix_dv = q.grad, k.grad, v.grad
            q.grad = k.grad = v.grad = None

        print("Grad errors: ")
        print_errors(ref_dq, my_dq)
        print_errors(ref_dk, my_dk)
        print_errors(ref_dv, my_dv)

        print("Grad errors (Triton): ")
        print_errors(ref_dq, tr_dq)
        print_errors(ref_dk, tr_dk)
        print_errors(ref_dv, tr_dv)

        if do_mixed:
            print("Attn errors (mixed): ")
            print_errors(ref_output, test_output)

            print("Grad errors (mixed): ")
            print_errors(ref_dq, mix_dq)
            print_errors(ref_dk, mix_dk)
            print_errors(ref_dv, mix_dv)

    def test_extra_tokens(self):
        states = torch.load("../states.pth")
        q, k, v, o, do, L, ka, va = states
        QUERY_OFFSET = 0

        #BATCH_SIZE = 2
        #NUM_HEADS = 32
        #QUERY_LEN = 4096
        #KEY_VALUE_LEN = 4096
        #HEAD_DIM = 128
        #VALUE_DIM = 128
        #EXTRA_TOKENS = 32
        #QUERY_OFFSET = KEY_VALUE_LEN - QUERY_LEN
        #
        #torch.random.manual_seed(0)
        #dev = torch.device("cuda")
        #dtype = torch.float32
        #q = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, EXTRA_TOKENS, HEAD_DIM, dtype=dtype, device=dev)
        #k = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, HEAD_DIM, dtype=dtype, device=dev).to(torch.float8_e5m2)
        #v = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, VALUE_DIM, dtype=dtype, device=dev).to(torch.float8_e5m2)
        #ka = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, EXTRA_TOKENS, HEAD_DIM, dtype=dtype, device=dev)
        #va = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, EXTRA_TOKENS, VALUE_DIM, dtype=dtype, device=dev)


        q.requires_grad = True
        k.requires_grad = True
        v.requires_grad = True

        #triton_output = flash_attn(q, k, v, QUERY_OFFSET, ka, va)
        #triton_output.backward(torch.randn_like(triton_output))
        (dq, dk, dv), _ = attn_flash_bwd_triton(q, k, v, o, do, L, QUERY_OFFSET, ka, va)

        torch.cuda.synchronize()


if __name__ == "__main__":
    unittest.main()
