# https://huggingface.co/docs/transformers/v4.40.2/en/perplexity

import pickle
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from tqdm import tqdm
from torch.nn import CrossEntropyLoss

T = 1.0

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side="left")

with open("dext.txt", 'rb') as f:
    text = pickle.load(f)
encodings = tokenizer(text, return_tensors="pt")

loss_fct = CrossEntropyLoss()

max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0

step = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    if step % 100 == 0:
        print(f"{begin_loc}/{seq_len}")
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        logits = outputs.logits
        logits /=  T
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = target_ids[..., 1:].contiguous()
        neg_log_likelihood = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    step += 1
    if end_loc == seq_len:
        break
ppl = torch.exp(torch.stack(nlls).mean())