{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6db35712-44f4-4e69-9589-67a5cdab7f5d",
   "metadata": {},
   "source": [
    "## Load Stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cc50c35e-0dbe-4ffd-993d-280b28d4ce8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from mario_gpt import MarioDataset, MarioLM, TrainingConfig, MarioGPTTrainer\n",
    "from mario_gpt.utils import view_level, convert_level_to_png, join_list_of_list, characterize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b44930dd-c364-4c8a-8cb1-7be444e05471",
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE = \"distilgpt2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e14649e1-07a3-4549-b6ab-6d574c464455",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using distilgpt2 lm\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:1177: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
      "  warnings.warn(\n",
      "Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['transformer.h.4.ln_cross_attn.weight', 'transformer.h.5.crossattention.bias', 'transformer.h.2.crossattention.bias', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.3.crossattention.masked_bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.4.crossattention.c_attn.weight', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.0.crossattention.bias', 'transformer.h.1.crossattention.bias', 'transformer.h.3.crossattention.c_proj.bias', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.4.crossattention.q_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.1.crossattention.masked_bias', 'transformer.h.4.crossattention.masked_bias', 'transformer.h.3.crossattention.bias', 'transformer.h.5.crossattention.masked_bias', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.5.crossattention.c_attn.weight', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.4.crossattention.bias', 'transformer.h.2.crossattention.masked_bias', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.0.crossattention.masked_bias', 'transformer.h.2.crossattention.c_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using distilgpt2 tokenizer\n"
     ]
    }
   ],
   "source": [
    "mario_lm = MarioLM(lm_path=BASE, tokenizer_path=BASE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2a9fc71-206f-4c60-800a-d72dc33eeeba",
   "metadata": {},
   "source": [
    "### Load Dataset (Optional)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1ee9ec86-1d5b-45cb-ada6-2b52d10edf5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS...\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (102116 > 1024). Running this sequence through the model will result in indexing errors\n"
     ]
    }
   ],
   "source": [
    "dataset = MarioDataset(mario_lm.tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6bdf49ba-0bd0-4bf0-a8f8-fd7230689481",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['--------------------------------------------------',\n",
       " '--------------------------------------------------',\n",
       " '--------------------------------------------------',\n",
       " '--------------------------------------------------',\n",
       " '-------------------------------------------------o',\n",
       " '--------XSSSSS---------------------------------SSS',\n",
       " '--------X-----------------------------------------',\n",
       " '--------X-----------------------------------------',\n",
       " '-------EX--E-X---------------xxxx-?-----------xxxx',\n",
       " '--------XSS?SX---QQ?QQ------xx<>-x-----------xx--?',\n",
       " '---------------------------xx-[]--x---------xx----',\n",
       " '--------------------------xx--[]---x-------xx-----',\n",
       " 'xxxxxxxxxxxxxxxxxxxxxxxxxxx---[]----xxxxxxxx------',\n",
       " 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXX']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "view_level(dataset.input_ids[:700], dataset.tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d7082d58-d8a6-4e02-95d5-dcf2728ef54a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/Pillow-9.1.1-py3.9-linux-x86_64.egg/PIL/Image.py:992: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAyAAAADgCAIAAAB0EpUWAAAYPUlEQVR4nO3dT2wcVZ7A8fe6q9OOsWNv4mwDQXHsvRhWyybZbAA7M4qGS9QoWhGFA1oJgmwNITkwBywuHo1G8S0ckBgFZXGCJZgDYsgcMmmNOIyi7MYBokSMUFgIKAECJuGPkzh2/Kf/1B6abZz+U1Wv+3X5VdX3o9EInG93qitdzo/X5Sq551BWeCalsG3vOT09PT09PT19FPuYQi7Unp2enp6enp6ePpq92oAFAAAAVwxYAAAAmjFgAQAAaMaABQAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgWUxKtQfQ09PT09PT09M7s5KWsG1RsEWu4H6zaCkFPT09PT09PT29cy/3HMq6VHc+wPVJ6enp6enp6ekj3qudg6X07PT09PT09PT00ew5yR0AAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANAsJqXaA+jp6enp6enp6Z1ZSUvYtijYIldwv1m0lIKenp6enp6ent65l3sOZV2qOx/g+qT09PT09PT09BHv1c7BUnp2enp6enp6evpo9pzkDgAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAaBaTUu0B9PT09PT09PT0zqykJWxbFGyRK7jfLFpKQU9PT09PT09P79zLPYeyLtWdD3B9Unp6enp6enr6iPdq52ApPTs9PT09PT09fTR7TnIHAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0i0mp9gB6enp6enp6enpnVtISti0KtsgV3G8WLaWgp6enp6enp6d37uWeQ1mX6s4HuD4pPT09PT09PX3I+gdmXlZ4gOo5WEpbQ09PT09PT08fwV4IYSk/AgAAINr+8x8Olf753qHLQoiDBw8uDfgpQgAAAAVLpyshxORYjxBieHh46RcZsAAAAJRljl3OHLtc61cZsAAAALwqLl9ljl1O7+pJ7+opfj5YuYjFgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAePXH6/uEEMUrYN07dLl4BazKu+VYBt6wmp6enp6enp7eqL5ScbSqxUpawrZFwRa5gvvvJKWgp6enp6enp49av9Qfr+9zvdmz3HMo6/Ksd/4GRo2TIeiPPnT2yEj/4OiEe7y5Xwhhn3cvw+SZ9/7de2zgny89PT09vT/90YfOKjyg+dQGLGj3+sNnhRCuM1ZxuhJCjKWr/OpQJrRf5/0JAPBi218SlV9cxr+/OMndCIOjE0dG+mv9amm6AgAAgcAK1jIrrmAV1ZqxhjJ+bY15eH8CALwY31dlBWsZWcu9AVBTWoo08OO8Znz9f6okAACUM+3vRz4iBAAA0IwBCwAAQDPOwVpmnIPljPcnAMALzsFCQ0z4XNnPrwMA4IVpf3/xESEAAIBmDFgAAACacQ7WMiudg+VwNdGxdHQ/LOP9CQDwgnOwUIXztdqLH/EWZyzTrvPR7K9zHSwAgBem/f3IR4TLz8udcGr9cQIAAAPFpFR7AL3m3vN9BqM5Yxn350VPT09Pb2RvGrn3taxti4ItcgVh2261FElL0GvsX3/OrM+MTcP7k56enp4+iH+fqp3kLqX7i6RX6o8+dPbISP/g6IR7vLlfKJ70beDrpaenp6enj0Kvdg6W0rPTe+wHRydqXcO9xPsniY1vDz09PT09PX2DPSe5G8F5xqpvugIAAMuFyzSYwss6FgAACARWsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAM66DZQqHq4mOpcVQxs9tAQAADWEFywjO12ofyoixtG/bAgAAGsWAtfy83AmHGQsAgACJSan2AHrNvef7DBZnLOO2n56enp6enr6ClbSEbYuCLXIF95tFSyno9fZKhjJi7+NmbT89PT09PT19ZW/NZ12iMvR6e/v8xJGR/sHRCdeyuNZl2vbT09PT09PTV1I7B8t1ZKOvox8cnTgy4vJBofdPEhvfHnp6enp6evoGe05yN4LzjFXfdAUAAJYL18EyhZd1LAAAEAisYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZlwHyxQOVxMdS4uhjJ/bAgAAGsIKlhGcr9U+lBFjad+2BQAANIoBa/l5uRMOMxYAAAESk1LtAfSae8/3GSzOWMZtPz09PT09PX0FK2kJ2xYFW+QK7jeLllLQ6+2VDGXE3sfN2n56enp6enr6yt6az7pEZeid/dONl8u+surk70v/PL39d2W/ap+fODLSPzg64frMxbUu07b/wl2/UXr+V//trNLr3XNI7QWY9n6gp49yv+XPCe/x4OjpZ97b2tTtoaf3s1f7KULXkY2+zNLppPivlTPK4Kj7jOX9k0Tn7VHtvWy/6vOb/Hrp6en19o/t6vHYy80DSv9BZebrpacv9Zzk3kRl04nDF4szR63nqW/aaJz37Vdl5usFfBYv5NbMXl01P7X69rXUrStrZq8mc/MB6j26Z+hygwEgzHv/u/ZcB6tZHAYRh3WsJm+UAtXtV2Xa6wX8t/r2tf2ZJ7/p2tgxO5maunCz7b6TD+57vzedjScD0Xv07dhPi1gnjt0xSJUWt0oB4MC0979rzwpWU6w6+ftnT8w8e2Km8peKX9eyDtQ8Qd9+IBAS+YXuyTObLr6xkGj7oO/plsXp3aee77t6Nii9kuJ0NZT56X+iYt4CnJn2/nftGbD0Kw4fhx9rs9/7a9mM8uyJGfu9vx5+rE1o+qytGYK+/UCA5Kz4xfXpV3aMv7N1ePzRw4nc7Uc+fStAfR2+HespLVnx4SCUmPb+d+75iLDpqq4DBUjQtx8wWdZqubJ2oxCiIOPTK7tutPd0zE4GqAf8ZNr737lnBUuzynWd4pJPcUHINV52Qd9+IFgSufnu786tmp+6a3E6Nf1F561LU+3dAeq9e2xXT+meqnw4iPqY9v537lnBaq6yoaRyRjFc0LcfMFw8n98weeqJMwe+7+zd/Nnbt1rXne/dGaBeSfGTwdLJWIAq097/zj0DVhP914HdDv9qvqBvP2C+rLXiq7sHUtc/uf/LzExr6t0tL364fnuA+kbww4NQZdr737lnwNJsevvvln5w9uvf/qksWDqmNH6xA+2Cvv1AsOTjyc/XbTu+aX/H3I8zyc5FtwsimNbX4bFdPZzbjvqY9v537hmw9CuNHS/84l8qf/XXv/3TS//9UeXXHa6uOZb2dTm9vu1XZc7rBZZdXlpTrang9kpOHLs8lla4wjtQxrT3f62ek9yXwfDw8PDw8NKvOF+7fCgjxtJN3iYVlduvKlivF2gSW8YWEu25mNcb9pnWN4JFLKgy7f3v2jNgNcvw8LB8eEfZF+XDO0pfLM0oXu4M4//M4X37VZn5egH/zSXazvU99fWaBwLae1Q2Sy39ccKqAVCVae9/114+82pW6RaGUqrd8jBq/T/Pvlz8h9L8UfrJu8p55eDBg0pjylhanN7Z3D8v1e3/uO03Ss8/vk/hP459eL309PTN67f9JeH9o8B7hi7z9xF9mHoraQnbFgVb5Aruj5RS0Dv3JaXhST68o+rVDQ4ePOjydBWGMmLv42Ztf93P74UPr5eenr55/VBGiIzC6pRp209P30gv9xzKulR3PsCo8TAE/dGHzh4Z6R8cnXCPN/cLIezz7mUjlG7APDh6WgjFoUkIpddr2vtz4LjCCtzg6OkjIwOB7p95b6v33sDjK+j9688193Qo044venqTe9Xv/2o/Rai0NfQe+8HRCdeZo3TeUtUBqNYZS/V93fuSvtw8oPr89vkJpderxJ8/L9X9E+he6S9gM4+voPcvfaz2KO9eUDzVysz9Q0/vZ6/0/ZPLNBjBeeaob9poxD1Dl52vAegaODPt9apS3T9B7+FRvJDrnPshG19hFbKJ/GIulphJdi5YLY30o3+74yEjvyp/krLAT6qvF1iqGceLD7x//2TAMkVx5nDNal0gStfXi0rvnrJbhpWG91JQ3/MLz6/XTN73Tzh6eLT69rX9mSe/6drYMTuZmrpws+2+kw/ue783na1xuUIvfWmichikio3zpDXyKzH6t5//XwjxgtqLq3P7gVqacbz4wPv3TwasgCl99Kb3Y8Gqym4ZNpYWJ45dLlsgrXt7QsDL/glTD1eJ/EL35Jl/nProfzf8xwd9T//rpT/vPvX8zda1H63bpqVvRNl0pYWf24/wMfl4ceXl+yfXwYK7b8d6SiM5V6yppLp/gt7DQc6KX1yffmXH+Dtbh8cfPZzI3X7k07c09nVbuoKlkW/bj1Ay9njxzuH7JwMWAGiTtVqurN0ohCjI+PTKrhvtPR2zkxr7ujVjBUv4uP0IJWOPFy34iDBglp7b1IzzsZZ+nLf0nqxlHzbXvT2Do9WbIPKyf8LUw4tEbr77u3Or5qfyMSs1/UXnrUuX7v2Fxr5uledgaeHb9iOUjD1eXHn5/smAFTBNPe+qckIqrnyWfdjcyPaEjOv+CVkPV/F8fsPkqSfOHPi+s3fzZ2/fal13vnenxr5uTVrB8m37EUrGHi9euH7/ZMCCAn64zJnq/gl6j0pZa8VXdw+krn9y/5eZmdbUu1te/HD9do193Zq0guXb9iOUjD1e6lD5/ZMBC+6WroWikur+CXoPB/l48vN1245v2t8x9+NMsnPR7QfIVfu6NWkFy7ftRygZe7x45/D9kwHLFA5X1xxL/7z86M91sMqcOHa51hXA635+j683EBz2Tyh7uMpLa6o11by+Dk1awSryYfsRYgYeL945fP9kwDKC87XLi6c0FWcOP6+DVabq5Wvr2x7vrzdAVK9+HvQelWwZW0i052Jeb1im2jeiGdOVn9uP8DH5eFFV9fsnl2lYfl7uDOPz5TrLFjwf29VTNu408omSga9Xler+CXoPj+YSbef6nvp6jdeb/Kn2jWjGdbD83H6Ej8nHiwPv3z8tA29YHa3e8333/Jw5KifxWvdaUaX6ek8b9udVpLp/gtsbd7yY3d9Y2fXmwAGNvai4JXPV+9t4uenNCxX/X7k9rhp8varPTx/uXvvxYtr3f7n3taxti4ItcgX330lKkbQEvcb+9edMXO00h2nvz6j9eZm2/6PWH91rCSGklC6pOtu2hRDPjeWMer309Cb3qt//5Z5DWYXa7HGYnp6ePkx96Rv6Sx//9MXiB3xl51GVzqyq9fFf2dlXpVUxvv/T0zevVzsHS3UxjZ6enp7en96BlrOvTHu99PSG9/wUIQBUFy/kOud+yMZXWIVsIr+YiyVmkp0LVouu3jdNujqDKmP3D7QIzfGiCwMWAFS3+va1/Zknv+na2DE7mZq6cLPtvpMP7nu/N52tcXlD1d43zbsClhJj9w+0CM3xoguXaQCA6hL5he7JM5suvrGQaPug7+mWxendp57vu3pWV+8bE6YrYfD+gRahOV50YcACgJpyVvzi+vQrO8bf2To8/ujhRO72I5++pbH3h/YrYNXNzP0DXcJxvOjCgAUANWWtlitrNwohCjI+vbLrRntPx+ykxt4fhqxgCVP3D3QJx/GiCwMWANSUyM13f3du1fzUXYvTqekvOm9dmmrv1tj7w5wVLDP3D3QJx/GiCye5A0BN8Xx+w+SpJ84c+L6zd/Nnb99qXXe+d6fG3h/mrGCZuX+gSziOF11YwQKAmrLWiq/uHkhd/+SXf/+DEOLdLS9+uH67xt4f5qxgmbl/oEs4jhddWMECgJry8eTn67Yd37S/Y+7HmWTnotsPkKv2/jBnBcvM/QNdwnG86MIKFgC4yEtrqjXl/bu/at9s5qxgFZm2f6BX0I8XXRiwAKA6W8YWEu25mNc7vKr2vjFkBcvY/QMtQnO86MKABQDVzSXazvU99fWaB9zTunrfGLKCZez+gRahOV50sQy8ATU9PT29Cf2NlV1vDhxoXu+bWitYhu/PZm8Pvd5+2Y8X03oraQnbFgVb5Aruj5RS0NPT09P70+tS616Epr1eevow9dZ81iUqQ09PT0/vT69LrRUs014vPX2YerVzsFQXt+np6enp/ekdaDkHy7TXS09veM91sABERbyQ65z7IRtfYRWyifxiLpaYSXYuWC26emMZ8lOEqkKz/wMqsseLLgxYAKJi9e1r+zNPftO1sWN2MjV14WbbfScf3Pd+bzpb4wI8qr2xap2DZbjQ7P+AiuzxoguXaQAQFYn8QvfkmU0X31hItH3Q93TL4vTuU8/3XT2rqzdWEKcrEaL9H1CRPV50YcACECE5K35xffqVHePvbB0ef/RwInf7kU/f0tibyZDrYNUhHPs/uKJ5vOjCgAUgQrJWy5W1G4UQBRmfXtl1o72nY3ZSY2+mgK5gibDs/+CK5vGiCwMWgAhJ5Oa7vzu3an7qrsXp1PQXnbcuTbV3a+zNFNwVrHDs/+CK5vGiCye5A4iQeD6/YfLUE2cOfN/Zu/mzt2+1rjvfu1Njb6bgrmCFY/8HVzSPF11YwQIQIVlrxVd3D6Suf/LLv/9BCPHulhc/XL9dY2+m4K5ghWP/B1c0jxddWMECECH5ePLzdduOb9rfMffjTLJz0e0HyFV7MwV3BSsc+z+4onm86MIKFoDIyUtrqjXl/bu/am+a4K5gFQV9/wdd1I4XXRiwAESFLWMLifZcLNGk3lgBXcEKzf4PqMgeL7owYAGIirlE27m+p75e80CTemMFdAUrNPs/oCJ7vOhiSal2C0N6enr6gPY3Vna9OXCgeb3q9vim1gpWyP68mr09UesDd7yY1ltJS9i2KNgiV3B/pJSCnp6ent6fXpda9yI07fXS04ept+azLlEZenp6enp/el1qrWCZ9nrp6cPUq52Dpbq4TU9PT0/vT+9AyzlYpr1eenrD+/LrYMULuc65H7LxFVYhm8gv5mKJmWTngtVS6yno6enpw9qHRkB/ilCVae+fqPUoUz5grb59bX/myW+6NnbMTqamLtxsu+/kg/ve701na1zQgp6enj6sfWjUOgcrZEx7/0StR5nyjwgT+YXuyTObLr6xkGj7oO/plsXp3aee77t6ttbj6enp6cPah0YUpith3vsnaj3KVDkHK2fFL65Pv7Jj/J2tw+OPHk7kbj/y6VsOT0FPT08f1j4cAnodrDqY9v6JWo+lqgxYWavlytqNQoiCjE+v7LrR3tMxO+nwFPT09PRh7cMhIitYwrz3T9R6LFVlwErk5ru/O7dqfuquxenU9Bedty5NtXc7PAU9PT19WPtwiM4Klmnvn6j1WKr8JHchRDyf3zB56okzB77v7N382du3Wted793p8BT09PT0Ye3DITorWKa9f6LWY6kqA1bWWvHV3QOp65/c/2VmpjX17pYXP1y/3eEp6Onp6cPah0NEfopQmPf+iVqPpaoMWPl48vN1245v2t8x9+NMsnPR7Qcy6enp6cPah0NEpith3vsnaj2WqjJgFeWlNdWa8v5E9PT09GHtgy46K1hFpr1/otajqPwkd1vGFhLtuVjC4+Pp6enpw9qHRkSmK9PeP1HrUaZ8wJpLtJ3re+rrNQ94fDw9PT19WPvQiMhPEZr2/olajzLymVezSrcwlFLtlof09PT09PX14/t+Wjx46eOfvlIcksrWokqrU7VGqLIVrBf+/29Mvv/T0zevt5KWsG1RsEWu4P5IKQU9PT09vT+9LrXOwTLt9dLTh6mX9vmJIyP9g6MTLq0QcnO/EIKenp6enp6efs+hrGv580Oav7x09KGzRu2fmBBicHTiyEi/l7qInp6enp6ent47pWmp7t6o/RPz8pjKvUlPT09PT09Pbxpz9s/P18HyMpd5/z3o6enp6enpo9CbxpD9U+VmzwAAAGgEAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmv18HSyHq4eNpcVQpvyL9PT09PT09PSmMWT/xFxrIcRQRoylvT47PT09PT09fUR605izf2KudeVj6Onp6enp6elNY9T+ka7pUqprg/T09PT09PSh7E/vzCrdkllKtVs4q/bj+xLeYx/2j9qABQAAIITY+1rWtkXBFrmC+yQkpUhaoqn9688pDFg++D9omqPgoC0DtAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=800x224>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = convert_level_to_png(dataset.input_ids[:700],  dataset.tokenizer)[0]\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b0186ee-8f66-435f-965b-3de9e344afd1",
   "metadata": {},
   "source": [
    "### Setup training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "36e44f8c-878c-463e-a481-36229f295572",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = TrainingConfig(save_iteration=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "faf6fa43-e970-47a5-87c5-78bb7c2672e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = MarioGPTTrainer(mario_lm, dataset, config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8d8e037e-fae6-4c04-914c-3736d422c9ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training for 100 Iterations and batch_size 1\n",
      "================== Training Config ==================\n",
      "gradient_accumulation_steps -- 1\n",
      "mixed_precision -- no\n",
      "output_dir -- Mario-GPT2-700-context-length\n",
      "learning_rate -- 0.0005\n",
      "epsilon -- 1e-09\n",
      "lr_warmup_steps -- 1000\n",
      "batch_size -- 4\n",
      "total_steps -- 50000\n",
      "mask_proportion -- 0.0\n",
      "eval_iteration -- 1000\n",
      "save_iteration -- 10\n",
      "================== MarioLM ==================\n",
      "Follow tensorboard with: python -m tensorboard.main --logdir /home/kokkgoblin/Code/mario-gpt/notebooks/Mario-GPT2-700-context-length/logs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "{'loss': 0.5392114520072937, 'last_lr': 5e-05}: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:21<00:00,  4.68it/s]\n"
     ]
    }
   ],
   "source": [
    "trainer.train(100, batch_size=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py39] *",
   "language": "python",
   "name": "conda-env-py39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
