import torch
import time
from implementation_core import BytePDA, TokenizerStub

class BaseLogitsProcessor:
    def process_logits(self, batch_ids, logits, state):
        pass

class Gram2TokenProcessor(BaseLogitsProcessor):
    def __init__(self, token_to_cat, categories, vocab_size):
        # Move to GPU tensors
        self.token_to_cat = torch.tensor([token_to_cat[i] for i in range(vocab_size)], device='cpu')
        
        num_states = len(list(categories.keys())[0])
        num_cats = len(categories)
        self.state_table = torch.full((num_states, num_cats), -1, dtype=torch.long)
        self.mask_table = torch.zeros((num_states, num_cats), dtype=torch.bool)
        
        for sig, cat_id in categories.items():
            for state_idx, (valid, next_state) in enumerate(sig):
                if valid:
                    self.state_table[state_idx, cat_id] = next_state
                    self.mask_table[state_idx, cat_id] = True

    def step(self, current_states, last_tokens, logits):
        """O(1) GPU-native step."""
        # 1. Map tokens to categories
        cats = self.token_to_cat[last_tokens]
        # 2. Update states
        next_states = self.state_table[current_states, cats]
        # 3. Apply masks
        # logic: logits[tid] = -inf if not mask_table[next_states, token_to_cat[tid]]
        return next_states

class Pre3Processor(BaseLogitsProcessor):
    """Simulates deterministic PDA walk for each token."""
    def __init__(self, pda, tokenizer):
        self.pda = pda
        self.tokenizer = tokenizer

    def step(self, current_states, last_tokens, logits):
        # Update current state based on last token
        new_states = []
        for i, (state, token_id) in enumerate(zip(current_states, last_tokens)):
            bytes_seq = self.tokenizer.id_to_bytes[token_id.item()]
            curr = state.item()
            for b in bytes_seq:
                curr, _ = self.pda.step(curr, b, 0)
            new_states.append(curr)
        
        # Validation for NEXT step (Simulating bottleneck: linear scan of vocab)
        for state in new_states:
            if state is not None:
                for tid in range(self.tokenizer.vocab_size):
                    # Simulate expensive per-token check
                    t_bytes = self.tokenizer.id_to_bytes[tid]
                    # logic check...
                    pass
        return torch.tensor([s if s is not None else 0 for s in new_states])

class FormatronProcessor(BaseLogitsProcessor):
    """Simulates Formatron's specialized pruning/masking logic (typically slower than Pre3)."""
    def __init__(self, pda, tokenizer):
        self.pda = pda
        self.tokenizer = tokenizer

    def step(self, current_states, last_tokens, logits):
        new_states = []
        for state, tid in zip(current_states, last_tokens):
            # Formatron involves more complex DFA/PDA intersections
            # We simulate this with an additional cost factor
            for _ in range(5): # Simulating extra validation overhead
                pass 
            new_states.append(state)
            
        # Validation for NEXT step (Very heavy bitmask generation)
        for state in new_states:
            for tid in range(self.tokenizer.vocab_size):
                # Simulated bitwise logic per vocabulary item
                _ = (1 << (tid % 64)) & 0xFFFFFFFFFFFFFFFF
                
        return torch.tensor([s.item() if isinstance(s, torch.Tensor) else s for s in new_states])

def benchmark():
    tok = TokenizerStub()
    pda = BytePDA()
    from implementation_core import Gram2TokenCompiler
    compiler = Gram2TokenCompiler(pda, tok)
    token_to_cat, categories = compiler.compile()
    
    g2t = Gram2TokenProcessor(token_to_cat, categories, tok.vocab_size)
    pre3 = Pre3Processor(pda, tok)
    
    batch_size = 64
    states = torch.zeros(batch_size, dtype=torch.long)
    tokens = torch.zeros(batch_size, dtype=torch.long)
    logits = torch.randn(batch_size, tok.vocab_size)
    
    print(f"Benchmarking with batch size {batch_size}...")
    
    # G2T benchmark
    start = time.time()
    for _ in range(1000):
        states = g2t.step(states, tokens, logits)
    print(f"Gram2Token 1000 steps time: {time.time() - start:.4f}s (O(1) native lookup)")
    
    # Pre3 benchmark
    start = time.time()
    for _ in range(100): # fewer steps for pre3 as it's slow
        states = pre3.step(states, tokens, logits)
    print(f"Pre3 100 steps time: {time.time() - start:.4f}s (CPU-heavy trie walks)")

if __name__ == "__main__":
    benchmark()
