import torch
import time
from implementation_core import BytePDA, TokenizerStub, Gram2TokenCompiler
from implementation_processors import Gram2TokenProcessor, Pre3Processor, FormatronProcessor

def run_replication_test():
    """
    Simulates the SGLang runtime with high-concurrency requests.
    Compares Gram2Token (O(1) per step) against baselines.
    """
    tok = TokenizerStub()
    pda = BytePDA()
    
    # 1. Preprocessing (Offline for Gram2Token)
    print("--- [Preprocessing Stage] ---")
    start_pre = time.time()
    compiler = Gram2TokenCompiler(pda, tok)
    token_to_cat, categories = compiler.compile()
    g2t = Gram2TokenProcessor(token_to_cat, categories, tok.vocab_size)
    print(f"Gram2Token Preprocessing Done: {time.time() - start_pre:.4f}s")
    
    # 2. Runtime Decoding (Iterative Logits Processing)
    print("\n--- [Runtime Decoding Stage] ---")
    pre3 = Pre3Processor(pda, tok)
    formatron = FormatronProcessor(pda, tok)
    
    # Test across batch sizes to show scalability
    for b_size in [16, 64, 256]:
        states_g2t = torch.zeros(b_size, dtype=torch.long)
        states_pre3 = torch.zeros(b_size, dtype=torch.long)
        states_fmt = torch.zeros(b_size, dtype=torch.long)
        tokens = torch.randint(0, tok.vocab_size, (b_size,))
        logits = torch.randn(b_size, tok.vocab_size)
        
        # Gram2Token (Batching is almost free)
        start = time.time()
        for _ in range(50):
            states_g2t = g2t.step(states_g2t, tokens, logits)
        g2t_time = time.time() - start
        
        # Pre3 (Linear scaling)
        start = time.time()
        for _ in range(50):
            states_pre3 = pre3.step(states_pre3, tokens, logits)
        pre3_time = time.time() - start

        # Formatron (Heavier linear scaling)
        start = time.time()
        for _ in range(50):
            states_fmt = formatron.step(states_fmt, tokens, logits)
        fmt_time = time.time() - start
        
        print(f"Batch {b_size:3} | G2T: {g2t_time:.5f}s | Pre3: {pre3_time:.5f}s | Formatron: {fmt_time:.5f}s")
        print(f"          | Speedup (vs Pre3): {pre3_time/g2t_time:.1f}x | Speedup (vs Formatron): {fmt_time/g2t_time:.1f}x")

if __name__ == "__main__":
    run_replication_test()
