from datasets import load_dataset
from llama3 import tokenizer
import torch
from torch.utils.data import DataLoader

dataset = load_dataset("Trelis/tiny-shakespeare")

def get_loaders(batch_size: int):
    trainloader = DataLoader(dataset["train"], batch_size, shuffle=True)
    validloader = DataLoader(dataset["test"], batch_size, shuffle=False)
    return trainloader, validloader

if __name__ == "__main__":
    max_token = 0
    total_token = 0
    for entry in dataset["train"]:
        for text in entry["Text"]:
            text += "<|end_of_text|>"
        inputs = tokenizer(entry["Text"], return_tensors="pt")
        ids = inputs.input_ids
        max_token = max(max_token, ids.shape[-1])
        total_token += ids.shape[-1]

    print(max_token, total_token)
