# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from sglang.srt.constrained.base_grammar_backend import (
    BaseGrammarBackend,
    BaseGrammarObject,
)
from sglang.srt.constrained.gram2token_kernels import fill_g2t_vocab_mask, apply_g2t_mask

logger = logging.getLogger(__name__)

class Gram2TokenGrammar(BaseGrammarObject):
    def __init__(
        self,
        state_table: torch.Tensor,
        mask_table: torch.Tensor,
        token_to_cat: torch.Tensor,
        batch_state_ids: Optional[torch.Tensor] = None,
        vocab_size: int = 128256,
        batch_size: int = 1,
    ) -> None:
        self.state_table = state_table # [num_states, num_cats]
        self.mask_table = mask_table   # [num_states, num_cats]
        self.token_to_cat = token_to_cat # [vocab_size]
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        
        if batch_state_ids is not None:
            self.state_ids = batch_state_ids
        else:
            self.state_ids = torch.zeros((batch_size,), dtype=torch.int32, device=state_table.device)
            
        self.finished = False

    def accept_token(self, token: Union[int, torch.Tensor], idx: int = 0):
        # Update state for a single request or whole batch
        if isinstance(token, int):
            cat = self.token_to_cat[token]
            self.state_ids[idx] = self.state_table[self.state_ids[idx], cat]
        else:
            # Batch update
            cats = self.token_to_cat[token]
            self.state_ids = self.state_table[self.state_ids, cats]

    def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
        return None

    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
        return "", -1

    def jump_and_retokenize(
        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
    ):
        self.state_ids.fill_(0)
        self.accept_token(torch.tensor(new_output_ids, device=self.state_ids.device))

    def allocate_vocab_mask(
        self, vocab_size: int, batch_size: int, device
    ) -> torch.Tensor:
        return torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device)

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        # GPU-NATIVE O(1) TOKEN EXPANSION
        fill_g2t_vocab_mask(
            vocab_mask=vocab_mask,
            idx=idx,
            state_id=self.state_ids[idx],
            token_to_cat=self.token_to_cat,
            mask_table=self.mask_table
        )

    @staticmethod
    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
        return vocab_mask.to(device, non_blocking=True)

    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor, grammar_objs: List["Gram2TokenGrammar"]) -> None:
        obj = grammar_objs[0]
        state_ids = torch.stack([g.state_ids[0] if g.state_ids.dim() > 0 else g.state_ids for g in grammar_objs]).to(torch.int32)
        
        apply_g2t_mask(
            logits=logits,
            token_to_cat=obj.token_to_cat,
            mask_table=obj.mask_table,
            state_ids=state_ids
        )

    def copy(self):
        return Gram2TokenGrammar(
            self.state_table, 
            self.mask_table, 
            self.token_to_cat, 
            self.state_ids.clone(), 
            self.vocab_size, 
            self.batch_size
        )
class Gram2TokenGrammarBackend(BaseGrammarBackend):
    def __init__(self, tokenizer, vocab_size: int):
        super().__init__()
        self.tokenizer = tokenizer
        self.vocab_size = vocab_size
        # Metadata storage for compiled grammars
        self.compiled_cache = {}

    def _compile_to_g2t(self, key_string: str, key_type: str) -> Gram2TokenGrammar:
        """
        Compiles a raw grammar (JSON/EBNF/Regex) into token-level state tables.
        This is the core innovation of Gram2Token.
        """
        # 1. Convert key_string to a byte-level PDA (Using internal compiler)
        # 2. Align tokenizer with PDA
        # 3. Categorize tokens
        # (This logic is abstracted here for the integration summary)
        
        # Mocking the resulting tables for the integration demo
        num_states = 100
        num_cats = 50
        state_table = torch.full((num_states, num_cats), -1, dtype=torch.int32)
        mask_table = torch.zeros((num_states, num_cats), dtype=torch.bool)
        token_to_cat = torch.zeros((self.vocab_size,), dtype=torch.int32)
        
        return Gram2TokenGrammar(state_table, mask_table, token_to_cat)

    def dispatch_json(self, key_string: str) -> Optional[Gram2TokenGrammar]:
        return self._compile_to_g2t(key_string, "json")

    def dispatch_ebnf(self, key_string: str) -> Optional[Gram2TokenGrammar]:
        return self._compile_to_g2t(key_string, "ebnf")

    def dispatch_regex(self, key_string: str) -> Optional[Gram2TokenGrammar]:
        return self._compile_to_g2t(key_string, "regex")
