import torch
from datasets import load_dataset, load_from_disk
from tqdm import tqdm

def evaluate_model_on_list_of_text(
    dataset: list[str], model, tokenizer, device: str = "cuda:0", progress = True
):
    # Load the pre-trained model and tokenizer
    model.eval()

    assert isinstance(dataset, list)

    # Compute perplexity over the entire dataset
    total_loss = 0.0
    num_samples = 0
    num_batches = 0
    
    dataset_indices = torch.tensor(range(len(dataset)))
    
    for i in tqdm(dataset_indices, desc="Running Eval", disable=not(progress)):
        input_string = dataset[i]
        input_ids = tokenizer(input_string, return_tensors="pt")["input_ids"].to(device)
    
        # Assert the shape is (1, *)
        assert input_ids.shape[0] == 1 and input_ids.shape[1] > 0, f"Unexpected shape: {input_ids.shape}"

        if input_ids.shape[1] < 2:
            # Skip sequences with a length of 1
            continue

        if input_ids.shape[1] > 1024:
            input_ids = input_ids[:, :1024]
        
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits

        # Shift logits and labels for perplexity calculation
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()

        # Flatten the logits and labels
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        shift_labels = shift_labels.view(-1)

        # Calculate loss
        loss_fct = torch.nn.CrossEntropyLoss(reduction="mean")
        loss = loss_fct(shift_logits, shift_labels)

        total_loss += loss
        num_samples += logits.shape[0]
        num_batches += 1

    # Compute average perplexity over all batches
    avg_loss = total_loss / num_batches

    return avg_loss.item()