import torch
from flash_attn import flash_attn_func, flash_attn_with_kvcache
import dp_ext 

"""
Flash Attention
"""
@torch.library.custom_op("plugin::flash_attn", mutates_args=())
def flash_attn_forward(q: torch.Tensor,k: torch.Tensor,v:torch.Tensor, softmax_scale:float) -> torch.Tensor:
    return flash_attn_func(q,k,v,causal=True,softmax_scale=softmax_scale)

@torch.library.custom_op("plugin::flash_attn_with_cache", mutates_args=())
def flash_attn_with_cache_forward(q: torch.Tensor, k_cache: torch.Tensor, v_cache:torch.Tensor, cache_seqlens: torch.Tensor, softmax_scale:float) -> torch.Tensor:
    return flash_attn_with_kvcache(q, k_cache, v_cache, causal=True, cache_seqlens=cache_seqlens, softmax_scale=softmax_scale)

@flash_attn_forward.register_fake
def _(q, k, v, softmax_scale):
    return q.new_empty(q.shape)

@flash_attn_with_cache_forward.register_fake
def _(q, k, v, cache_seqlens, softmax_scale):
    return q.new_empty(q.shape)

"""
Sqllm
"""
@torch.library.custom_op("plugin::sqllm_gemv", mutates_args={"output"})
def sqllm_gemv(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int) -> None:
    dp_ext.sqllm_gemv(x, q_weight, lut, output, bitwidth)

@sqllm_gemv.register_fake
def _(x, q_weight, lut, output, bitwidth):
    return None

@torch.library.custom_op("plugin::dec_sqllm", mutates_args={"output"})
def dec_sqllm(dec_config: int, x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int) -> None:
    dp_ext.dec_sqllm(dec_config, x, q_weight, lut, output, bitwidth)

@dec_sqllm.register_fake
def _(dec_config, x, q_weight, lut, output, bitwidth):
    return None

"""
Any-Precision
"""
@torch.library.custom_op("plugin::anyprec_gemv_sel_fake", mutates_args={"output"})
def anyprec_gemv_sel_fake(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int, bsel:torch.Tensor) -> None:
    dp_ext.anyprec_gemv_sel_fake(x, output, q_weight, lut, bitwidth, bsel)

@anyprec_gemv_sel_fake.register_fake
def _(x, q_weight, lut, output, bitwidth, bsel):
    return None

@torch.library.custom_op("plugin::gemvNormTH", mutates_args={"bsel"})
def gemvNormTH(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTH(x, jl, bsel, low, high, threshold, sne)

@gemvNormTH.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None


"""
qkvgu
"""

@torch.library.custom_op("plugin::gemvNormTHq", mutates_args={"bsel"})
def gemvNormTHq(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTHq(x, jl, bsel, low, high, threshold, sne)

@gemvNormTHq.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTHk", mutates_args={"bsel"})
def gemvNormTHk(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTHk(x, jl, bsel, low, high, threshold, sne)

@gemvNormTHk.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTHv", mutates_args={"bsel"})
def gemvNormTHv(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTHv(x, jl, bsel, low, high, threshold, sne)

@gemvNormTHv.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTHg", mutates_args={"bsel"})
def gemvNormTHg(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTHg(x, jl, bsel, low, high, threshold, sne)

@gemvNormTHg.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTHu", mutates_args={"bsel"})
def gemvNormTHu(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTHu(x, jl, bsel, low, high, threshold, sne)

@gemvNormTHu.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTHqkv", mutates_args={"bsel"})
def gemvNormTHqkv(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTHqkv(x, jl, bsel, low, high, threshold, sne)

@gemvNormTHqkv.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTHgu", mutates_args={"bsel"})
def gemvNormTHgu(x: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.gemvNormTHgu(x, jl, bsel, low, high, threshold, sne)

@gemvNormTHgu.register_fake
def _(x, jl, bsel, low, high, threshold, sne):
    return None

@torch.library.custom_op("plugin::normTHq", mutates_args={"bsel"})
def normTHq(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, sne:int) -> None:
    dp_ext.normTHq(x, a, b, threshold, bsel, low, high, sne)

@normTHq.register_fake
def _(x, a, b, threshold, bsel, low, high, sne):
    return None

@torch.library.custom_op("plugin::normTHk", mutates_args={"bsel"})
def normTHk(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, sne:int) -> None:
    dp_ext.normTHk(x, a, b, threshold, bsel, low, high, sne)

@normTHk.register_fake
def _(x, a, b, threshold, bsel, low, high, sne):
    return None


@torch.library.custom_op("plugin::normTHv", mutates_args={"bsel"})
def normTHv(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, sne:int) -> None:
    dp_ext.normTHv(x, a, b, threshold, bsel, low, high, sne)

@normTHv.register_fake
def _(x, a, b, threshold, bsel, low, high, sne):
    return None


@torch.library.custom_op("plugin::normTHg", mutates_args={"bsel"})
def normTHg(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, sne:int) -> None:
    dp_ext.normTHg(x, a, b, threshold, bsel, low, high, sne)

@normTHg.register_fake
def _(x, a, b, threshold, bsel, low, high, sne):
    return None


@torch.library.custom_op("plugin::normTHu", mutates_args={"bsel"})
def normTHu(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, sne:int) -> None:
    dp_ext.normTHu(x, a, b, threshold, bsel, low, high, sne)

@normTHu.register_fake
def _(x, a, b, threshold, bsel, low, high, sne):
    return None

@torch.library.custom_op("plugin::normTHqkv", mutates_args={"bsel"})
def normTHqkv(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, sne:int) -> None:
    dp_ext.normTHqkv(x, a, b, threshold, bsel, low, high, sne)

@normTHqkv.register_fake
def _(x, a, b, threshold, bsel, low, high, sne):
    return None

@torch.library.custom_op("plugin::normTHgu", mutates_args={"bsel"})
def normTHgu(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, low:torch.Tensor, high:torch.Tensor, sne:int) -> None:
    dp_ext.normTHgu(x, a, b, threshold, bsel, low, high, sne)

@normTHgu.register_fake
def _(x, a, b, threshold, bsel, low, high, sne):
    return None






@torch.library.custom_op("plugin::gemvNormTH2", mutates_args={"bsel", "bsel2"})
def gemvNormTH2(x: torch.Tensor, jl:torch.Tensor, jl2:torch.Tensor, bsel:torch.Tensor, bsel2:torch.Tensor, threshold:float, threshold2:float, sne:int) -> None:
    dp_ext.gemvNormTH2(x, jl, jl2, bsel, bsel2, threshold, threshold2, sne)

@gemvNormTH2.register_fake
def _(x, jl, jl2, bsel, bsel2, threshold, threshold2, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTH3", mutates_args={"bsel", "bsel2", "bsel3"})
def gemvNormTH3(x: torch.Tensor, jl:torch.Tensor, jl2:torch.Tensor, jl3:torch.Tensor, 
                bsel:torch.Tensor, bsel2:torch.Tensor, bsel3:torch.Tensor, 
                threshold:float, threshold2:float, threshold3:float, sne:int) -> None:
    dp_ext.gemvNormTH3(x, jl, jl2, jl3, bsel, bsel2, bsel3, threshold, threshold2, threshold3, sne)

@gemvNormTH3.register_fake
def _(x, jl, jl2, jl3, bsel, bsel2, bsel3, threshold, threshold2, threshold3, sne):
    return None

@torch.library.custom_op("plugin::gemvNormTH3Full", mutates_args={"bsel", "bsel2", "bsel3", "res", "res2", "res3"})
def gemvNormTH3Full(x: torch.Tensor, jl:torch.Tensor, jl2:torch.Tensor, jl3:torch.Tensor, 
                    res:torch.Tensor, res2:torch.Tensor, res3:torch.Tensor, 
                bsel:torch.Tensor, bsel2:torch.Tensor, bsel3:torch.Tensor, 
                threshold:float, threshold2:float, threshold3:float, sne:int) -> None:
    dp_ext.gemvNormTH3Full(x, jl, jl2, jl3, res, res2, res3, bsel, bsel2, bsel3, threshold, threshold2, threshold3, sne)

@gemvNormTH3Full.register_fake
def _(x, jl, jl2, jl3, res, res2, res3, bsel, bsel2, bsel3, threshold, threshold2, threshold3, sne):
    return None

@torch.library.custom_op("plugin::normTH", mutates_args={"bsel"})
def normTH(x: torch.Tensor, a:float, b:float, threshold:float, bsel:torch.Tensor, sne:int) -> None:
    dp_ext.normTH(x, a, b, threshold, bsel, sne)

@normTH.register_fake
def _(x, a, b, threshold, bsel, sne):
    return None

@torch.library.custom_op("plugin::normTH2", mutates_args={"bsel"})
def normTH2(x: torch.Tensor, a:float, a2:float, b:float, b2:float, threshold:float, threshold2:float, 
            bsel:torch.Tensor, bsel2:torch.Tensor, sne:int) -> None:
    dp_ext.normTH2(x, a, a2, b, b2, threshold, threshold2, bsel, bsel2, sne)

@normTH2.register_fake
def _(x, a, a2, b, b2, threshold, threshold2, bsel, bsel2, sne):
    return None

@torch.library.custom_op("plugin::lnNormTH2", mutates_args={"res", "bsel"})
def lnNormTH2(x: torch.Tensor, normW: torch.Tensor, res: torch.Tensor, a:float, a2:float, b:float, b2:float, threshold:float, threshold2:float, 
            bsel:torch.Tensor, bsel2:torch.Tensor, sne:int) -> None:
    dp_ext.lnNormTH2(x, normW, res, a, a2, b, b2, threshold, threshold2, bsel, bsel2, sne)

@lnNormTH2.register_fake
def _(x, normW, res, a, a2, b, b2, threshold, threshold2, bsel, bsel2, sne):
    return None

@torch.library.custom_op("plugin::lnGemvNormTH", mutates_args={"res", "bsel"})
def lnGemvNormTH(x: torch.Tensor, normW: torch.Tensor, res: torch.Tensor, jl:torch.Tensor, bsel:torch.Tensor, threshold:float, sne:int) -> None:
    dp_ext.lnGemvNormTH(x, normW, res, jl, bsel, threshold, sne)

@lnGemvNormTH.register_fake
def _(x, normW, res, jl, bsel, threshold, sne):
    return None

@torch.library.custom_op("plugin::fakeTrigger", mutates_args={})
def fakeTrigger(sne:int) -> None:
    dp_ext.fakeTrigger(sne)

@fakeTrigger.register_fake
def _(sne):
    return None



# @torch.library.custom_op("plugin::create_streamNevent", mutates_args={})
# def create_streamNevent() -> int:
#     dec_context = dp_ext.create_streamNevent()
#     return dec_context

# @create_streamNevent.register_fake
# def _():
#     return None

@torch.library.custom_op("plugin::create_streamNevent_full", mutates_args={})
def create_streamNevent_full() -> int:
    dec_context = dp_ext.create_streamNevent_full()
    return dec_context

@create_streamNevent_full.register_fake
def _():
    return None

# @torch.library.custom_op("plugin::create_streamNevent_intra", mutates_args={})
# def create_streamNevent_intra() -> int:
#     dec_context = dp_ext.create_streamNevent_intra()
#     return dec_context

# @create_streamNevent_intra.register_fake
# def _():
#     return None



@torch.library.custom_op("plugin::anyprec_gemv_sel_two", mutates_args={"output"})
def anyprec_gemv_sel_two(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int, 
                         jl:torch.Tensor, res:torch.Tensor) -> None:
    dp_ext.anyprec_gemv_sel_two(x, output, q_weight, lut, bitwidth, jl, res)

@anyprec_gemv_sel_two.register_fake
def _(x, q_weight, lut, output, bitwidth, jl, res):
    return None

@torch.library.custom_op("plugin::anyprec_gemv_sel", mutates_args={"output"})
def anyprec_gemv_sel(x: torch.Tensor, q_weight: torch.Tensor, lut3: torch.Tensor, lut4: torch.Tensor, lut5: torch.Tensor, lut6: torch.Tensor, output:torch.Tensor, bitwidth:int, bsel:torch.Tensor, sne:int) -> None:
    dp_ext.anyprec_gemv_sel(x, output, q_weight, lut3, lut4, lut5, lut6, bitwidth, bsel, sne)

@anyprec_gemv_sel.register_fake
def _(x, q_weight, lut3, lut4, lut5, lut6, output, bitwidth, bsel, sne):
    return None

@torch.library.custom_op("plugin::anyprec_gemv", mutates_args={"output"})
def anyprec_gemv(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int) -> None:
    dp_ext.anyprec_gemv(x, output, q_weight, lut, bitwidth)

@anyprec_gemv.register_fake
def _(x, q_weight, lut, output, bitwidth):
    return None

@torch.library.custom_op("plugin::dec_anyprec", mutates_args={"output"})
def dec_anyprec(dec_config: int, x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int) -> None:
    dp_ext.dec_anyprec(dec_config, x, output, q_weight, lut, bitwidth)

@dec_anyprec.register_fake
def _(dec_config, x, q_weight, lut, output, bitwidth):
    return None

"""
LUTGEMM
"""
@torch.library.custom_op("plugin::lutgemm_gemv", mutates_args={"output"})
def lutgemm_gemv(x: torch.Tensor, q_weight: torch.Tensor, alpha: torch.Tensor, q_bias: torch.Tensor, output: torch.Tensor, bitwidth: int, group_size: int) -> None:
    dp_ext.lutgemm_gemv(x, q_weight, alpha, q_bias, output, bitwidth, group_size)

@lutgemm_gemv.register_fake
def _(x, q_weight, alpha, q_bias, output, bitwidth, group_size):
    return None

@torch.library.custom_op("plugin::dec_lutgemm", mutates_args={"output"})
def dec_lutgemm(dec_config: int, x: torch.Tensor, q_weight: torch.Tensor, alpha: torch.Tensor, q_bias: torch.Tensor, output: torch.Tensor, bitwidth: int, group_size: int) -> None:
    dp_ext.dec_lutgemm(dec_config, x, q_weight, alpha, q_bias, output, bitwidth, group_size)

@dec_lutgemm.register_fake
def _(dec_config, x, q_weight, alpha, q_bias, output, bitwidth, group_size):
    return None


"""
DECConfig, DECContext
"""

@torch.library.custom_op("plugin::create_dec_context", mutates_args={})
def create_dec_context(n_tb: int, index_buffer: torch.Tensor, activation_buffer: torch.Tensor) -> int:
    dec_context = dp_ext.create_dec_context(n_tb, index_buffer, activation_buffer)
    return dec_context

@create_dec_context.register_fake
def _(n_tb, index_buffer, activation_buffer):
    return None

@torch.library.custom_op("plugin::create_dec_config", mutates_args={})
def create_dec_config(dec_context: int, k_chunk: int, q_residual: torch.Tensor, scales: torch.Tensor, thresholds: torch.Tensor) -> int:
    dec_config = dp_ext.create_dec_config(dec_context, k_chunk, q_residual, scales, thresholds)
    return dec_config

@create_dec_config.register_fake
def _(dec_context, k_chunk, q_residual, scales, thresholds):
    return None

@torch.library.custom_op("plugin::update_dec_config", mutates_args={})
def update_dec_config(dec_config: int, dec_context: int, k_chunk: int) -> None:
    dp_ext.update_dec_config(dec_config, dec_context, k_chunk)

@update_dec_config.register_fake
def _(dec_config, dec_context, k_chunk):
    return None
