import os
from six.moves import cPickle as pkl
import nltk.data

LENGTH = 1000
DUMP = "DIR/TO/CACHE"
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

with open(os.path.join(DUMP, "llm_output_{}.pkl".format(LENGTH)), "rb") as f:
    s1 = pkl.load(f)

with open(os.path.join(DUMP, "llm_watermarked_{}.pkl".format(LENGTH)), "rb") as f:
    s2 = pkl.load(f)
    
with open(os.path.join(DUMP, "xsum_prompt_{}.pkl".format(LENGTH)), "rb") as f:
    p = pkl.load(f)
    
with open(os.path.join(DUMP, "xsum_truth_{}.pkl".format(LENGTH)), "rb") as f:
    t = pkl.load(f)
    
tmp2, tmp1 = [], []
p_, t_ = [], []

for i in range(len(s2)):
    
    flag = True
    data = s2[i].split("\n")
    data = [item for item in data if item != ""]
    
    # if single word sentences, discard passage
    for j in data:
        if len(j.split(" ")) <= 2:
            flag = False
            
    if not flag:
        continue
    
    sent = ". ".join(data)
    data = tokenizer.tokenize(sent)
    
    # if repeated sentences, discard passage
    for j in range(len(data)):
        if data[j] in data[j+1:]:
            flag = False
            break
    if flag:
        tmp1.append(s1[i])
        tmp2.append(s2[i]) 
        p_.append(p[i])
        t_.append(t[i])       
        
print("\n", len(tmp1))

with open(os.path.join(DUMP, "clean_llm_output_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(tmp1[: LENGTH], f)

with open(os.path.join(DUMP, "clean_llm_watermarked_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(tmp2[: LENGTH], f)
    
with open(os.path.join(DUMP, "clean_xsum_prompt_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(p_[: LENGTH], f)

with open(os.path.join(DUMP, "clean_xsum_truth_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(t_[: LENGTH], f)