import os
from datasets import load_dataset
from six.moves import cPickle as pkl
from transformers import AutoTokenizer
import nltk.data

TOKENS = 300
LENGTH = 1000
DUMP = "DIR/TO/CACHE"
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

dataset = load_dataset("xsum")
dataset = dataset["test"]["document"][:10*LENGTH]
IND = []

s1 = []
s2 = []


lm_tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")

for i in range(len(dataset)):
    
    # ensure there are at least 300 tokens in each passage fragments.
    data = tokenizer.tokenize(dataset[i])
    i1, i2, i3 = 0, -1, -1
    for j in range(1, len(data)):
        tokens = lm_tokenizer(" ".join(data[i1:j]), return_tensors="pt", add_special_tokens=True, truncation=True, max_length=2*TOKENS).input_ids.shape[1]
        if tokens >= TOKENS and i2 == -1:
            i2 = j 
            i1 = j
        elif tokens >= TOKENS and i3 == -1:
            i3 = j 
            break         
    if i2 == -1 or i3 == -1:
        continue
        
    s1.append(" ".join(data[0: i2]))
    s2.append(" ".join(data[i2: i3]))
    
    if len(s1) == 2*LENGTH:
        break

with open(os.path.join(DUMP, "xsum_prompt_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(s1, f)

with open(os.path.join(DUMP, "xsum_truth_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(s2, f)