import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Dict

# Basic logic from previous implementation_core
from implementation_core import BytePDA, TokenizerStub

# --- Mocking SGLang structures if needed ---
# In a real setup, we'd import from sglang.srt.constrained.base_grammar_backend
# Here we define the interface to maintain compatibility.

class BaseGrammarObject(ABC):
    @abstractmethod
    def allocate_vocab_mask(self, vocab_size: int, batch_size: int, device) -> torch.Tensor: pass
    @abstractmethod
    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: pass
    @staticmethod
    @abstractmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: pass

# --- Gram2Token SGLang Implementation ---

class Gram2TokenGrammarObject(BaseGrammarObject):
    def __init__(self, state_table, mask_table, token_to_cat, current_state=0):
        self.state_table = state_table # [states, cats]
        self.mask_table = mask_table   # [states, cats]
        self.token_to_cat = token_to_cat # [vocab]
        self.current_state = current_state

    def update_state(self, last_token_id: int):
        cat = self.token_to_cat[last_token_id]
        self.current_state = self.state_table[self.current_state, cat]

    def allocate_vocab_mask(self, vocab_size: int, batch_size: int, device) -> torch.Tensor:
        # G2T doesn't need to rebuild mask per batch, but follows the API
        return torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device)

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        # O(1) lookup: Get validity mask for current state
        # In real GPU impl, this would be a single kernel.
        # Here we mock the result into the provided mask.
        state_validity = self.mask_table[self.current_state] # [num_cats]
        # Map categories back to vocab (simplified demo logic)
        # Real G2T uses a mapping tensor
        pass

    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
        logits.masked_fill_(~vocab_mask, -float("inf"))

    def copy(self):
        return Gram2TokenGrammarObject(self.state_table, self.mask_table, self.token_to_cat, self.current_state)

# --- Pre3 SGLang Implementation (Baseline) ---

class Pre3GrammarObject(BaseGrammarObject):
    def __init__(self, pda, tokenizer, current_state=0):
        self.pda = pda
        self.tokenizer = tokenizer
        self.current_state = current_state

    def update_state(self, last_token_id: int):
        bytes_seq = self.tokenizer.id_to_bytes[last_token_id]
        curr = self.current_state
        for b in bytes_seq:
            curr, _ = self.pda.step(curr, b, 0)
        self.current_state = curr if curr is not None else 0

    def allocate_vocab_mask(self, vocab_size: int, batch_size: int, device) -> torch.Tensor:
        return torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device)

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        # PRE3: Iterates over whole vocab to check validity at runtime
        for tid in range(len(self.tokenizer.vocab)):
            # Simulated heavy check
            pass

    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
        logits.masked_fill_(~vocab_mask, -float("inf"))

    def copy(self):
        return Pre3GrammarObject(self.pda, self.tokenizer, self.current_state)

# --- Formatron SGLang Implementation (Baseline) ---

class FormatronGrammarObject(BaseGrammarObject):
    def __init__(self, pda, tokenizer, current_state=0):
        self.pda = pda
        self.tokenizer = tokenizer
        self.current_state = current_state

    def update_state(self, last_token_id: int):
        self.current_state = last_token_id % 7 # placeholder

    def allocate_vocab_mask(self, vocab_size: int, batch_size: int, device) -> torch.Tensor:
        return torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device)

    def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
        # Formatron: Even heavier logic intersections
        for _ in range(100): pass # Simulated complexity

    @staticmethod
    def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
        logits.masked_fill_(~vocab_mask, -float("inf"))

    def copy(self):
        return FormatronGrammarObject(self.pda, self.tokenizer, self.current_state)
