import unittest
import torch

from quick_extend.triton_impl.mask_gen import mask_gen_triton
from quick_extend.triton_impl.sparse_attn import sparse_attn_triton
from quick_extend.ref_impl.sparse_attn import sparse_attn_ref, sparse_attn_flash_ref

from test_hip import print_errors


def test_causality(call_type):
    BATCH_SIZE = 1
    NUM_HEADS = 1
    QUERY_LEN = 1024
    KEY_VALUE_LEN = 1024
    HEAD_DIM = 128
    VALUE_DIM = 128
    BLOCK_SIZE_Q = 16
    BLOCK_SIZE_K = 2
    TOP_K_ELEMS = 64
    QUERY_OFFSET = 0
    START_SINK_TOKENS = 32
    END_SINK_TOKENS = 32
    TOP_K = TOP_K_ELEMS // BLOCK_SIZE_K

    torch.random.manual_seed(0)
    dev = torch.device("cuda")
    dtype = torch.float16
    q = torch.randn(BATCH_SIZE, NUM_HEADS, QUERY_LEN, HEAD_DIM, dtype=dtype, device=dev)
    k = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, HEAD_DIM, dtype=dtype, device=dev)
    v = torch.randn(BATCH_SIZE, NUM_HEADS, KEY_VALUE_LEN, VALUE_DIM, dtype=dtype, device=dev)

    m_, _ = mask_gen_triton(q, k, TOP_K, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                            START_SINK_TOKENS, END_SINK_TOKENS)

    def hip_fwd(q, k, v):
        if call_type == "triton":
            r_, _, _ = sparse_attn_triton(q, k, m_, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                                          START_SINK_TOKENS, END_SINK_TOKENS)
        elif call_type == "ref":
            r_ = sparse_attn_ref(q, k, m_, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, QUERY_OFFSET,
                                    START_SINK_TOKENS, END_SINK_TOKENS)
        elif call_type == "ref_flash":
            r_, _ = sparse_attn_flash_ref(q, k, m_, v, BLOCK_SIZE_Q, BLOCK_SIZE_K, 32, 32, QUERY_OFFSET,
                                          START_SINK_TOKENS, END_SINK_TOKENS)
        return r_

    i = 201
    r_1 = hip_fwd(q, k, v)
    k[:, :, i:] = 0.0
    r_2 = hip_fwd(q, k, v)

    assert torch.allclose(r_1[:, :, :i], r_2[:, :, :i]), print_errors(r_1[:, :, :i], r_2[:, :, :i])


class TestHipCausality(unittest.TestCase):
    def test_causality_triton(self):
        test_causality("triton")

    def test_causality_ref(self):
        test_causality("ref")

    def test_causality_ref_flash(self):
        test_causality("ref_flash")


if __name__ == "__main__":
    unittest.main()
