import torch


def calculate_ppl(sentence, model, tokenizer, device='cuda'):
    model.eval()
    encodings = tokenizer(sentence, return_tensors='pt')
    encodings = {key: val.cuda() for key, val in encodings.items()}

    # Calculate log likelihood and the number of tokens
    with torch.no_grad():
        outputs = model(**encodings, labels=encodings['input_ids'])
        log_likelihood = outputs[0].item()

    # The perplexity is the exponential of the cross-entropy (negative log likelihood)
    # normalized by the number of tokens
    perplexity = torch.exp(outputs[0] / encodings['input_ids'].shape[1])
    # print(perplexity)
    return perplexity.item()
#
# import pandas as pd
# from datasets import Dataset
#
# wiki_path = '/home/chenyuheng/chenyuheng/NIPS2024/Datasets/EXP3/test-wikitext-2-v1.parquet'
# data = pd.read_parquet(wiki_path)
# hf_dataset_full = Dataset.from_pandas(data)["text"]
# hf_dataset_full_filter = [s for s in hf_dataset_full if 100<len(s)<200]
# hf_dataset = hf_dataset_full_filter[:5]