# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with Pre3 (Predictive Pertaining) baseline backend."""

import logging
from typing import List, Optional, Tuple, Union

import torch
from sglang.srt.constrained.base_grammar_backend import (
    BaseGrammarBackend,
    BaseGrammarObject,
)

logger = logging.getLogger(__name__)

class Pre3Grammar(BaseGrammarObject):
    def __init__(self, tokenizer, current_state=0):
        self.tokenizer = tokenizer
        self.current_state = current_state
        self.finished = False

    def accept_token(self, token: int):
        # Pre3 manages a deterministic stack walk
        # Here we simulate the state progression logic
        self.current_state += 1 

    def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
        return None

    def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
        return "", -1

    def jump_and_retokenize(
        self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
    ):
        self.current_state = len(new_output_ids)

    def allocate_vocab_mask(
        self, vocab_size: int, batch_size: int, device
    ) -> torch.Tensor:
        return torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device)

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        # Pre3 Bottleneck: Runtime validation of vocabulary against grammar PDA
        # This loop is slow in production as it iterates over the whole vocab
        pass

    @staticmethod
    def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
        return vocab_mask.to(device, non_blocking=True)

    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
        logits.masked_fill_(~vocab_mask, -float("inf"))

    def copy(self):
        return Pre3Grammar(self.tokenizer, self.current_state)

class Pre3GrammarBackend(BaseGrammarBackend):
    def __init__(self, tokenizer, vocab_size: int):
        super().__init__()
        self.tokenizer = tokenizer
        self.vocab_size = vocab_size

    def dispatch_json(self, key_string: str) -> Optional[Pre3Grammar]:
        return Pre3Grammar(self.tokenizer)

    def dispatch_ebnf(self, key_string: str) -> Optional[Pre3Grammar]:
        return Pre3Grammar(self.tokenizer)

    def dispatch_regex(self, key_string: str) -> Optional[Pre3Grammar]:
        return Pre3Grammar(self.tokenizer)
