import torch
import torch.nn as nn
from flashinfer import get_batch_indices_positions, get_seq_lens, append_paged_kv_cache

torch.library.define(
    "mylib::update_kv",
    "(Tensor k, Tensor v, Tensor kv_append_indptr, Tensor(a!) kv_cache, Tensor kv_page_indices, Tensor kv_page_indptr, Tensor cachelen, int page_size) -> ()",
)

@torch.library.impl("mylib::update_kv", "cuda")
def update_kv(k, v, kv_append_indptr, kv_cache, kv_page_indices, kv_page_indptr, kv_page_last_len, page_size=128):
    nnz_kv = kv_append_indptr[-1].item()
    batch_indices, positions = get_batch_indices_positions(
        kv_append_indptr, get_seq_lens(kv_page_indptr, kv_page_last_len, page_size), nnz_kv
    )
    append_paged_kv_cache(k, v, batch_indices, positions, kv_cache, kv_page_indices, kv_page_indptr, kv_page_last_len)

@torch.library.register_fake("mylib::update_kv")
def update_kv_abstract(k, v, kv_append_indptr, kv_cache, kv_page_indices, kv_page_indptr, kv_page_last_len, page_size=128):
    return None

class StandardKVCache(nn.Module):
    def __init__(self, max_num_pages, page_size, n_heads, head_dim, dtype=torch.bfloat16):
        super().__init__()
        cache_shape = (max_num_pages, 2, page_size, n_heads, head_dim)
        self.register_buffer('kv_cache', torch.zeros(cache_shape, dtype=dtype))
        self.page_size = page_size
        
    def update(self, k, v, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen):
        torch.ops.mylib.update_kv(k, v, kv_append_indptr, self.kv_cache, kv_page_indices, kv_page_indptr, kv_page_lastlen, self.page_size)
        return self.kv_cache