# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""GPU-native kernels for Gram2Token using Triton."""

import torch
import triton
import triton.language as tl

def g2t_expand_mask_kernel(
    mask_table_ptr,      # [num_states, num_cats] (uint8/bool)
    token_to_cat_ptr,    # [vocab_size] (int32)
    output_mask_ptr,     # [vocab_size] (uint8) or packed
    state_id,
    vocab_size,
    num_cats,
    stride_mask_s,
    stride_mask_c,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < vocab_size
    
    # 1. Load category ID for these tokens
    cat_ids = tl.load(token_to_cat_ptr + offsets, mask=mask, other=0)
    
    # 2. Lookup validity in the compressed table
    # mask_table[state_id, cat_id]
    lookup_ptr = mask_table_ptr + state_id * stride_mask_s + cat_ids * stride_mask_c
    is_valid = tl.load(lookup_ptr, mask=mask, other=0)
    
    # 3. Store to output vocab mask
    tl.store(output_mask_ptr + offsets, is_valid, mask=mask)

def fill_g2t_vocab_mask(vocab_mask, idx, state_id, token_to_cat, mask_table):
    """
    Expands compressed category mask to full vocab mask for a single request.
    Args:
        vocab_mask: [batch, vocab_size] (uint8)
        idx: request index in batch
        state_id: current PDA state for this request
    """
    vocab_size = vocab_mask.shape[1]
    
    BLOCK_SIZE = 1024
    grid = ((vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE, )
    
    g2t_expand_mask_kernel[grid](
        mask_table,
        token_to_cat,
        vocab_mask[idx],
        int(state_id),
        vocab_size,
        mask_table.shape[1],
        mask_table.stride(0),
        mask_table.stride(1),
        BLOCK_SIZE=BLOCK_SIZE,
    )
