class CFGTokenizer:
    def __init__(self, context_length=512, pad_token=0):
        """
        Simple tokenizer for CFG sequences (space-separated integers).

        Args:
            context_length (int): The base context length for padding
            pad_token (int): Token used for padding (default: 0)
        """
        self.context_length = context_length
        self.pad_token = pad_token

    def tokenize(self, sequence):
        """
        Tokenize a space-separated integer sequence.

        Args:
            sequence (str): Space-separated integer string like "14 15 14 13"

        Returns:
            list: List of integers
        """
        if not sequence or not sequence.strip():
            return []

        try:
            tokens = [int(token) for token in sequence.strip().split()]
            return tokens
        except ValueError as e:
            raise ValueError(f"Invalid sequence format. Expected space-separated integers: {e}")

    def pad_to_multiple(self, tokens, multiple=None):
        """
        Pad tokens to the nearest multiple of context_length (or custom multiple).

        Args:
            tokens (list): List of integer tokens
            multiple (int, optional): Custom multiple to pad to. If None, uses context_length

        Returns:
            tuple: (padded_tokens, attention_mask)
                - padded_tokens: List of tokens padded to multiple length
                - attention_mask: List of 1s for real tokens, 0s for padding
        """
        if not tokens:
            # Handle empty sequence
            target_length = multiple or self.context_length
            padded_tokens = [self.pad_token] * target_length
            attention_mask = [0] * target_length
            return padded_tokens, attention_mask

        multiple = multiple or self.context_length
        current_length = len(tokens)

        # Calculate target length (next multiple of context_length)
        target_length = ((current_length + multiple - 1) // multiple) * multiple

        # Create attention mask (1 for real tokens, 0 for padding)
        padded_tokens = tokens + [self.pad_token] * (target_length - current_length)
        attention_mask = []
        for i in range(len(padded_tokens)):
            if padded_tokens[i] != self.pad_token:
                attention_mask.append(1)
            else:
                attention_mask.append(0)

        return padded_tokens, attention_mask

    def encode(self, sequence, return_attention_mask=True, pad_multiple=None):
        """
        Full encoding pipeline: tokenize and pad sequence.

        Args:
            sequence (str): Space-separated integer string
            return_attention_mask (bool): Whether to return attention mask
            pad_multiple (int, optional): Custom multiple to pad to

        Returns:
            dict: Dictionary containing 'input_ids' and optionally 'attention_mask'
        """
        tokens = self.tokenize(sequence)
        padded_tokens, attention_mask = self.pad_to_multiple(tokens, pad_multiple)

        result = {'input_ids': padded_tokens}

        if return_attention_mask:
            result['attention_mask'] = attention_mask

        return result

    def encode_no_pad(self, sequence):
        tokens = self.tokenize(sequence)

        result = {'input_ids': tokens}

        return result

    def decode(self, token_ids, skip_padding=True):
        """
        Decode token IDs back to space-separated string.

        Args:
            token_ids (list): List of integer token IDs
            skip_padding (bool): Whether to skip padding tokens

        Returns:
            str: Space-separated integer string
        """
        if skip_padding:
            # Remove padding tokens from the end
            while token_ids and token_ids[-1] == self.pad_token:
                token_ids = token_ids[:-1]

        return ' '.join(map(str, token_ids))


# Example usage and testing
if __name__ == "__main__":
    # Initialize tokenizer
    tokenizer = CFGTokenizer(context_length=8, pad_token=0)

    # Test sequences
    sequences = [
        "14 15 14 13",
        "1 2 3 4 5 6 7 8 9 10",
        "42",
        ""
    ]

    print("Testing CFG Tokenizer:")
    print("=" * 50)

    for i, seq in enumerate(sequences):
        print(f"\nTest {i + 1}: '{seq}'")

        # Encode sequence
        encoded = tokenizer.encode(seq)

        print(f"Input length: {len(seq.split()) if seq else 0}")
        print(f"Padded length: {len(encoded['input_ids'])}")
        print(f"Input IDs: {encoded['input_ids']}")
        print(f"Attention mask: {encoded['attention_mask']}")

        # Decode back
        decoded = tokenizer.decode(encoded['input_ids'])
        print(f"Decoded: '{decoded}'")

        # Verify round-trip
        original_tokens = seq.strip() if seq else ""
        print(f"Round-trip success: {original_tokens == decoded}")

    # Test with different context lengths
    print(f"\n" + "=" * 50)
    print("Testing different context lengths:")

    test_seq = "1 2 3 4 5"
    for ctx_len in [4, 8, 16]:
        tokenizer = CFGTokenizer(context_length=ctx_len)
        encoded = tokenizer.encode(test_seq)
        print(f"Context length {ctx_len}: padded to {len(encoded['input_ids'])} tokens")
        print(f"  Input IDs: {encoded['input_ids']}")
        print(f"  Attention: {encoded['attention_mask']}")