import pandas as pd
from transformers import AutoTokenizer
from tqdm import tqdm
from os import listdir
from os.path import isfile, join
import itertools
tqdm.pandas()


def pack(ids, max_length):
    total_seq = []
    seq = []
    for id_ in ids:
        seq.append(id_)
        if len(seq) == max_length:
            seq = tokenizer.decode(seq)
            total_seq.append(seq)
            seq = []

    # last seq might might be shorter than max_length
    try:
        total_seq.pop()
    except IndexError:
        pass
    return total_seq

def pack_max_len(corpus: str, tokenizer, max_length):
    ids = tokenizer.encode(corpus)
    return pack(ids, max_length)

    
def pack_max_len_batch(corpus, tokenizer, max_length):

    # n = int(len(corpus) / 10000)
    batch_size = 20000
    batch = [corpus[i:i+batch_size] for i in range(0, len(corpus), batch_size)]
    id_batch = tokenizer(batch)['input_ids']
    packs = [pack(ids, max_length) for ids in id_batch]
    total_seq = list(itertools.chain(*packs)) # Flatten all lists
    return total_seq


if __name__ == '__main__':
    tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
    path = '/mnt/nfs/dongkeun/the_pile_raw/valid'
    onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]

    print(onlyfiles)

    for file in tqdm(onlyfiles):
        print(file)
        # df = pd.read_csv(join(path, file)).sample(n=3000, random_state=42)
        df = pd.read_csv(join(path, file))
        df['text'] = df['text'].astype(str)
        df = df.dropna()
        # df['text'] = df['text'].progress_apply(lambda x: ' '.join(x.split()))

        df['corpus'] = file.split('.')[0]
        df = df[['corpus', 'text']]
        df = df.groupby(['corpus'], as_index = False).agg({'text': ' '.join})

        df['text'] = df['text'].apply(lambda x: pack_max_len_batch(x, tokenizer, 128))
        df = df[['corpus', 'text']].explode('text')
        df = df.reset_index(drop=True)
        print(df.head(5))
        print(df.shape)
        df.to_csv(join('/mnt/nfs/dongkeun/the_pile_200/valid', '_'.join(file.lower().split())), index=False)