from six.moves import cPickle as pkl
from utils import parse_args, tokenize_and_truncate
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np

args = parse_args()

def combine(s1, s2):
    if s1[-1] in [" ", "\n"] or s2[0] in [" ", "\n"]:
        return s1 + s2 
    else:
        return s1 + " " + s2
    
DUMP = "DIR/TO/CACHE"

FILES = [["clean_xsum_prompt_1000.pkl"] + ["clean_llm_watermarked_1000.pkl"] + ["R{}_P_60_0_1000.pkl".format(i) for i in range(1, 5)] + ["R5_P_60_40_1000.pkl"]]
model_name = "facebook/opt-13b"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map = 'auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)


def ppl(c, pc):
    with torch.no_grad():
        tok_prefix = tokenize_and_truncate({"text":pc}, completion_length=0, hf_model_name=model_name, tokenizer=tokenizer, model_max_seq_len=model.config.max_position_embeddings)["inputs"]
        tok_suffix = tokenize_and_truncate({"text":c}, completion_length=0, hf_model_name=model_name, tokenizer=tokenizer, model_max_seq_len=model.config.max_position_embeddings)["inputs"]
        tok_inputs = tok_prefix
        tok_inputs = tok_inputs.to(model.device)
        tok_labels = tok_inputs.clone().detach()
        tok_labels[:, :tok_labels.shape[1] - tok_suffix.shape[1] + 1] = -100
        outputs = model(input_ids=tok_inputs, labels=tok_labels)
        loss = outputs.loss
    return np.exp(loss.detach().cpu())

for files in FILES:
    
    Scores = []

    # load the prompt first
    print(files[0])
    with open(os.path.join(DUMP, files[0]), "rb") as f:
        P = pkl.load(f) 

    # score completions
    for file in files[1:]:
        print(file)
        scores = []
        with open(os.path.join(DUMP, file), "rb") as f:
            C = pkl.load(f) 
        for i in range(len(C)):
            scores.append(ppl(C[i], combine(P[i], C[i])))
            print("{:3d} {:.4f}".format(i, scores[-1]), flush=True, end="\r")
        Scores.append(scores)
        print("Mean:", sum(scores)/len(scores))
            
        
    Scores = np.stack(Scores)
    np.save("scores/ppl13b_{}.npy".format(files[-1].split(".")[0]), Scores)
    print(Scores.mean(1))