import torch

class BytePDA:
    """A simple byte-level Pushdown Automaton for JSON-like structures."""
    def __init__(self):
        # States: 0=start, 1=key, 2=colon, 3=value, 4=comma, 5=end
        self.transitions = {
            (0, ord('{')): (1, 'push_obj'),
            (1, ord('"')): (2, 'none'), # Key start
            (2, ord('k')): (2, 'none'), # Simple key char
            (2, ord('"')): (3, 'none'), # Key end
            (3, ord(':')): (4, 'none'), # Colon
            (4, ord('"')): (5, 'none'), # Value start
            (5, ord('v')): (5, 'none'), # Simple value char
            (5, ord('"')): (6, 'none'), # Value end
            (6, ord(',')): (1, 'none'), # Comma, next key
            (6, ord('}')): (7, 'pop_obj'), # End of object
        }
        self.start_state = 0

    def step(self, current_state, byte_val, stack_depth):
        key = (current_state, byte_val)
        if key in self.transitions:
            next_state, action = self.transitions[key]
            new_depth = stack_depth
            if action == 'push_obj': new_depth += 1
            if action == 'pop_obj': new_depth -= 1
            return next_state, new_depth
        return None, stack_depth

class TokenizerStub:
    def __init__(self):
        # Small vocabulary for demo
        self.vocab = {
            "{": [ord('{')],
            "\"": [ord('"')],
            "key": [ord('k'), ord('e'), ord('y')],
            "val": [ord('v'), ord('a'), ord('l')],
            ":": [ord(':')],
            ",": [ord(',')],
            "}": [ord('}')],
            "abc": [ord('a'), ord('b'), ord('c')] # invalid
        }
        self.id_to_bytes = {i: b for i, (s, b) in enumerate(self.vocab.items())}
        self.vocab_size = len(self.vocab)

class Gram2TokenCompiler:
    def __init__(self, pda, tokenizer):
        self.pda = pda
        self.tokenizer = tokenizer

    def compile(self):
        """Pre-align tokenizer with PDA and categorize tokens."""
        # state -> {token_id -> next_state}
        transitions = {}
        # Simple implementation: iterate states and tokens
        reachable_states = [0, 1, 2, 3, 4, 5, 6, 7] 
        
        token_signatures = {} # token_id -> tuple of outcomes across states
        
        for token_id in range(self.tokenizer.vocab_size):
            bytes_seq = self.tokenizer.id_to_bytes[token_id]
            sig = []
            for state in reachable_states:
                curr = state
                depth = 0 # simplified
                valid = True
                for b in bytes_seq:
                    curr, depth = self.pda.step(curr, b, depth)
                    if curr is None:
                        valid = False
                        break
                sig.append((valid, curr))
            token_signatures[token_id] = tuple(sig)
            
        # Categorization
        categories = {} # signature -> cat_id
        token_to_cat = {}
        cat_id_counter = 0
        for tid, sig in token_signatures.items():
            if sig not in categories:
                categories[sig] = cat_id_counter
                cat_id_counter += 1
            token_to_cat[tid] = categories[sig]
            
        return token_to_cat, categories

if __name__ == "__main__":
    pda = BytePDA()
    tok = TokenizerStub()
    compiler = Gram2TokenCompiler(pda, tok)
    token_to_cat, categories = compiler.compile()
    print(f"Vocab size: {tok.vocab_size}")
    print(f"Num Categories: {len(categories)}")
    print(f"Token to Category mapping: {token_to_cat}")
