import sys 
import argparse 
from tqdm import tqdm 
from omegaconf import OmegaConf 
from datasets import Dataset 
import pandas as pd 
from torch .utils .data import DataLoader 
import lightning as L 

sys .path .append (".")
from ltm .utils import instantiate_from_config 


def parse_args ():
    parser =argparse .ArgumentParser ()
    parser .add_argument (
    "--base",
    type =str ,
    default ="configs/sudoku/ltm-tokenizer.yaml",
    help ="path to config",
    )
    parser .add_argument (
    "--num_examples",type =int ,default =1000 ,help ="number of examples to count"
    )
    parser .add_argument (
    "--tokenizer_path",type =str ,default =None ,help ="path to tokenizer"
    )
    args =parser .parse_args ()
    return args 


def main ():
    args =parse_args ()
    config =OmegaConf .load (args .base )
    if args .tokenizer_path is not None :
        config .data .params .tokenizer_config .params .pretrained_model_name_or_path =(
        args .tokenizer_path 
        )
    data :L .LightningDataModule =instantiate_from_config (
    OmegaConf .to_container (config .data ,resolve =True )
    )
    data .prepare_data ()
    data .setup (stage ="fit")

    dataset :Dataset =data .datasets ["train"]

    if args .num_examples is not None :
        dataset =dataset .shuffle (seed =42 ).select (range (args .num_examples ))

    dataloader =DataLoader (dataset ,batch_size =1 )
    tokens =[]

    for batch in tqdm (dataloader ):


        n_tokens =batch ["input_ids"].size (1 )
        if n_tokens <2 :

            continue 
        tokens .append (n_tokens )
    df =pd .DataFrame (tokens ,columns =["tokens"])
    print (df .describe ())


if __name__ =="__main__":
    main ()
