import json
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace


class CustomTokenizer:
    """
    Custom tokenizer class for handling specific character sets and supporting vocabulary saving and loading.
    """

    def __init__(self, vocab_path=None):
        """
        Initialize the CustomTokenizer class.

        Args:
        - vocab_path (str): Path to the vocabulary file. If None, a vocabulary file will be created automatically.
        """
        # Path to the vocabulary file
        self.vocab_path = vocab_path or "data/custom_vocab.json"
        self.tokenizer = None

        # If no vocabulary path is provided, create the vocabulary file
        if not vocab_path:
            self._create_vocab()
        # Load the tokenizer
        self._load_tokenizer()

    def _create_vocab(self):
        """
        Create a vocabulary file and save it in JSON format.
        """
        # Define the supported character set, including lowercase letters, numbers, logical symbols, and parentheses
        self.chars = (
            ['<pad>', '<unk>', '<bos>', '<eos>'] +  # Special tokens
            ['a', 'b', 'x'] +  # Lowercase letters
            [f'{i}' for i in range(10)] +  # Numbers 0-9
            ['&', '|', '~', '#', '(', ')']  # Logical symbols and parentheses
        )
        # Create a mapping from characters to IDs
        vocab = {char: idx for idx, char in enumerate(self.chars)}

        # Save the vocabulary as a JSON file
        with open(self.vocab_path, 'w') as f:
            json.dump(vocab, f)

    def _load_tokenizer(self):
        """
        Load the custom tokenizer and set it up in `PreTrainedTokenizerFast` format.
        """
        # Read the vocabulary file
        with open(self.vocab_path, 'r') as f:
            vocab = json.load(f)

        # Create a WordLevel tokenizer using the tokenizers library
        tokenizer = Tokenizer(WordLevel(vocab, unk_token="<unk>"))
        # Set the pre-tokenizer to split by whitespace
        tokenizer.pre_tokenizer = Whitespace()
        # Save the tokenizer as a JSON file (for use with PreTrainedTokenizerFast)
        tokenizer.save("data/custom_tokenizer.json")

        # Load the tokenizer into PreTrainedTokenizerFast and explicitly set special tokens
        self.tokenizer = PreTrainedTokenizerFast(
            tokenizer_file="data/custom_tokenizer.json",
            pad_token="<pad>",
            unk_token="<unk>",
            bos_token="<bos>",
            eos_token="<eos>",
        )

    def get_tokenizer(self):
        """
        Get the tokenizer instance.

        Returns:
        - PreTrainedTokenizerFast object.
        """
        return self.tokenizer


if __name__ == "__main__":
    # Create an instance of the custom tokenizer
    custom_tokenizer = CustomTokenizer()
    tokenizer = custom_tokenizer.get_tokenizer()

    # Test the functionality of the tokenizer
    test_strings = [
        "a & ( b | c ) #",         # Normal input
        "x ~ y | z",            # Input containing logical symbols
        "undefined_symbol"  # Input containing undefined symbols
    ]
    
    # Encode and decode each test string
    for test_string in test_strings:
        encoded = tokenizer.encode(test_string, add_special_tokens=True)  # Encode
        decoded = tokenizer.decode(encoded)  # Decode
        print(f"Original: {test_string}")  # Original input
        print(f"Encoded: {encoded}")       # Encoded result
        print(f"Decoded: {decoded}")       # Decoded result