import math
import os
import time
import cuda
import cuda.cudart
import torch
from torch import Tensor
from typing import Optional, Tuple, Union
import tqdm
import triton
import triton.language as tl

from hip.models.hip_attention.offload_runner.tensor_from_pointer import (
    tensor_from_pointer
)
from hip.models.hip_attention.gen3.attention_metadata import (
    HiPAttentionOutputMetadata,
    HiPAttentionCacheAccessStatistics,
)

MAX_INT: tl.constexpr =             tl.constexpr(90000000)
MAX_INT_ACQUIRED: tl.constexpr =    tl.constexpr(90000001)

def sizeof(dtype: Union[Tensor, torch.dtype]) -> int:
    if isinstance(dtype, Tensor):
        return dtype.numel() * sizeof(dtype.dtype)
    
    if dtype in [
        torch.uint8, 
        torch.int8, 
        torch.float8_e4m3fn, 
        torch.float8_e4m3fnuz, 
        torch.float8_e5m2, 
        torch.float8_e5m2fnuz
    ]:
        return 1
    elif dtype in [
        torch.uint16,
        torch.int16,
        torch.float16,
        torch.bfloat16,
    ]:
        return 2
    elif dtype in [
        torch.uint32,
        torch.int32,
        torch.float32,
    ]:
        return 4
    elif dtype in [
        torch.uint64,
        torch.int64,
        torch.float64,
    ]:
        return 8
    else:
        raise Exception()

def format_size_bytes(tensor: Union[Tensor, Union[float, int]]) -> str:
    if isinstance(tensor, Tensor):
        byte_size = sizeof(tensor)
    elif isinstance(tensor, (int, float)):
        byte_size = tensor
    else:
        raise Exception()

    if byte_size < 1024:
        return f'{byte_size} B'
    elif byte_size < (1024 ** 2):
        return f'{byte_size / 1024:.2f} KB'
    elif byte_size < (1024 ** 3):
        return f'{byte_size / (1024 ** 2):.2f} MB'
    else:
        return f'{byte_size / (1024 ** 3):.2f} GB'

def debug_print(*args):
    print(f'[HiPOffloadKVPoolMHA] {" ".join(map(lambda x: str(x), args))}')

###############################################################################
#                               Data Structure
###############################################################################

def uvm_note_cpu(tensor: Tensor, prefetch: bool = False):
    cuda.cudart.cudaMemAdvise(tensor.data_ptr(), tensor.numel() * tensor.element_size(), cuda.cudart.cudaMemoryAdvise.cudaMemAdviseSetPreferredLocation, -1)
    cuda.cudart.cudaMemAdvise(tensor.data_ptr(), tensor.numel() * tensor.element_size(), cuda.cudart.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy, tensor.device.index)
    if prefetch:
        cuda.cudart.cudaMemPrefetchAsync(tensor.data_ptr(), tensor.numel() * tensor.element_size(), -1, 0)

class UVMCache:
    bank_cpu: Tensor
    bank_gpu: Tensor
    metadata: Tensor
    
    def __init__(
        self, 
        layer_id: int,
        max_token_size: int,
        head_num: int,
        head_dim: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.layer_id = layer_id
        self.max_token_size = max_token_size
        self.head_num = head_num
        self.head_dim = head_dim
        self.dtype = dtype
        self.device = device
        if self.device.index is None:
            self.device = torch.get_default_device()
        
        self.bank_cpu, self.bank_gpu = self.alloc_uvm(
            [max_token_size, head_num, head_dim],
            dtype=self.dtype
        )
        
        # {
        #     Token Generation: uint32    # Increase one on every overwrite
        # }
        self.metadata = torch.full(
            [max_token_size, 1], 
            dtype=torch.int32, 
            device=device,
            fill_value=0
        )
        
        self.allocated_cpu_bytes = sizeof(self.bank_cpu)
        self.allocated_gpu_bytes = sizeof(self.metadata)
        
        # debug_print(f'UVMCache: bank={format_size_bytes(self.bank_cpu)}, metadata={format_size_bytes(self.metadata)}')
    
    def alloc_uvm(self, shape, dtype: torch.dtype) -> Tuple[Tensor, Tensor]:
        device = self.device
        if isinstance(device, str):
            device = torch.device(device)
            
        elem_size = sizeof(dtype)
        numel = math.prod(shape)
        align = 4096
        byte_size = elem_size * numel
        byte_size = byte_size + byte_size % align
        
        _result_code, pointer = cuda.cudart.cudaMallocManaged(
            byte_size, 
            cuda.cudart.cudaMemAttachGlobal
        )
        
        t_gpu = tensor_from_pointer(pointer, shape, dtype, device.index)
        t_cpu = tensor_from_pointer(pointer, shape, dtype, -1)
        
        uvm_note_cpu(t_gpu)
        t_cpu.fill_(0)
        
        return t_cpu, t_gpu
    
    def gather_cpu(self, table: Tensor, pin_memory = False) -> Tensor:
        assert table.ndim == 1
        assert table.device == self.bank_cpu.device
        
        # print('gather alloc', flush=True)
        # t = torch.empty(
        #     (table.shape[0], self.bank_cpu.shape[1], self.bank_cpu.shape[2]), 
        #     dtype=self.bank_cpu.dtype, 
        #     device='cpu'
        # )
        
        view_dtype = torch.uint16
        view_dtype_np = np.uint16
        if self.bank_cpu.dtype in [torch.float32]:
            view_dtype = torch.uint32
            view_dtype_np = np.uint32
        elif self.bank_cpu.dtype in [torch.float16, torch.bfloat16]:
            view_dtype = torch.uint16
            view_dtype_np = np.uint16
        elif self.bank_cpu.dtype in [torch.uint8, torch.float8_e5m2]:
            view_dtype = torch.uint8
            view_dtype_np = np.uint8
        else:
            raise Exception()
    
        t = np.empty(
            (table.shape[0], self.bank_cpu.shape[1], self.bank_cpu.shape[2]),
            dtype=view_dtype_np,
        )

        # print('gather pin', flush=True)
        # if pin_memory:
        #     t = t.pin_memory()
        
        # print('gather index_copy', flush=True)
        index_copy(
            self.bank_cpu.view(dtype=view_dtype).numpy(), 
            t, 
            table.numpy(),
            num_thread=os.cpu_count(),
        )
        # print('gather done', flush=True)
        
        t = torch.from_numpy(t).view(self.bank_cpu.dtype)
        # print('convert done', flush=True)
        
        return t

import numba
import numpy as np

@numba.njit(parallel=True)
def index_copy(src: np.ndarray, out: np.ndarray, table: np.ndarray, num_thread: int = 32):
    chunk_size = math.ceil(table.shape[0] / num_thread)
    for ithread in numba.prange(num_thread):
        for i in range(chunk_size):
            t = chunk_size * ithread + i
            if t < table.shape[0]:
                out[t] = src[table[t]]

def pad_to_cacheline(nelem: int, dtype: torch.dtype):
    byte_size = 4
    if dtype in [torch.int32, torch.uint32, torch.float32]:
        byte_size = 4
    elif dtype in [torch.int64, torch.uint64, torch.float64]:
        byte_size = 8
    elif dtype in [torch.int16, torch.uint16, torch.bfloat16, torch.float16]:
        byte_size = 2
    else:
        raise Exception()
    
    assert nelem > 0

    # in bytes
    cacheline_size = 32
    
    step = max(1, cacheline_size // byte_size)
    return nelem if (nelem % step) == 0 else (
        nelem + step - (nelem % step)
    )

class GPUCache:
    global_metadata: Tensor
    bank: Tensor
    metadata: Tensor
    table: Tensor

    def __init__(
        self, 
        k_uvm: UVMCache, 
        v_uvm: Optional[UVMCache],
        max_cache_token_size: int,
        online_cache_update: bool,
    ):
        self.k_uvm = k_uvm
        self.v_uvm = v_uvm
        self.head_num = self.k_uvm.head_num
        self.head_dim = self.k_uvm.head_dim
        self.dtype = self.k_uvm.dtype
        self.device = self.k_uvm.device
        self.kv_packed = self.v_uvm is not None
        if self.kv_packed:
            assert self.head_num == self.v_uvm.head_num
            self.head_dim += self.v_uvm.head_dim
        self.max_cache_token_size = max_cache_token_size
        self.max_uvm_token_size = self.k_uvm.max_token_size
        if self.kv_packed:
            assert self.max_uvm_token_size == self.v_uvm.max_token_size
        
        """
        [
            CachelinePadded { 
                current_tick: int32,
            }
            CachelinePadded {
                random_seed: int32,
            }
        ]
        """
        self.global_metadata = torch.zeros(
            (2, pad_to_cacheline(1, torch.int32)), 
            dtype=torch.int32, 
            device=self.device
        )
        
        self.bank = torch.zeros(
            (self.max_cache_token_size, self.head_dim), 
            dtype=self.dtype, 
            device=self.device
        )
        
        """
        CachelinePadded {
            [0] Back reference to table: int64,         # initial handshake, store token index of UVM bank
            [1] Reference to UVM Cache: int64,          # MAX_TOKEN, for token generation check
            [2] Token Generation of UVM Cache: int64,   # To check the version of cached token
            [3] Last accessed tick: int64,
            [4] Not accessed duration: int64,           # Increse one every step
            [5] Did taken in this kernel step: int64    # Reset to zero on every step
            [6] Token hash: int64                       # for debug
        }
        """
        self.metadata = torch.zeros(
            (self.max_cache_token_size, pad_to_cacheline(7, torch.int64)),
            dtype=torch.int64,
            device=self.device,
        )
        # self.metadata[:, 0].fill_(0)
        # self.metadata[:, 1].fill_(0)
        # self.metadata[:, 2].fill_(0)
        # self.metadata[:, 3].fill_(0)
        self.metadata[:, 4].fill_(1)
        # self.metadata[:, 5].fill_(0)
        
        # NOTE: this table is way too large to pad... sorry
        self.table = torch.full(
            (self.head_num, self.max_uvm_token_size, 1),
            dtype=torch.int32,
            device=self.device,
            fill_value=MAX_INT.value,
        )
        
        self.allocated_gpu_bytes = (
            sizeof(self.global_metadata) +
            sizeof(self.bank) + 
            sizeof(self.metadata) + 
            sizeof(self.table)
        )
        
        self.step = 0

        self.online_update_cache = online_cache_update

    def handle_cache_miss(
        self,
        metadata: HiPAttentionOutputMetadata,
        stats: HiPAttentionCacheAccessStatistics
    ):
        # self._verify_cache()

        # NOTE: increase not accessed timer
        self.metadata[:, 4].add_(1)
        self.metadata[:, 5].fill_(0)
        
        # if id(stats) == id(metadata.mask_cache_statistics): return
        if stats is None: return
        if self.online_update_cache: return
        
        # NOTE: this function should be capturable.
        # NOTE: this function will called only when mask is updated

        uvm_page_count = self.k_uvm.bank_cpu.shape[0]
        gpu_page_count = self.bank.shape[0]

        assert stats.cache_miss_counter.shape[1:] == (self.head_num, uvm_page_count), \
            f'{stats.cache_miss_counter.shape[1:]} == [{self.head_num}, {uvm_page_count}]'
    
        # update LRU recency
        # increase LRU step
        self.global_metadata[0, 0].add_(1)

        if stats.access_counter.shape[0] != 1:
            # assert False
            # NOTE: if paged attention, stats should be single batch.
            accessed = stats.access_counter.sum(0)
        else:
            accessed = stats.access_counter.squeeze(0)

        assert accessed.ndim == 2
        assert accessed.shape == (self.head_num, uvm_page_count)
        assert self.k_uvm.metadata.shape == (uvm_page_count, 1)
        assert self.global_metadata.shape == (2, pad_to_cacheline(1, self.global_metadata.dtype))
        assert self.metadata.shape == (self.bank.shape[0], pad_to_cacheline(5, self.metadata.dtype))
        assert self.table.shape == (self.head_num, uvm_page_count, 1)

        BLOCK_SIZE = 128
        grid = (self.head_num * triton.cdiv(uvm_page_count, BLOCK_SIZE), )
        update_recency[grid](
            accessed, *accessed.stride(),

            self.k_uvm.metadata, *self.k_uvm.metadata.stride(),

            self.global_metadata, *self.global_metadata.stride(),
            self.metadata, *self.metadata.stride(),
            self.table, *self.table.stride(),

            uvm_page_count,
            self.k_uvm.bank_cpu.shape[1],

            BLOCK_SIZE,

            num_warps=4,
        )
        self.step += 1

        # perform LRU
        assert gpu_page_count <= (uvm_page_count * self.head_num), f'{gpu_page_count} <= {(uvm_page_count * self.head_num)}'

        # cache_miss = ((stats.cache_miss_counter > 0) * stats.access_counter).sum(0).view(-1)
        if stats.cache_miss_counter.shape[0 == 1]:
            cache_miss = stats.cache_miss_counter.view(-1)
        else:
            cache_miss = ((stats.cache_miss_counter > 0) * stats.access_counter).sum(0).view(-1)
        put_mask = cache_miss > 0
        # put_priority_list = cache_miss.argsort(-1, descending=True, stable=False)
        # put_priority_list = put_priority_list[:gpu_page_count]
        put_priority_list = cache_miss.topk(k=gpu_page_count, dim=-1, largest=True, sorted=False).indices
        put_mask = put_mask[put_priority_list]

        slot_recency = self.metadata[:, 3]
        evict_priority_list = slot_recency.argsort(dim=-1, descending=False, stable=False)
        
        self.write_cache(
            put_list=put_priority_list,
            put_mask=put_mask,
            evict_list=evict_priority_list,
        )

        # NOTE: for debug
        self.verify_cache(put_mask)
    
    def verify_cache(self, put_mask: Optional[Tensor] = None, force: bool = False, max_verification: int = 10000):
        if (os.getenv('DEBUG_ONLINE_VERIFY', '0') == '0') and (not force):
            return
        if self.k_uvm.layer_id != 3 and (not force):
            return
        
        torch.cuda.synchronize()
        table = self.table.cpu()
        metadata = self.metadata.cpu()
        bank = self.bank.cpu()
        uvm_metadata = self.k_uvm.metadata.cpu()
        uvm_k_bank = self.k_uvm.bank_cpu
        uvm_v_bank = self.v_uvm.bank_cpu if self.kv_packed else None

        total_table_hit = 0
        total_back_ref_hit = 0
        total_uvm_ref_hit = 0
        total_token_gen_hit = 0
        total_hash_hit = 0
        total_cache_hit = 0
        for idx_head in range(table.shape[0]):
            for idx_page in tqdm.tqdm(range(min(table.shape[1], max_verification)), dynamic_ncols=True, leave=False):
                target_slot = table[idx_head, idx_page].item()
                if target_slot < MAX_INT:
                    total_table_hit += 1
                    back_ref, ref_to_uvm, token_gen, last_tick, sleep_tick, is_touched, token_hash = metadata[target_slot, :7]
                    if (back_ref // table.shape[0]) == idx_page:
                        total_back_ref_hit += 1
                        if ref_to_uvm < MAX_INT:
                            total_uvm_ref_hit += 1
                            if uvm_metadata[ref_to_uvm, 0] == token_gen:
                                total_token_gen_hit += 1
                                gpu_value = bank[target_slot]
                                if not self.kv_packed:
                                    cpu_value = uvm_k_bank[idx_page, idx_head]
                                else:
                                    cpu_value = torch.cat([
                                        uvm_k_bank[idx_page, idx_head],
                                        uvm_v_bank[idx_page, idx_head],
                                    ], dim=0)
                                gpu_hash = gpu_value.view(torch.uint16).to(torch.uint32).sum().item()
                                cpu_hash = cpu_value.view(torch.uint16).to(torch.uint32).sum().item()
                                mse = ((gpu_value - cpu_value) ** 2).mean().item()
                                check_pass = (mse < 1e-4) or True
                                if not check_pass:
                                    error_location = (uvm_k_bank.view(torch.uint16).to(torch.uint32).sum(dim=-1) == token_hash).nonzero()
                                    error_cpu_location = (uvm_k_bank.view(torch.uint16).to(torch.uint32).sum(dim=-1) == gpu_hash).nonzero()
                                    error_msg = f"""
cache_hit={total_cache_hit}, token_gen={total_token_gen_hit}, ref_to_uvm={total_uvm_ref_hit}, back_ref={total_back_ref_hit}
GPU = {gpu_value} ({gpu_value.shape})
-----
UVM = {cpu_value} ({cpu_value.shape})
token_hash={token_hash}, gpu_hash={gpu_hash}, cpu_hash={cpu_hash}, token_found={error_location}, uvm_found={error_cpu_location}
head={idx_head}, page={idx_page}, slot={target_slot}, backref={back_ref}, uvmref={ref_to_uvm}, gen={token_gen}, is_touched={is_touched}, mse={mse}"""
                                    assert False, error_msg
                                total_cache_hit += 1
                                if (gpu_hash == cpu_hash):
                                    total_hash_hit += 1
        
        if put_mask is not None:
            print('lastly put', put_mask.sum().item())
        tqdm.tqdm.write(f'verified kv_packed={self.kv_packed}, cache_hit={total_cache_hit}, token_gen={total_token_gen_hit}, ref_to_uvm={total_uvm_ref_hit}, back_ref={total_back_ref_hit}, hash_hit={total_hash_hit}')

    def write_cache(
        self,
        put_list: Tensor,
        put_mask: Tensor,
        evict_list: Tensor,
    ):
        assert put_list.shape == put_mask.shape
        assert evict_list.shape == put_list.shape

        BLOCK_SIZE = 128

        qsize = put_list.shape[0]

        grid = (triton.cdiv(qsize, BLOCK_SIZE),)
        write_cache[grid](
            put_list, *put_list.stride(),
            put_mask, *put_mask.stride(),
            evict_list, *evict_list.stride(),

            self.bank, *self.bank.stride(),
            self.metadata, *self.metadata.stride(),
            self.table, *self.table.stride(),

            self.k_uvm.metadata, *self.k_uvm.metadata.stride(),
            self.k_uvm.bank_gpu, *self.k_uvm.bank_gpu.stride(),
            self.v_uvm.bank_gpu if self.kv_packed else None, 
            *(self.v_uvm.bank_gpu.stride() if self.kv_packed else (0, 0, 0)),

            self.global_metadata, *self.global_metadata.stride(),

            qsize,
            self.k_uvm.bank_gpu.shape[0],
            self.k_uvm.bank_gpu.shape[1],

            self.kv_packed,
            BLOCK_SIZE,
            self.k_uvm.bank_gpu.shape[-1]
        )
    
    def on_set_kv_buffer(
        self,
        table: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
        # NOTE: cache_k, and cache_v might not on GPU (during prefill)
        # if table is allocated to valid slots, copy tensors to bank
        # if slot is not valid, unlink the table
        # NOTE: but for temporary, just unlink table always
        
        assert table.device == self.device
        
        self.table[:, :, 0].index_fill_(dim=1, index=table.to(torch.int64), value=MAX_INT.value)

    def flush(self):
        self.table[:, :, 0].fill_(MAX_INT.value)

class HiPOffloadCache:
    def __init__(
        self,
        layer_id: int, 
        max_token_size: int,
        max_mask_cache_token_size: int,
        max_sa_cache_token_size: int,
        head_num: int,
        head_dim: int,
        dtype: torch.dtype,
        device: torch.device,
        online_cache_update: bool,
    ):
        self.k_uvm = UVMCache(
            layer_id=layer_id,
            max_token_size=max_token_size,
            head_num=head_num,
            head_dim=head_dim,
            dtype=dtype,
            device=device,
        )
        
        self.v_uvm = UVMCache(
            layer_id=layer_id,
            max_token_size=max_token_size,
            head_num=head_num,
            head_dim=head_dim,
            dtype=dtype,
            device=device,
        )
        
        self.mask_k_cache = GPUCache(
            k_uvm=self.k_uvm,
            v_uvm=None,
            max_cache_token_size=max_mask_cache_token_size,
            online_cache_update=online_cache_update,
        )
        
        self.sa_kv_cache = GPUCache(
            k_uvm=self.k_uvm,
            v_uvm=self.v_uvm,
            max_cache_token_size=max_sa_cache_token_size,
            online_cache_update=online_cache_update,
        )
    
    def get_page_count(self):
        assert self.k_uvm.bank_cpu.shape == self.v_uvm.bank_cpu.shape
        return self.k_uvm.bank_cpu.shape[0]
    
    def prefetch_prefix_kv_buffer(
        self,
        table: Tensor,
        device: torch.device
    ) -> Tuple[Tensor, Tensor]:
        if table.device != torch.device('cpu'):
            table = table.to('cpu', non_blocking=False)
        k = self.k_uvm.gather_cpu(table, pin_memory=True)
        v = self.v_uvm.gather_cpu(table, pin_memory=True)
        k = k.to(device, non_blocking=False).unsqueeze(0)
        v = v.to(device, non_blocking=False).unsqueeze(0)
        return k, v
    
    def set_kv_buffer(
        self,
        table: torch.Tensor,
        table_gpu: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
        cache_device = cache_k.device
        assert table.device == cache_device
        assert cache_v.device == cache_device
        
        if cache_device == torch.device('cpu'):
            # self.k_uvm.bank_cpu[table] = cache_k
            # self.v_uvm.bank_cpu[table] = cache_v
            if cache_k.dtype in [torch.float16, torch.bfloat16]:
                view_dtype = torch.uint16
            elif cache_k.dtype in [torch.float32]:
                view_dtype = torch.uint32
            elif cache_k.dtype in [torch.uint8, torch.float8_e5m2]:
                view_dtype = torch.uint8
            else:
                raise Exception()

            set_kv_buffer_(
                self.k_uvm.bank_cpu.view(view_dtype).numpy(), 
                cache_k.view(view_dtype).numpy(),
                table.numpy(),
                os.cpu_count(),
            )
            
            set_kv_buffer_(
                self.v_uvm.bank_cpu.view(view_dtype).numpy(), 
                cache_v.view(view_dtype).numpy(),
                table.numpy(),
                os.cpu_count(),
            )
        else:
            assert cache_device == self.k_uvm.device
            self.k_uvm.bank_gpu[table] = cache_k
            self.v_uvm.bank_gpu[table] = cache_v
        
        self.mask_k_cache.on_set_kv_buffer(
            table=table_gpu,
            cache_k=cache_k,
            cache_v=cache_v,
        )
        self.sa_kv_cache.on_set_kv_buffer(
            table=table_gpu,
            cache_k=cache_k,
            cache_v=cache_v,
        )
        
        self.k_uvm.metadata.index_put_(
            indices=(table, ), 
            values=torch.index_select(self.k_uvm.metadata, index=table_gpu, dim=0) + 1
        )
        self.v_uvm.metadata.index_put_(
            indices=(table, ), 
            values=torch.index_select(self.v_uvm.metadata, index=table_gpu, dim=0) + 1
        )
    
    def handle_cache_miss(self, metadata: HiPAttentionOutputMetadata):
        if metadata.mask_cache_statistics is not None:
            self.mask_k_cache.handle_cache_miss(
                metadata=metadata,
                stats=metadata.mask_cache_statistics
            )
            self.sa_kv_cache.handle_cache_miss(
                metadata=metadata,
                stats=metadata.sa_cache_statistics
            )
        else:
            self.mask_k_cache.handle_cache_miss(
                metadata=metadata,
                stats=None
            )
            self.sa_kv_cache.handle_cache_miss(
                metadata=metadata,
                stats=None
            )

###############################################################################
#                               Kernel Function
###############################################################################

@numba.njit(parallel=True, fastmath=True)
def set_kv_buffer_(
    bank: np.ndarray, 
    cache: np.ndarray, 
    table: np.ndarray, 
    num_threads: int
):
    assert table.shape[0] == cache.shape[0]
    
    chunk_size = math.ceil(table.shape[0] / num_threads)
    for ithread in numba.prange(num_threads):
        for i in range(chunk_size):
            t = ithread * chunk_size + i
            if t < table.shape[0]:
                bank[table[t]] = cache[t]

@triton.jit
def validate_bank_metadata_slots(
    UVM_METADATA,
    stride_uvm_metadata_token, stride_uvm_metadata_k,
    METADATA,
    stride_metadata_slot, stride_metadata_k,
    
    idx_slot,
    idx_page, # this is optional. if given, check backref
    cache_hit,
    
    HEAD_KV,
):
    cache_hit = (idx_slot < MAX_INT) & cache_hit
    
    back_ref = tl.load(
        METADATA +\
            idx_slot.to(tl.int64) * stride_metadata_slot +\
            0 * stride_metadata_k,
        mask=cache_hit,
    )
    
    if idx_page is not None:
        cache_hit = ((back_ref // HEAD_KV) == idx_page) & (back_ref < MAX_INT) & cache_hit
    else:
        cache_hit = (back_ref < MAX_INT) & cache_hit

    ref_to_uvm = tl.load(
        METADATA +\
            idx_slot.to(tl.int64) * stride_metadata_slot +\
            1 * stride_metadata_k,
        mask=cache_hit,
    ).to(tl.int64)
    cache_hit = (ref_to_uvm < MAX_INT) & cache_hit

    uvm_token_gen = tl.load(
        UVM_METADATA +\
            ref_to_uvm.to(tl.int64) * stride_uvm_metadata_token +\
            0 * stride_uvm_metadata_k,
        mask=cache_hit,
    )
    cache_token_gen = tl.load(
        METADATA +\
            idx_slot.to(tl.int64) * stride_metadata_slot +\
            2 * stride_metadata_k,
        mask=cache_hit,
    )
    cache_hit = (
        uvm_token_gen < MAX_INT
    ) & (
        uvm_token_gen == cache_token_gen
    ) & cache_hit
    
    tl.static_assert(cache_hit.dtype == tl.int1)
    
    return cache_hit

@triton.jit
def load_tokens(
    K, 
    stride_k_bsz,
    stride_k_tsrc,
    stride_k_head,
    stride_k_hid,
    
    # paged attention args template
    USING_PAGES: tl.constexpr,
    PAGE_SIZE: tl.constexpr,
    K_CACHE, 
    stride_k_cache_page, 
    stride_k_cache_offset, 
    stride_k_cache_kv_head, 
    stride_k_cache_hid,
    BLOCK_TABLE,
    stride_block_table_bsz,
    stride_block_table_page,
    CACHE_SEQ_LENS,
    stride_cache_seq_lens_b,
    
    USING_OFFLOAD_CACHE: tl.constexpr,
    OFFLOAD_CACHE_KV_PACKED: tl.constexpr,
    GPU_BANK_COUNT: int,
    OFFLOAD_CACHE_LOAD_VALUE: tl.constexpr,
    OFFLOAD_CACHE_UVM_METADATA,
    stride_offload_cache_uvm_metadata_token,
    stride_offload_cache_uvm_metadata_k,
    OFFLOAD_CACHE_GPU_GLOBAL_METADATA,
    stride_offload_cache_gpu_global_metadata_k,
    stride_offload_cache_gpu_global_metadata_pad,
    OFFLOAD_CACHE_GPU_BANK,
    stride_offload_cache_gpu_bank_token,
    stride_offload_cache_gpu_bank_hid,
    OFFLOAD_CACHE_GPU_METADATA,
    stride_offload_cache_gpu_metadata_token,
    stride_offload_cache_gpu_metadata_k,
    OFFLOAD_CACHE_GPU_TABLE,
    stride_offload_cache_gpu_table_head_kv,
    stride_offload_cache_gpu_table_token,
    strdie_offload_cache_gpu_table_k,
    
    ACCESS_COUNTER,
    stride_access_counter_bsz,
    stride_access_counter_head_kv,
    stride_access_counter_tsrc,
    CACHE_MISS_COUNTER,
    stride_cache_miss_counter_bsz,
    stride_cache_miss_counter_head_kv,
    stride_cache_miss_counter_tsrc,
    
    idx_bsz,
    idx_tsrc,
    idx_kv_head,
    idx_hid,
    
    mask_keys,
    
    HEAD_KV: int,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_HID: tl.constexpr,
    
    IS_BSA: tl.constexpr = False,
    UPDATE_CACHE: tl.constexpr = False,
    
    V_CACHE = None,
    stride_v_cache_page = None, 
    stride_v_cache_offset = None, 
    stride_v_cache_kv_head = None, 
    stride_v_cache_hid = None,
):
    # DEBUG: to load nothing
    # mask_keys = mask_keys & False
    
    # tl.static_print(OFFLOAD_CACHE_METHOD)
    
    if not USING_PAGES:
        tl.static_assert(not USING_OFFLOAD_CACHE)
        
        if ACCESS_COUNTER is not None:
            # tl.atomic_add(
            #     ACCESS_COUNTER +\
            #         idx_bsz.to(tl.int64) * stride_access_counter_bsz +\
            #         idx_kv_head * stride_access_counter_head_kv +\
            #         idx_tsrc * stride_access_counter_tsrc,
            #     mask=mask_keys,
            #     val=1
            # )
            
            tl.store(
                ACCESS_COUNTER +\
                    idx_bsz.to(tl.int64) * stride_access_counter_bsz +\
                    idx_kv_head * stride_access_counter_head_kv +\
                    idx_tsrc * stride_access_counter_tsrc,
                mask=mask_keys,
                value=1
            )
        
        keys = tl.load(
            K +\
                idx_bsz.to(tl.int64) * stride_k_bsz +\
                idx_tsrc.to(tl.int64) * stride_k_tsrc +\
                idx_kv_head.to(tl.int64) * stride_k_head +\
                idx_hid.to(tl.int64) * stride_k_hid,
            mask = mask_keys,
            other = 0.0,
            # cache_modifier='.cs', # TODO: uncomment this
        )
    else:
        seq_len = tl.load(
            CACHE_SEQ_LENS +\
                idx_bsz.to(tl.int64) * stride_cache_seq_lens_b,
        )
        mask_tsrc = (idx_tsrc >= 0) & (idx_tsrc < seq_len)
        ptrs = BLOCK_TABLE +\
            idx_bsz.to(tl.int64) * stride_block_table_bsz + \
            (idx_tsrc // PAGE_SIZE).to(tl.int64) * stride_block_table_page
        idx_page = tl.load(
            ptrs,
            mask=mask_tsrc,
            other=0,
        ).to(tl.int64)
        offset_page = idx_tsrc % PAGE_SIZE
        
        if ACCESS_COUNTER is not None:
            # tl.atomic_add(
            #     ACCESS_COUNTER +\
            #         idx_bsz.to(tl.int64) * stride_access_counter_bsz +\
            #         idx_kv_head * stride_access_counter_head_kv +\
            #         idx_page * stride_access_counter_tsrc,
            #     mask=mask_keys,
            #     val=1
            # )
            
            tl.store(
                ACCESS_COUNTER +\
                    idx_bsz.to(tl.int64) * stride_access_counter_bsz +\
                    idx_kv_head * stride_access_counter_head_kv +\
                    idx_page * stride_access_counter_tsrc,
                mask=mask_keys,
                value=1
            )
        
        original_mask_keys = mask_keys
        
        if USING_OFFLOAD_CACHE:
            tl.static_assert(PAGE_SIZE == 1)
            
            idx_slots = tl.load(
                OFFLOAD_CACHE_GPU_TABLE +\
                    idx_page.to(tl.int64) * stride_offload_cache_gpu_table_token +\
                    idx_kv_head * stride_offload_cache_gpu_table_head_kv +\
                    0 * strdie_offload_cache_gpu_table_k,
                mask=mask_keys,
            )
            idx_slot_has_reference_to_bank = (idx_slots < MAX_INT) & mask_keys
            idx_slots = idx_slots * idx_slot_has_reference_to_bank
            
            ALWAYS_VALIDATE_LINK: tl.constexpr = False # not UPDATE_CACHE
            
            if ALWAYS_VALIDATE_LINK:
                validated_slots = validate_bank_metadata_slots(
                    OFFLOAD_CACHE_UVM_METADATA,
                    stride_offload_cache_uvm_metadata_token,
                    stride_offload_cache_uvm_metadata_k,
                    
                    OFFLOAD_CACHE_GPU_METADATA,
                    stride_offload_cache_gpu_metadata_token,
                    stride_offload_cache_gpu_metadata_k,
                    
                    idx_slots,
                    idx_page,
                    idx_slot_has_reference_to_bank,
                    
                    HEAD_KV,
                )
                
                mask_slot_cache_hit = validated_slots & idx_slot_has_reference_to_bank
            else:
                mask_slot_cache_hit = idx_slot_has_reference_to_bank
            
            # if OFFLOAD_CACHE_LOAD_VALUE:
            #     mask_slot_cache_hit = mask_slot_cache_hit & False
            
            idx_hid_cached = idx_hid
            if OFFLOAD_CACHE_LOAD_VALUE:
                idx_hid_cached += BLOCK_HID
            keys_cached = tl.load(
                OFFLOAD_CACHE_GPU_BANK +\
                    idx_slots.to(tl.int64) * stride_offload_cache_gpu_bank_token +\
                    idx_hid_cached * stride_offload_cache_gpu_bank_hid,
                mask=mask_slot_cache_hit,
                other=0.0,
            )
            if keys_cached.dtype == tl.uint8:
                keys_cached = keys_cached.to(tl.float8e5, bitcast=True).to(tl.bfloat16)
            if keys_cached.dtype == tl.float8e5:
                keys_cached = keys_cached.to(tl.bfloat16)
            
            if UPDATE_CACHE:
                idx_slots_verify = tl.load(
                    OFFLOAD_CACHE_GPU_TABLE +\
                        idx_kv_head.to(tl.int64) * stride_offload_cache_gpu_table_head_kv +\
                        idx_page * stride_offload_cache_gpu_table_token +\
                        0 * strdie_offload_cache_gpu_table_k,
                    mask=mask_slot_cache_hit,
                )
                mask_slot_cache_hit = (
                    idx_slots_verify < MAX_INT
                ) & (
                    idx_slots == idx_slots_verify
                ) & mask_slot_cache_hit
                
            if mask_slot_cache_hit.shape[0] == 1:
                keys_cached_hash = tl.sum(keys_cached.to(tl.uint16, bitcast=True), axis=0, keep_dims=True).to(tl.uint64)
            elif mask_slot_cache_hit.shape[1] == 1:
                keys_cached_hash = tl.sum(keys_cached.to(tl.uint16, bitcast=True), axis=1, keep_dims=True).to(tl.uint64)
            else:
                raise Exception()
            tl.debug_barrier()
            
            tl.inline_asm_elementwise("MEMBAR.SC.GPU;", "=r", [], dtype=tl.int32, is_pure=True, pack=1)
            
            truth_hash = tl.load(
                OFFLOAD_CACHE_GPU_METADATA +\
                    idx_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                    6 * stride_offload_cache_gpu_metadata_k,
                mask=mask_slot_cache_hit,
            ).to(tl.uint64)
            hash_mask = tl.full((1,), value=1, dtype=tl.uint64)
            hash_mask = (hash_mask << 32) - 1
            if OFFLOAD_CACHE_LOAD_VALUE:
                truth_hash = (truth_hash >> 32) & hash_mask
            else:
                truth_hash = truth_hash & hash_mask
            tl.debug_barrier()
            if UPDATE_CACHE:
                mask_slot_cache_hit = (truth_hash == (keys_cached_hash & hash_mask)) & mask_slot_cache_hit
            
            tl.static_assert(mask_slot_cache_hit.dtype == tl.int1)
            
            mask_keys_cache_miss = mask_keys & (~mask_slot_cache_hit)
            mask_slot_cache_hit = mask_keys & (~mask_keys_cache_miss)
            mask_keys = mask_keys_cache_miss
            
            tl.store(
                OFFLOAD_CACHE_GPU_METADATA +\
                    idx_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                    4 * stride_offload_cache_gpu_metadata_k,
                value=0,
                mask=mask_slot_cache_hit,
            )
            # tl.store(
            #     CACHE_MISS_COUNTER +\
            #         idx_bsz.to(tl.int64) * stride_cache_miss_counter_bsz +\
            #         idx_kv_head * stride_cache_miss_counter_head_kv +\
            #         idx_page * stride_cache_miss_counter_tsrc,
            #     mask=mask_slot_cache_hit,
            #     value=0,
            # )
        
        keys = tl.load(
            K_CACHE +\
                idx_page.to(tl.int64) * stride_k_cache_page +\
                offset_page.to(tl.int64) * stride_k_cache_offset +\
                idx_kv_head.to(tl.int64) * stride_k_cache_kv_head +\
                idx_hid.to(tl.int64) * stride_k_cache_hid,
            mask=mask_keys,
            other=0.0,
        )
        if keys.dtype == tl.uint8:
            keys = keys.to(tl.float8e5, bitcast=True).to(tl.bfloat16)
        if keys.dtype == tl.float8e5:
            keys = keys.to(tl.bfloat16)
        
        if USING_OFFLOAD_CACHE:
            keys = tl.where(
                mask_keys,
                keys,
                keys_cached,
            )
            
            if CACHE_MISS_COUNTER is not None:
                if UPDATE_CACHE:
                    tl.debug_barrier()
                    tl.store(
                        CACHE_MISS_COUNTER +\
                            idx_bsz.to(tl.int64) * stride_cache_miss_counter_bsz +\
                            idx_kv_head * stride_cache_miss_counter_head_kv +\
                            idx_page * stride_cache_miss_counter_tsrc,
                        mask=mask_keys_cache_miss,
                        value=1,
                    )
                    # mask_victim_slots = (cache_miss_counter != 1) & mask_keys_cache_miss
                    mask_victim_slots = mask_keys_cache_miss # NOTE: init value if cache miss counter is ignored
                    # table is protected by cache miss counter
                    previous_table = tl.atomic_xchg(
                        OFFLOAD_CACHE_GPU_TABLE +\
                            idx_page.to(tl.int64) * stride_offload_cache_gpu_table_token +\
                            idx_kv_head * stride_offload_cache_gpu_table_head_kv +\
                            0 * strdie_offload_cache_gpu_table_k,
                        val=MAX_INT_ACQUIRED,
                        mask=mask_victim_slots,
                    )
                    mask_victim_slots = (previous_table != MAX_INT_ACQUIRED) & mask_victim_slots
                    mask_table_acquired = mask_victim_slots
                    
                    tl.debug_barrier()
                    
                    idx_victim_slots = tl.zeros_like(idx_slots).to(tl.int64) + MAX_INT
                    max_not_accessed_time = tl.zeros_like(idx_slots).to(tl.int64) - 1
                    victim_slot_not_acquired = mask_victim_slots
                    seed = tl.atomic_add(
                        OFFLOAD_CACHE_GPU_GLOBAL_METADATA +\
                            1 * stride_offload_cache_gpu_global_metadata_k +\
                            0 * stride_offload_cache_gpu_global_metadata_pad,
                        val=1,
                    )
                    for i in range(10):
                        pid = tl.program_id(0).to(tl.int64) * tl.num_programs(1) * tl.num_programs(2) + tl.program_id(1) * tl.num_programs(2) + tl.program_id(2)
                        if original_mask_keys.shape[0] == 1:
                            idx_randint = (tl.randint(seed, tl.arange(0, BLOCK_SIZE_K)[None, :] + i * BLOCK_SIZE_K, 10).to(tl.int64) & ((1 << 30) - 1))
                        else:
                            idx_randint = (tl.randint(seed, tl.arange(0, BLOCK_SIZE_K)[:, None] + i * BLOCK_SIZE_K, 10).to(tl.int64) & ((1 << 30) - 1))
                        # if IS_BSA:
                        #     idx_victim_slots_try = idx_randint % (8192 * HEAD_KV)
                        # else:
                        #     idx_victim_slots_try = idx_randint % (8192 * HEAD_KV)
                        # idx_victim_slots_try = idx_victim_slots_try * 128 + tl.extra.cuda.smid()
                        idx_victim_slots_try = idx_randint
                        # idx_victim_slots_try = idx_randint * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
                        idx_victim_slots_try = idx_victim_slots_try % GPU_BANK_COUNT
                        # if IS_BSA:
                        #     idx_victim_slots_try = idx_victim_slots_try % (10000 * HEAD_KV)
                        # else:
                        #     idx_victim_slots_try = idx_victim_slots_try % (32000 * HEAD_KV)
                        
                        acquired = victim_slot_not_acquired
                        
                        not_accessed_time = tl.load(
                            OFFLOAD_CACHE_GPU_METADATA +\
                                idx_victim_slots_try.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                                4 * stride_offload_cache_gpu_metadata_k,
                            mask=mask_victim_slots,
                        )
                        new_old_slot = (not_accessed_time > max_not_accessed_time) & mask_victim_slots
                        # if already acquired, release it
                        tl.atomic_xchg(
                            OFFLOAD_CACHE_GPU_METADATA +\
                                idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                                5 * stride_offload_cache_gpu_metadata_k,
                            val=0,
                            mask=new_old_slot & (~victim_slot_not_acquired),
                        )
                        tl.atomic_xchg(
                            OFFLOAD_CACHE_GPU_METADATA +\
                                idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                                0 * stride_offload_cache_gpu_metadata_k,
                            val=MAX_INT,
                            mask=new_old_slot & (~victim_slot_not_acquired),
                        )
                        max_not_accessed_time = tl.maximum(max_not_accessed_time, not_accessed_time)
                        victim_slot_not_acquired = victim_slot_not_acquired | new_old_slot
                        acquired = victim_slot_not_acquired & (not_accessed_time > 0)
                        
                        # check already written or not
                        previous_state = tl.atomic_xchg(
                            OFFLOAD_CACHE_GPU_METADATA +\
                                idx_victim_slots_try.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                                5 * stride_offload_cache_gpu_metadata_k,
                            val=1, # NOTE: this should be MAX_INT_1, but just for temporary.
                            mask=acquired,
                        )
                        acquired = (previous_state != 1) & acquired

                        # check acquired or not
                        previous_state = tl.atomic_xchg(
                            OFFLOAD_CACHE_GPU_METADATA +\
                                idx_victim_slots_try.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                                0 * stride_offload_cache_gpu_metadata_k,
                            val=MAX_INT_ACQUIRED, # NOTE: this should be MAX_INT_1, but just for temporary.
                            mask=acquired,
                        )
                        acquired = (previous_state != MAX_INT_ACQUIRED) & acquired
                        
                        previously_acquired = (previous_state < MAX_INT) & acquired
                        previous_idx_page = previous_state // HEAD_KV
                        previous_idx_head_kv = previous_state % HEAD_KV
                        tl.atomic_xchg(
                            OFFLOAD_CACHE_GPU_TABLE +\
                                previous_idx_page.to(tl.int64) * stride_offload_cache_gpu_table_token +\
                                previous_idx_head_kv * stride_offload_cache_gpu_table_head_kv +\
                                0 * strdie_offload_cache_gpu_table_k,
                            val=MAX_INT,
                            mask=previously_acquired,
                        )
                        
                        idx_victim_slots = tl.where(
                            acquired,
                            idx_victim_slots_try,
                            idx_victim_slots,
                        )
                        
                        victim_slot_not_acquired = (~acquired) & victim_slot_not_acquired
                        tl.debug_barrier()
                    tl.debug_barrier()
                    # mask_victim_slots = mask_victim_slots & (idx_victim_slots < MAX_INT)
                    mask_victim_slots = mask_victim_slots & (~victim_slot_not_acquired) & mask_keys_cache_miss & (idx_victim_slots != MAX_INT) & original_mask_keys
                    # if not IS_BSA:
                    #     mask_victim_slots = mask_victim_slots & (~victim_slot_not_acquired) & (idx_victim_slots < (32000 * HEAD_KV)) & (idx_victim_slots > -1)
                    # else:
                    #     mask_victim_slots = mask_victim_slots & (~victim_slot_not_acquired) & (idx_victim_slots < (10000 * HEAD_KV)) & (idx_victim_slots > -1)
                    idx_victim_slots = idx_victim_slots * mask_victim_slots
                    tl.debug_barrier()
                    
                    tl.store(
                        OFFLOAD_CACHE_GPU_BANK +\
                            idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_bank_token +\
                            idx_hid_cached * stride_offload_cache_gpu_bank_hid,
                        value=keys,
                        mask=mask_victim_slots,
                    )

                    # take token hash for debug
                    if mask_victim_slots.shape[0] == 1:
                        keys_hash = tl.sum(keys.to(tl.uint16, bitcast=True), axis=0, keep_dims=True).to(tl.uint64)
                    elif mask_victim_slots.shape[1] == 1:
                        keys_hash = tl.sum(keys.to(tl.uint16, bitcast=True), axis=1, keep_dims=True).to(tl.uint64)
                    else:
                        raise Exception()
                    
                    if IS_BSA:
                        values = tl.load(
                            V_CACHE +\
                                idx_page.to(tl.int64) * stride_v_cache_page +\
                                offset_page.to(tl.int64) * stride_v_cache_offset +\
                                idx_kv_head.to(tl.int64) * stride_v_cache_kv_head +\
                                idx_hid.to(tl.int64) * stride_v_cache_hid,
                            mask=mask_victim_slots,
                            other=0.0,
                        )
                        if values.dtype == tl.uint8:
                            values = values.to(tl.float8e5, bitcast=True).to(tl.bfloat16)
                        if values.dtype == tl.float8e5:
                            values = values.to(tl.bfloat16)
                        tl.store(
                            OFFLOAD_CACHE_GPU_BANK +\
                                idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_bank_token +\
                                # idx_hid * stride_offload_cache_gpu_bank_hid,
                                ((idx_hid_cached + BLOCK_HID) % (BLOCK_HID * 2)) * stride_offload_cache_gpu_bank_hid,
                            value=values,
                            mask=mask_victim_slots,
                        )
                        
                        if mask_victim_slots.shape[0] == 1:
                            values_hash = tl.sum(values.to(tl.uint16, bitcast=True), axis=0, keep_dims=True).to(tl.uint64)
                        elif mask_victim_slots.shape[1] == 1:
                            values_hash = tl.sum(values.to(tl.uint16, bitcast=True), axis=1, keep_dims=True).to(tl.uint64)
                        else:
                            raise Exception()
                        if not OFFLOAD_CACHE_LOAD_VALUE:
                            keys_hash = keys_hash | (values_hash << 32)
                        else:
                            keys_hash = (keys_hash << 32) | values_hash

                    tl.store(
                        OFFLOAD_CACHE_GPU_METADATA +\
                            idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                            6 * stride_offload_cache_gpu_metadata_k,
                        value=keys_hash,
                        mask=mask_victim_slots,
                    )
                    
                    uvm_token_gen = tl.load(
                        OFFLOAD_CACHE_UVM_METADATA +\
                            idx_page.to(tl.int64) * stride_offload_cache_uvm_metadata_token +\
                            0 * stride_offload_cache_uvm_metadata_k,
                        mask=mask_victim_slots,
                    )
                    
                    tl.store(
                        OFFLOAD_CACHE_GPU_METADATA +\
                            idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                            1 * stride_offload_cache_gpu_metadata_k,
                        value=idx_page,
                        mask=mask_victim_slots,
                    )
                    tl.store(
                        OFFLOAD_CACHE_GPU_METADATA +\
                            idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                            2 * stride_offload_cache_gpu_metadata_k,
                        value=uvm_token_gen,
                        mask=mask_victim_slots,
                    )
                    tl.store(
                        OFFLOAD_CACHE_GPU_METADATA +\
                            idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                            4 * stride_offload_cache_gpu_metadata_k,
                        value=0,
                        mask=mask_victim_slots,
                    )
                    
                    tl.debug_barrier()
                    
                    # core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
                    #                    _builder=_builder)
                    tl.inline_asm_elementwise("MEMBAR.SC.GPU;", "=r", [], dtype=tl.int32, is_pure=True, pack=1)
                    
                    # release slot
                    tl.atomic_xchg(
                        OFFLOAD_CACHE_GPU_METADATA +\
                            idx_victim_slots.to(tl.int64) * stride_offload_cache_gpu_metadata_token+\
                            0 * stride_offload_cache_gpu_metadata_k,
                        val=idx_page * HEAD_KV + idx_kv_head,
                        mask=mask_victim_slots,
                    )
                    # release table
                    table_slots = tl.where(
                        mask_table_acquired & (~mask_victim_slots),
                        MAX_INT,
                        idx_victim_slots,
                    )
                    tl.atomic_xchg(
                        OFFLOAD_CACHE_GPU_TABLE +\
                            idx_page.to(tl.int64) * stride_offload_cache_gpu_table_token +\
                            idx_kv_head * stride_offload_cache_gpu_table_head_kv +\
                            0 * strdie_offload_cache_gpu_table_k,
                        val=table_slots,
                        mask=mask_table_acquired,
                    )
                else:
                    tl.store(
                        CACHE_MISS_COUNTER +\
                            idx_bsz.to(tl.int64) * stride_cache_miss_counter_bsz +\
                            idx_kv_head * stride_cache_miss_counter_head_kv +\
                            idx_page * stride_cache_miss_counter_tsrc,
                        mask=mask_keys_cache_miss,
                        value=1,
                    )
    if keys.dtype == tl.uint8:
        keys = keys.to(tl.float8e5, bitcast=True).to(tl.float16)
    if keys.dtype == tl.float8e5:
        keys = keys.to(tl.float16)
    
    return keys

def update_recency_pytorch(
    accessed_ptr: Tensor,
    uvm_metadata: Tensor,
    global_metadata: Tensor,
    metadata: Tensor,
    table: Tensor,
    head_num: int,
    uvm_page_count: int,
):
    for idx_head_kv in range(head_num):
        for idx_token in tqdm.tqdm(range(uvm_page_count), dynamic_ncols=True, leave=False):
            current_tick = global_metadata[0, 0]
            
            accessed = accessed_ptr[idx_head_kv, idx_token] > 0
            cache_hit = True & accessed
            if not cache_hit: continue

            idx_table = table[idx_head_kv, idx_token]
            cache_hit = (idx_table != MAX_INT) & cache_hit
            if not cache_hit: continue
            
            back_ref = metadata[idx_table, 0]
            cache_hit = (back_ref == idx_token) & cache_hit
            if not cache_hit: continue

            ref_to_uvm = metadata[idx_table, 1]
            cache_hit = (ref_to_uvm != MAX_INT) & cache_hit
            if not cache_hit: continue

            uvm_token_gen = uvm_metadata[ref_to_uvm, 0]
            cache_token_gen = metadata[idx_table, 2]
            cache_hit = (
                uvm_token_gen != MAX_INT
            ) & (
                uvm_token_gen == cache_token_gen
            ) & cache_hit
            if not cache_hit: continue

            metadata[idx_table, 3] = current_tick.to(metadata.dtype)

@triton.jit
def update_recency(
    ACCESSED,
    stride_accessed_head_kv, stride_accessed_token,

    UVM_METADATA,
    stride_uvm_metadata_token, stride_uvm_metadata_k,

    GLOBAL_METADTA,
    stride_global_metadata_k, stride_global_metadata_pad,
    METADATA,
    stride_metadata_slot, stride_metadata_k,
    TABLE,
    stride_table_head_kv, stride_table_token, stride_table_k,

    page_count: int,
    HEAD_KV: int,

    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0).to(tl.int64)

    idx_block = pid % tl.cdiv(page_count, BLOCK_SIZE)
    idx_head_kv = pid // tl.cdiv(page_count, BLOCK_SIZE)

    idx_token = idx_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask_token = idx_token < page_count

    current_tick = tl.load(
        GLOBAL_METADTA +\
            0 * stride_global_metadata_k+\
            0 * stride_global_metadata_pad
    )

    #TODO: merge with load tokens, verify cache
    accessed = tl.load(
        ACCESSED +\
            idx_head_kv.to(tl.int64) * stride_accessed_head_kv +\
            idx_token * stride_accessed_token,
        mask=mask_token,
        other=0,
    ) > 0
    cache_hit = mask_token & accessed

    table = tl.load(
        TABLE +\
            idx_head_kv.to(tl.int64) * stride_table_head_kv +\
            idx_token * stride_table_token +\
            0 * stride_table_k,
        mask=cache_hit,
        other=MAX_INT,
    ).to(tl.int64)
    cache_hit = (table < MAX_INT) & cache_hit
    
    ALWAYS_VALIDATE_LINK: tl.constexpr = False # True
    
    if ALWAYS_VALIDATE_LINK:
        validated_cache_hit = validate_bank_metadata_slots(
            UVM_METADATA, 
            stride_uvm_metadata_token, 
            stride_uvm_metadata_k,
            
            METADATA, 
            stride_metadata_slot, 
            stride_metadata_k,
            
            table,
            idx_token,
            cache_hit,
            
            HEAD_KV,
        )
        cache_hit = cache_hit & validated_cache_hit

    tl.store(
        METADATA +\
            table.to(tl.int64) * stride_metadata_slot +\
            3 * stride_metadata_k,
        mask=cache_hit,
        value=current_tick,
    )

@triton.jit
def write_cache(
    PUT, stride_put_t,
    MASK, stride_mask_t,
    EVICT, stride_evict_t,

    BANK, 
    stride_bank_t, 
    stride_bank_hid,
    METADATA, 
    stride_metadata_t, 
    stride_metadata_k,
    TABLE, 
    stride_table_head_kv, 
    stride_table_t, 
    stride_table_k,

    UVM_METADATA,
    stride_uvm_metadata_t, 
    stride_uvm_metadata_k,
    UVM_K_BANK,
    stride_uvm_k_bank_t, 
    stride_uvm_k_bank_head_kv, 
    stride_uvm_k_bank_hid,
    UVM_V_BANK,
    stride_uvm_v_bank_t, 
    stride_uvm_v_bank_head_kv, 
    stride_uvm_v_bank_hid,

    GLOBAL_METADATA,
    stride_global_metadata_t,
    stride_global_metadata_k,

    qsize: int,
    page_count: int,
    HEAD_KV: int,
    
    KV_PACKED: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    BLOCK_HID: tl.constexpr,
):
    pid = tl.program_id(0)
    idx_queue = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask_queue = idx_queue < qsize

    put_list = tl.load(
        PUT + idx_queue.to(tl.int64) * stride_put_t,
        mask=mask_queue,
    )
    idx_page = put_list % page_count
    idx_head_kv = put_list // page_count

    mask_put = (tl.load(
        MASK + idx_queue.to(tl.int64) * stride_mask_t,
        mask=mask_queue,
        other=0
    ) != 0)
    idx_evict = tl.load(
        EVICT + idx_queue.to(tl.int64) * stride_evict_t,
        mask=mask_put
    )
    
    # check still it is cache miss
    idx_slot = tl.load(
        TABLE +\
            idx_head_kv.to(tl.int64) * stride_table_head_kv +\
            idx_page * stride_table_t +\
            0 * stride_table_k,
        mask=mask_put,
        other=MAX_INT,
    )
    is_valid_slot = validate_bank_metadata_slots(
        UVM_METADATA, 
        stride_uvm_metadata_k, stride_uvm_metadata_k,
        METADATA,
        stride_metadata_t, stride_metadata_k,
        idx_slot,
        idx_page,
        mask_put,
        HEAD_KV,
    )
    mask_put = mask_put & (~is_valid_slot)
    
    # unlink bank <-> table
    victim_table_entry = tl.load(
        METADATA +\
            idx_evict.to(tl.int64) * stride_metadata_t +\
            0 * stride_metadata_k,
        mask=mask_put,
    )
    tl.store(
        TABLE +\
            (victim_table_entry.to(tl.int64) % HEAD_KV) * stride_table_head_kv +\
            (victim_table_entry // HEAD_KV) * stride_table_t +\
            0 * stride_table_k,
        value=MAX_INT,
        mask=mask_put & (victim_table_entry < MAX_INT),
    )

    # setup metadata
    tl.store(
        METADATA +\
            idx_evict.to(tl.int64) * stride_metadata_t +\
            0 * stride_metadata_k,
        mask=mask_put,
        value=idx_page * HEAD_KV + idx_head_kv
    )
    tl.store(
        METADATA +\
            idx_evict.to(tl.int64) * stride_metadata_t +\
            1 * stride_metadata_k,
        mask=mask_put,
        value=idx_page,
    )
    token_gen = tl.load(
        UVM_METADATA +\
            idx_page.to(tl.int64) * stride_uvm_metadata_t +\
            0 * stride_uvm_metadata_k,
        mask=mask_put,
    )
    tl.store(
        METADATA +\
            idx_evict.to(tl.int64) * stride_metadata_t +\
            2 * stride_metadata_k,
        mask=mask_put,
        value=token_gen,
    )
    current_tick = tl.load(
        GLOBAL_METADATA +\
            0 * stride_global_metadata_t +\
            0 * stride_global_metadata_k,
    )
    tl.store(
        METADATA +\
            idx_evict.to(tl.int64) * stride_metadata_t +\
            3 * stride_metadata_k,
        mask=mask_put,
        value=current_tick,
    )

    # setup table
    tl.store(
        TABLE +\
            idx_page.to(tl.int64) * stride_table_t +\
            idx_head_kv * stride_table_head_kv +\
            0 * stride_table_k,
        mask=mask_put,
        value=idx_evict,
    )

    # copy values
    idx_hid = tl.arange(0, BLOCK_HID)

    keys = tl.load(
        UVM_K_BANK +\
            idx_page[:, None].to(tl.int64) * stride_uvm_k_bank_t +\
            idx_head_kv[:, None] * stride_uvm_k_bank_head_kv +\
            idx_hid[None, :] * stride_uvm_k_bank_hid,
        mask=mask_put[:, None],
    )
    tl.store(
        BANK +\
            idx_evict[:, None].to(tl.int64) * stride_bank_t +\
            idx_hid[None, :] * stride_bank_hid,
        mask=mask_put[:, None],
        value=keys,
    )

    if KV_PACKED:
        values = tl.load(
            UVM_V_BANK +\
                idx_page[:, None].to(tl.int64) * stride_uvm_v_bank_t +\
                idx_head_kv[:, None] * stride_uvm_v_bank_head_kv +\
                idx_hid[None, :] * stride_uvm_v_bank_hid,
            mask=mask_put[:, None],
        )
        tl.store(
            BANK +\
                idx_evict[:, None].to(tl.int64) * stride_bank_t +\
                (idx_hid + BLOCK_HID)[None, :] * stride_bank_hid,
            mask=mask_put[:, None],
            value=values,
        )