import os
import sys 
sys.path.append("../lm-watermarking/")
from pprint import pprint
from functools import partial
import numpy as np
from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          AutoModelForCausalLM,
                          LogitsProcessorList)
from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
from utils import load_model, parse_args, generate
from six.moves import cPickle as pkl

LENGTH = 1000

args = parse_args()
args.prompt_max_length = 300
args.max_new_tokens = 300
args.min_new_tokens  = 200

with open(os.path.join(args.DUMP, "xsum_prompt_{}.pkl".format(LENGTH)), "rb") as f:
    s1 = pkl.load(f)

with open(os.path.join(args.DUMP, "xsum_truth_{}.pkl".format(LENGTH)), "rb") as f:
    s2 = pkl.load(f)
    
model, tokenizer, device = load_model(args)

o1, o2 = [], []

for i in range(0, len(s1)):
    
    _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ \
        = generate(s1[i], args, model=model, device=device, tokenizer=tokenizer)        
        
    # print(decoded_output_with_watermark)
    print("{:4d}/{:4d}, {:4d}".format(i+1, len(s1), tokenizer(decoded_output_with_watermark, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=1000).input_ids.shape[1]))
    
    o1.append(decoded_output_without_watermark)
    o2.append(decoded_output_with_watermark)
    
with open(os.path.join(args.DUMP, "llm_output_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(o1, f)

with open(os.path.join(args.DUMP, "llm_watermarked_{}.pkl".format(LENGTH)), "wb") as f:
    pkl.dump(o2, f)