import torch
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
from datasets import load_dataset, load_from_disk
import evaluate
from tqdm import tqdm

def evaluate_model_wikipedia(
    model, device: str = "cuda:0", max_num_samples: int = None, tokenized_dataset_path: str = None, progress = True
):
    batch_size = 1  ## hardcoded. Do batched inference later.
    # Load the pre-trained model and tokenizer
    model.eval()

    # Load the Penn Treebank (PTB) dataset
    if tokenized_dataset_path is None:
        dataset = load_dataset("wikipedia", "20220301.en", split="train")
    else:
        print(f"Loading tokenized dataset from disk: {tokenized_dataset_path}")
        dataset= load_from_disk(dataset_path=tokenized_dataset_path)["train"]

    # Compute perplexity over the entire dataset
    total_loss = 0.0
    total_num_batches = len(dataset) // batch_size
    num_samples = 0
    num_batches = 0
    for i in tqdm(range(total_num_batches), desc="Running Eval", disable = not(progress)):
        # Get the input batch
        input_ids = dataset[
            i * batch_size : (i + 1) * batch_size
        ]["input_ids"]
        inputs = torch.tensor(input_ids).to(device)

        with torch.no_grad():
            outputs = model(inputs)
            logits = outputs.logits

        # Shift logits and labels for perplexity calculation
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:].contiguous()

        # Flatten the logits and labels
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        shift_labels = shift_labels.view(-1)

        # Calculate loss
        loss_fct = torch.nn.CrossEntropyLoss(reduction="mean")
        loss = loss_fct(shift_logits, shift_labels)

        total_loss += loss

        num_samples += logits.shape[0]
        num_batches += 1
        if max_num_samples is not None:
            if num_samples == max_num_samples:
                break

    # Compute average perplexity over all batches
    avg_loss = total_loss / num_batches

    return avg_loss
