import torch
import time
from backends import Gram2TokenGrammarObject, Pre3GrammarObject, FormatronGrammarObject
from implementation_core import BytePDA, TokenizerStub, Gram2TokenCompiler

def sglang_simulation(backend_type="g2t", batch_size=64):
    tok = TokenizerStub()
    pda = BytePDA()
    
    # 1. Initialize Backend
    if backend_type == "g2t":
        compiler = Gram2TokenCompiler(pda, tok)
        token_to_cat, categories = compiler.compile()
        
        # Build tables (normally done on GPU initialization)
        num_states = len(list(categories.keys())[0])
        num_cats = len(categories)
        state_table = torch.full((num_states, num_cats), -1, dtype=torch.long)
        mask_table = torch.zeros((num_states, num_cats), dtype=torch.bool)
        t_to_c = torch.tensor([token_to_cat[i] for i in range(tok.vocab_size)])
        
        for sig, cat_id in categories.items():
            for s_idx, (valid, next_state) in enumerate(sig):
                if valid:
                    state_table[s_idx, cat_id] = next_state
                    mask_table[s_idx, cat_id] = True
        
        grammars = [Gram2TokenGrammarObject(state_table, mask_table, t_to_c) for _ in range(batch_size)]
        name = "Gram2Token (O(1) GPU Table)"
        
    elif backend_type == "pre3":
        grammars = [Pre3GrammarObject(pda, tok) for _ in range(batch_size)]
        name = "Pre3 (CPU/Trie Walk)"
        
    elif backend_type == "formatron":
        grammars = [FormatronGrammarObject(pda, tok) for _ in range(batch_size)]
        name = "Formatron (Heavy Logic)"

    # 2. Simulate SGLang ModelRunner loop
    print(f"--- Running SGLang {name} simulation (BatchSize={batch_size}) ---")
    
    vocab_mask = torch.zeros((batch_size, tok.vocab_size), dtype=torch.bool)
    logits = torch.randn(batch_size, tok.vocab_size)
    
    start = time.time()
    for step in range(20): # Simulate 20 decoding steps
        # SGLang logic: fill mask for each request in batch
        for i in range(batch_size):
            grammars[i].fill_vocab_mask(vocab_mask, i)
        
        # Apply mask to logits
        grammars[0].apply_vocab_mask(logits, vocab_mask)
        
        # Update states base on dummy sampled tokens
        dummy_tokens = torch.randint(0, tok.vocab_size, (batch_size,))
        for i in range(batch_size):
            grammars[i].update_state(dummy_tokens[i].item())
            
    end = time.time()
    print(f"Total time for 20 steps: {end - start:.4f}s")
    print(f"Avg time per step: {(end - start)/20*1000:.2f}ms\n")

if __name__ == "__main__":
    for b in [16, 128]:
        for kind in ["g2t", "pre3", "formatron"]:
            sglang_simulation(kind, b)
