from transformers import AutoTokenizer
from tqdm import tqdm
from nltk.tokenize import sent_tokenize
import nltk
import pandas as pd
from datasets import load_from_disk

nltk.download('punkt')

def pack_sentences(corpus: str, tokenizer, max_length):
    # use nltk==3.6.5!! on 3.7 there is a bug that makes tokenization very slow
    sentences = sent_tokenize(corpus)

    len_accum = 0
    full_data = []
    data = []
    for sentence in tqdm(sentences):
        sent_len = len(tokenizer(sentence)['input_ids'])
        len_accum += sent_len

        if len_accum < max_length:
            data.append(sentence)
        else:
            full_data.append(' '.join(data))
            data = [sentence]
            len_accum = sent_len

    df = pd.DataFrame({'text': full_data})

    return df


if __name__ == '__main__':
    tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')

    # with open(f'../data/wikitext-103/wiki.{split}.tokens', 'r') as file:
    #     corpus = file.read()
    pile = 'enron_emails'
    dataset = load_from_disk(f'/mnt/nfs/dongkeun/the_pile_40M/{pile}')
    text = dataset.to_pandas()['text'].sample(n=2000, random_state=42)
    corpus = text.str.cat(sep=' ')

    df = pack_sentences(corpus, tokenizer, 512)
    df['text'] = df['text'].apply(lambda x: ' '.join(x.split()))
    df.to_csv(f'/mnt/nfs/dongkeun/the_pile_512/{pile}.csv')
    print(df.shape)
    print(df.iloc[0]['text'])

