import os
import sys 
sys.path.append("../lm-watermarking/")
import argparse
from argparse import Namespace
from pprint import pprint
from functools import partial
import numpy as np
import torch
from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          AutoModelForCausalLM,
                          LogitsProcessorList)
from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
from utils import load_model, parse_args, detect
from datasets import load_dataset
from six.moves import cPickle as pkl

args = parse_args()

DUMP = "DIR/TO/CACHE"
files = ["clean_xsum_truth_1000.pkl"] + ["clean_llm_watermarked_1000.pkl"] + ["R{}_P_60_0_1000.pkl".format(i) for i in range(1, 5)] + ["R5_P_60_40_1000.pkl"]

model, tokenizer, device = load_model(args)
del model

Scores, Tokens = [], []

for f in files:
    
    print(f)
    scores, tokens = [], []
    with open(os.path.join(DUMP, f), "rb") as g:
        texts = pkl.load(g)
        
    for text in texts:
        result = detect(text, args, device=device, tokenizer=tokenizer)[0]
        scores.append(float(result[3][1]))
        tokens.append(tokenizer(text, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=1000).input_ids.shape[1])
        
    Scores.append(scores)
    Tokens.append(tokens)

Scores = np.stack(Scores)
Tokens = np.stack(Tokens)

np.save("scores/scores_{}.npy".format(files[-1].split(".")[0]), Scores)
np.save("scores/tokens_{}.npy".format(files[-1].split(".")[0]), Tokens)