from yunchang.ring import (
    ring_flash_attn_func,
    ring_flash_attn_qkvpacked_func,
    ring_flash_attn_func_skipkv,
    zigzag_ring_flash_attn_func,
    zigzag_ring_flash_attn_qkvpacked_func,
    stripe_flash_attn_func,
    stripe_flash_attn_qkvpacked_func,
    ring_pytorch_attn_func,
    ring_flashinfer_attn_func,
    ring_flashinfer_attn_qkvpacked_func,
    dist_flash_attn_func,
    zigzag_ring_flash_attn_func_skip_kv,
    upipe_ring_flash_attn_func,
    # fully_pipelined_long_context_attn_func,
    fully_fused_attn_func,
    staggered_ring_flash_attn_func,
    micro_fused_attn_func,
)

RING_IMPL_DICT = {
    "basic": ring_flash_attn_func,
    "basic_skipkv": ring_flash_attn_func_skipkv,
    "zigzag": zigzag_ring_flash_attn_func,
    "strip": stripe_flash_attn_func,
    "basic_pytorch": ring_pytorch_attn_func,
    "basic_flashinfer": ring_flashinfer_attn_func,
    "dist_flash_attn": dist_flash_attn_func,
    "zigzag_skipkv": zigzag_ring_flash_attn_func_skip_kv,
    "upipe": upipe_ring_flash_attn_func,
    # "fullpipe": fully_pipelined_long_context_attn_func,
    "fullfused": fully_fused_attn_func,
    "staggered": staggered_ring_flash_attn_func,
    "micro_fused": micro_fused_attn_func,
}

RING_IMPL_QKVPACKED_DICT = {
    "basic": ring_flash_attn_qkvpacked_func,
    "zigzag": zigzag_ring_flash_attn_qkvpacked_func,
    "strip": stripe_flash_attn_qkvpacked_func,
    "basic_flashinfer": ring_flashinfer_attn_qkvpacked_func,
}
