### FILE TO GRADE GENERATIONS PERPLEXITY USING LLAMA2-13B

# imports
import pickle
import os
from heapq import heappush

import torch
from tqdm import tqdm

from eval import *
from llama.metrics import *
from llama.generation import Llama

# mode
result_file = "./owt/p15_d3_ngram4_llama7B_owt.pkl"
mode = "eval"
prompt_len = 15

print(result_file)
# load models
reg_model = Llama.build(ckpt_dir="./70B/", 
                    tokenizer_path='./7B/tokenizer.model', 
                    max_seq_len=100, 
                    max_batch_size=64,
                    device=None,
                    model_parallel_size=8)

# load result file
filename, ext = os.path.splitext(result_file)
if ext == ".pt":
    r = torch.load(result_file)
elif ext == ".pkl":
    with open(result_file, "rb") as f:
        r = pickle.load(f)

print(mode)
loop = tqdm(total=len(r), position=0, leave=True)
if mode == "tune":
    heap = []
    min_div = 1
    for param in r:
        seqs = r[param]
        n_prompts, n_drafts, gen_len = seqs.shape
        print(seqs.shape)
        # ppl
        output_ppl = calculate_perplexity(reg_model, seqs.reshape(n_prompts * n_drafts, -1), prompt_len=prompt_len, bsz=64, marker=False)
        output_ppl = torch.mean(output_ppl)    
        heappush(heap, (output_ppl.item(), param))
        
        # div
        output_div = calculate_diversity(seqs[:, :, prompt_len:].tolist())
        if output_div < min_div:
            min_div = output_div
        loop.set_description(f"Average Perplexity: {output_ppl.item():.4f}, Div: {output_div.item():.4f}")
        loop.update(1)
    print(f"Min Div: {min_div}")
    
    # Saving
    print("Saving...")
    with open(f"{filename}_llama_tune.pkl", "wb") as f:
        pickle.dump(heap, f)
    print("Saved")
    
elif mode == "eval":
    with torch.no_grad():
        n_prompts, n_drafts, gen_len = r.shape
        print(r.shape)
        output_ppl = calculate_perplexity(reg_model, r.reshape(n_prompts * n_drafts, -1), prompt_len=prompt_len, bsz=64, marker=True) # bsz must be 1 for pads
        output_ppl = output_ppl.reshape(n_prompts, n_drafts)
        draft_avg = torch.mean(output_ppl, dim=0)    
        draft_std = torch.std(output_ppl, dim=0)
        # ovr_avg = torch.mean(output_ppl)
        
        best_ppl = output_ppl.min(dim=-1)
        best_avg = torch.mean(best_ppl[0])
        best_std = torch.std(best_ppl[0])

        print(f"Draft Avg: {draft_avg} Draft Std: {draft_std}")
        # print(f"Ovr Avg: {ovr_avg}")
        print(f"Best Avg: {best_avg} Best Std: {best_std}")
        
        # Saved
        torch.save(output_ppl, f"{filename}_ppl.pt") 
        
        # Get and save drafts w/best ppl
        best_idx = best_ppl[1].cpu()
        r = r.cpu()
        best_gens = torch.zeros(n_prompts, gen_len, dtype=torch.long)
        for i in range(n_prompts):
            # print(r[i, :, :])
            best_gens[i, :] = r[i, best_idx[i], :]
            # print(best_gens[i, :])
        # print(best_gens.shape)
        torch.save(best_gens, f"{filename}_best.pt")
else:
    print("Invalid mode")
