from src.daemon import DaemonDecoding
from tqdm import tqdm

import json
import csv
import os
import argparse

EVAL_MODEL_CUDA = "cuda:5"

ckpt_path = "gpt2-xl"


parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, required=True, help="data path to be evaluate (.csv)")
parser.add_argument("--mu", type=str, required=True, help="optimal mu path (.json)")
parser.add_argument("--tau", type=float, default=0.99, help="temperature")
args = parser.parse_args()

eval_data_path = args.data
optimal_solution_path = args.mu

max_length = 256
eval_bsz = 32


os.makedirs("ppl_results/", exist_ok=True)
save_ppl_path = "ppl_results/" + optimal_solution_path.split("/")[1].split(".")[0] + "_ppl.json"



scorer_configs = {"seq_rep/2": {"n":2, "requires_detokenize": True}, 
                "seq_rep/3": {"n":3, "requires_detokenize": True}, 
                "seq_rep/4": {"n":4, "requires_detokenize": True}, 
                "tok_rep/8": {"l": 8, "requires_detokenize": True}, 
                "tok_rep/16": {"l": 16, "requires_detokenize": True}, 
                "tok_rep/32": {"l": 32, "requires_detokenize": True}, 
                "coherence": {"prefix_len": 32, 
                                "max_length": 256,
                                "device": EVAL_MODEL_CUDA, 
                                "simcse_model_path": "sup-simcse-roberta-base"},
                "diversity": {"requires_detokenize": True},
                "information": {"max_length": 256, 
                            "lm_model_path": "gpt2-xl",
                            "device": EVAL_MODEL_CUDA}}

model_name = ckpt_path.split("/")[-1]
geo_decoder = DaemonDecoding(model_name, ckpt_path, "cuda:6", "float16")
optimal_mu = json.load(open(optimal_solution_path, "r"))

oracle_texts = [line[0] for line in csv.reader(open(eval_data_path, "r"), delimiter=",")]
oracle_token_ids = [geo_decoder.tokenizer.encode(x)[:max_length] for x in tqdm(oracle_texts, desc="tokenizing...")]

ppl_res = geo_decoder.compute_perplexity(oracle_token_ids, eval_bsz, max_length, scorer_configs, optimal_mu, args.tau)

json.dump(ppl_res, open(save_ppl_path, "w"))
