"""
Train tokenizer on the dataset

>>> uv run scripts/train_lltm_tokenizer.py \
    --base configs/sudoku/ltm-tokenizer.yaml \
    --vocab_size 256 \
    --batch_size 1 \
    --save_dir data/ltm_tokenizers/sudoku \
    --num_examples 1000
"""

import sys 

import tokenizers .implementations 

sys .path .append (".")
import os 
import argparse 
from omegaconf import OmegaConf 

import lightning as L 
from transformers import LlamaTokenizerFast ,GPT2TokenizerFast 
from tokenizers import implementations 
from datasets import Dataset 

from ltm .utils import instantiate_from_config 


os .environ ["TOKENIZERS_PARALLELISM"]="true"


def parse_args ():
    parser =argparse .ArgumentParser ()
    parser .add_argument (
    "--base",
    type =str ,
    default ="configs/sudoku/ltm-tokenizer.yaml",
    help ="path to config",
    )
    parser .add_argument (
    "--vocab_size",type =int ,default =8000 ,help ="vocab size for the tokenizer"
    )
    parser .add_argument (
    "--batch_size",type =int ,default =1 ,help ="batch size for training"
    )
    parser .add_argument (
    "--num_examples",type =int ,default =None ,help ="number of examples to train"
    )
    parser .add_argument (
    "--save_dir",
    type =str ,
    default ="data/ltm_tokenizer",
    help ="path to save the tokenizer",
    )
    parser .add_argument (
    "--dataset_path",
    type =str ,
    default =None ,
    help ="path to the saved dataset by `--save_data_dir`",
    )
    parser .add_argument (
    "--save_data_dir",
    type =str ,
    default =None ,
    help ="This will be helpful to skip the data preparation step"
    "when you want to train the tokenizer multiple times."
    "You can use `--dataset_path` next time.",
    )
    args =parser .parse_args ()
    return args 


def main_spm ():

    import sentencepiece as spm 

    args =parse_args ()
    config =OmegaConf .load (args .base )

    data :L .LightningDataModule =instantiate_from_config (
    OmegaConf .to_container (config .data ,resolve =True )
    )
    data .prepare_data ()
    data .setup (stage ="fit")

    dataset :Dataset =data .datasets ["train"]
    if args .num_examples is not None :
        dataset =dataset .shuffle (seed =42 ).select (range (args .num_examples ))

    def sentence_iterator ():
        for example in dataset :
            yield example ["text"]

    spm .SentencePieceTrainer .Train (
    sentence_iterator =sentence_iterator (),

    character_coverage =0.9995 ,
    model_type ="bpe",
    vocab_size =800 ,
    model_prefix ="tokenizer",
    byte_fallback =True ,
    max_sentence_length =50_000_000 ,
    num_threads =16 ,
    split_digits =True ,
    pad_id =0 ,
    unk_id =1 ,
    bos_id =2 ,
    eos_id =3 ,
    normalization_rule_name ="identity",
    allow_whitespace_only_pieces =True ,
    remove_extra_whitespaces =False ,
    )





def main ():
    args =parse_args ()
    config =OmegaConf .load (args .base )

    if args .dataset_path is not None :
        dataset =Dataset .load_from_disk (args .dataset_path )
    else :
        data :L .LightningDataModule =instantiate_from_config (
        OmegaConf .to_container (config .data ,resolve =True )
        )
        data .prepare_data ()
        data .setup (stage ="fit")

        dataset :Dataset =data .datasets ["train"]
        if args .save_data_dir is not None :
            dataset .save_to_disk (args .save_data_dir )

    if args .num_examples is not None :
        dataset =dataset .shuffle (seed =42 ).select (range (args .num_examples ))

    def batch_iterator ():
        for batch in dataset .iter (args .batch_size ):
            yield batch ["text"]




    tokenizer =implementations .ByteLevelBPETokenizer (

    )



    tokenizer .train_from_iterator (
    batch_iterator (),
    vocab_size =args .vocab_size ,
    length =len (dataset ),

    special_tokens =["<|endoftext|>"],
    )

    hf_tokenizer =GPT2TokenizerFast (tokenizer_object =tokenizer ._tokenizer )



    hf_tokenizer .save_pretrained (args .save_dir )

    print ("Finished training tokenizer")


if __name__ =="__main__":
    main ()

