import utils
from utils import move_to_cuda, decode, CACHE_LOCATION, load_model

import time
from tqdm import tqdm
import torch
from IPython import embed
import logging
import gc
from timer import Timer

# MODEL_SIZES = (["125m", "350m", "1.3b", "2.7b", "6.7b", "13b", "30b", "66b"])
MODEL_SIZES = (["small", "base", "large", "xl", "xxl"])
logging.getLogger().setLevel(logging.INFO)
timer = Timer(print_results=True)

for i, model_size in tqdm(enumerate(MODEL_SIZES)):

    model, tokenizer = load_model(model_size, family='byt5')
    timer.snap("Loaded {} model".format(model_size))

    if False:
        inputs = move_to_cuda(tokenizer("I am the {} OPT model and".format(model_size), return_tensors="pt"))
        library_outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
        logging.info("#############################################")
        timer.snap("Library outputs: {}".format(decode(tokenizer, library_outputs["sequences"], delete_after_period=False)))

        # Clear up model memory
        del model, tokenizer, inputs, library_outputs
        gc.collect()
        torch.cuda.empty_cache()
        timer.snap("Cleaned up {} model.".format(model_size))
