{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c6df025c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "from utils import load\n",
    "from models.flax_models import LeNetLarge\n",
    "from jax import jit, grad, value_and_grad\n",
    "from jax.example_libraries import optimizers\n",
    "import jax.numpy as np\n",
    "import jax\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as onp\n",
    "import pyt\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "def load_params_at_epoch(run, epoch):\n",
    "    params_ckpt = f\"/homes/ag2198/euclid-scratch/run-outputs/outputs/{run}/checkpoints/params_{epoch:08d}.pkl\"\n",
    "    if Path(params_ckpt).exists():\n",
    "        print(\"Already converted\")\n",
    "        with open(params_ckpt, \"rb\") as f:\n",
    "            params = pickle.load(f)\n",
    "        return pyt.tree_map(np.array, params)\n",
    "    \n",
    "    checkpoint = f\"/homes/ag2198/euclid-scratch/run-outputs/outputs/{run}/checkpoints/checkpoint_{epoch:08d}.pkl\"\n",
    "    checkpoint_state = load(checkpoint)\n",
    "    start_epoch, step, il_state, batch_stats, opt_state = checkpoint_state\n",
    "    assert not batch_stats\n",
    "    opt_init, opt_update, get_params = optimizers.sgd(1)\n",
    "    opt_state = optimizers.pack_optimizer_state(opt_state)\n",
    "    params = get_params(opt_state)\n",
    "    np_params = pyt.tree_map(onp.array, params)\n",
    "    with open(params_ckpt, \"wb\") as f:\n",
    "        pickle.dump(np_params, f)\n",
    "    return params\n",
    "\n",
    "model = LeNetLarge(num_outputs=10)\n",
    "forward = lambda x, p: model.apply({\"params\": p}, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0711e014",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Already converted\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Params(Shape:FrozenDict({\n",
       "    BasicMLP_0: {\n",
       "        Dense_0: {\n",
       "            bias: (120,),\n",
       "            kernel: (3200, 120),\n",
       "        },\n",
       "        Dense_1: {\n",
       "            bias: (84,),\n",
       "            kernel: (120, 84),\n",
       "        },\n",
       "        Dense_2: {\n",
       "            bias: (10,),\n",
       "            kernel: (84, 10),\n",
       "        },\n",
       "    },\n",
       "    Conv_0: {\n",
       "        bias: (50,),\n",
       "        kernel: (5, 5, 3, 50),\n",
       "    },\n",
       "    Conv_1: {\n",
       "        bias: (50,),\n",
       "        kernel: (5, 5, 50, 50),\n",
       "    },\n",
       "}))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAjA0lEQVR4nO2deZyN9fv/XxeGMYx9X8cuy9iGhBSpJIVCtBBqZIsP0pRESd8o68dSSNYP2UJkS0Kp7PuarWiGsYyxheH9++Mcv4f6Xi8mY874fO/r+Xh4OHO9znXu97nPuc59zn3d13WJcw6GYfzfJ1VKL8AwjMBgwW4YHsGC3TA8ggW7YXgEC3bD8AgW7IbhEdIkxVlE6gMYDiA1gPHOuY9udf8cmVO5wrlTq9rliyHU72DWP1V7yK6y1Ccu01aqlTiXn2oxpROodv2i/tlY7PfT1Gd31pJUSxvL1xhUJYhqmY+Wptp1XFHtl64e5OvIUZ5qlxM2U+381XCqZZDLqj0u6wnqk+pQXqoh0y4unctItVwFr6n2+L1h1Cf0Kk9H762wn2pFTmanWkIq/rzjj1dW7TmK8dfs6O4cqv0yjiPBnRVNkzvNs4tIagD7ADwK4CiA9QBaOufoq1KlZJBbO0rfIYc3V6TbavbMbtUeUVG3A8C8R/kbZ+nq96k28Mc4qp3fGKza5/acRn2qNV1JtQKj81AtX0Juqj3R62eqXZBDqn3nsZbUp9Arv1Lt0IlMVFsT8wfVqgfrjzmn6b+pT0irvlRLU5d/sNRbVZtqrw+NU+2L63xOfeoc5R/4D8c0oNrUiS9RLTb9CKotH3pRtbed3YL6vHl/pGrfldARF9w+NdiT8jW+GoBfnXMHnXNXAMwA0CgJj2cYRjKSlGDPD+D3m/4+6rcZhnEPkuwn6EQkUkQ2iMiG2LPXk3tzhmEQkhLsxwAUvOnvAn7bX3DOjXXORTjnInJmtpP/hpFSJCX61gMoISJFRCQtgBYAFtydZRmGcbe549Sbcy5BRDoDWApf6m2Cc27nrXwOHSiNl5/RPw/C6vEz5GU26ym2ORtPUZ/QoJNUe9dNpdqgkFlUGzqyuGov8N5n1KfvN/yMakg8/6yN3/4J1Z5YzNNQwWXrqfbNuY5Tn9BNA6jWLqoW1aYW4OnS8H5PqfYtWYdRn3mzClCtyi9Vqbak2zqqZXldf613bnqR+rxz6CuqtexckGqyrALVph79lmoFftHPxjf8bTL1qdNrn2rfN4GnbJOUZ3fOfQPgm6Q8hmEYgcF+RBuGR7BgNwyPYMFuGB7Bgt0wPIIFu2F4hCSdjf+nxOc6j+Vtf1C1KUf0C/sBIMf4t1X7iGXpqM+rufZQ7fAknuaLmhZHtVk5x6v2LPddoj49B0dQbULfFVR7pRRPDUWOHke1k2OPqPaR1/VUDQC02c8Ligac44VSxe6bTbWhc4uo9o4F+bamzUxPtY+W8FTq2Qo89dZ6q/5af9+Bp8KK7plPtet9tlOt3pjRVMvzIb96NMuh31R7ze/4e2dExzdV+9X5epUfYEd2w/AMFuyG4REs2A3DI1iwG4ZHsGA3DI9wx22p7oTspVK5x0frrZ1SRfEL+EcF91HtH49uRX2mPM21+W3OUG1h+/VUy5/2qGqv0u8w9fm+T1Oq5e4RS7UzJ9NS7drAKVRrvLSMan/yDX4WPOG76lSbO/5Tqi19pDHVFrZao9q/qP8I9bmQtgfVDm3XnxcAFL2udmHybe81/bm9nZ63S+wTxvd9/td51mjvtfuolidW76MIAPm76IUwL7YvRH3q1tGzDDtOfYsLV0/f9bZUhmH8F2HBbhgewYLdMDyCBbtheAQLdsPwCBbshuERApp6yxxSydUq8b2qtSQ9tQAgJvUvqj1vRz6aaNTMGlSbPWsD1d4IKkW1uINfqPa2z/IJIh1/aE21dJE8HTN+OF/HnO7/q4nv/6dVx4Gqff4xPb0DAIv+LEe1F8OHUK15Pt6R7LEVerHO5Ml6oQ4AFIm6SrWWM6ZT7eT0aKqV7aq/D7p05OnLXPE9qXb2M96f7qfc3am2rDV/3tMW/6Ta+x/kE3cGzhmm2u8fEI+NhxMs9WYYXsaC3TA8ggW7YXgEC3bD8AgW7IbhESzYDcMjJKkHnYgcBnAOwDUACc453jQLQKGSsRi2XO/TVStqDPWb3CxGtX/xDh/FM+TXf1Gt+bb6VDvwjd7vDgAm1nhZtZfdson67BySg2rb3uA93N6NzUq18qV4qqlvW31fxY3j1WvpOt1PtTMVrlAtX3H+cv/PQ7lU+7Xp/HnlPMknfkdn4JVola/o/e4A4OUZD6n2lev6UZ83e/P31dVM86g28oOlVBtWsArVOs0+r9on12lAffrMy6Ta/xgxmPrcjYaTdZxzvBugYRj3BPY13jA8QlKD3QFYJiIbRYRX9RuGkeIk9Wt8LefcMRHJBWC5iOxxzq2++Q7+D4FIAMhXIEsSN2cYxp2SpCO7c+6Y//8TAL4CUE25z1jnXIRzLiJb9gxJ2ZxhGEngjoNdRDKISOiN2wAeA7Djbi3MMIy7S1K+xucG8JWI3Hic/zjnltzKIdVvOZCJNOzLnJGf0E9fMbe+gAg+pmcQ9lNt6bhnqPbrF1updrJLTtXeskM36rN4+yiqRRWsQ7Xz+/XUFQCsfqkX1eLrdVDt1QdMpT4DavHqtWXj36Ba8dN6I1AA+Pm7bKo9djAfkfRDdr7GNNeWUe342GFU65Qlo2q/9OFr1GfW6aFUm1hHH18GAAsmHaDaxTZdqBb2lH6MXFT1S+rz8lE9lbrgKm9iesfB7pw7CKDCnfobhhFYLPVmGB7Bgt0wPIIFu2F4BAt2w/AIFuyG4RHuRiFMotmX6jDqhrysaiEZq1K/Ri+8qtqbB71Afc73SUe1+Ql8ptiqarxa7pc39OacwXt5pdHGGrzCrlLQi1Trt483L7wvhjdELPtnuGrPnJk3tzySm88oG9lwAtVa9ehGtRaP6RVsZSedoj55ozdS7cdFRak24ufhVOseeVy153pwL/X5V/S3VMs9dznVYtrzarlyP+vvYQAYW01POxdr+Rj16bjmoGo/dY430rQju2F4BAt2w/AIFuyG4REs2A3DI1iwG4ZHCOj4p8qVCrsfV/VWtdhwXgjTu4FeFDLz7CXqk+cMH1v0+Bn+GfdOXBjV/jU9n2rPVH4E9Xk2fBvV8g3h2YRKz/IxQ/Hp+Pq/WKCfpX0wPS/SOP+fH6nWoBofDRXfviLVSh3No9pj5vIedDGTeB/CEWdnUe1cQV6i8dF3ek++Lo73i/upIj8bf+g3KmFVk0F8HQ/y13qD6M/tbKfT1Oel7zOr9knP/IqY7Zds/JNheBkLdsPwCBbshuERLNgNwyNYsBuGR7BgNwyPENBCmP2XruCJrXru4mI5nirL3mOuam/5Dk/XjR7OCy4GnWtPtUd3paZa6OdlVPvSQUeoz4w0/6Fa2u0fUG17hZ+oFvJuTaq1ePsP1T5sPy8M+l4OUe2RJ/THA4AR63kB0JMJeu+3tlv58+o6fzzV4j9tRrXRRbk2P4ueto1+gfe7uzKQp6NnTP2Oags78v566dqVpVrPGXoPusW/8aKWBn31gqd5t5jNZEd2w/AIFuyG4REs2A3DI1iwG4ZHsGA3DI9gwW4YHuG2qTcRmQCgIYATzrlyfls2AF8CCANwGEBz59yZ2z1WhkNpUbW1Xjk2rffb1O/PMu+o9j3nm1OfNfX5aKhuqXl/txc7LaTa1zniVfuwXbz6LvOlVVQrUKgU1d5Kd55qcY2CqZanRUPVHvm7PkILAIal4qnISf35qKmwodep5p77RLWXC+1HfT6txlOYz9fiaa34Jp2plrpeFdU+eDgfAVblWiuqfdWF96ArGcn7DQ5pq7+HASDDkSDV3n8iHw/2eNosqv3KaZ46TsyRfSKAvz+LKAArnHMlAKzw/20Yxj3MbYPdP2/974W1jQBM8t+eBKDx3V2WYRh3mzv9zZ7bORftvx0D30RXwzDuYZJ8gs75Wt3Q6wtFJFJENojIhovX+O9QwzCSlzsN9uMikhcA/P+fYHd0zo11zkU45yJCUuvXSxuGkfzcabAvANDaf7s1AH7q2zCMe4LEpN6mA3gYQA4ROQqgL4CPAMwUkXYAjgDgObCbCM4filIf1FW1x3fzE/ozdunNEh+dyLfVviofrTQriFfYnep/gWrLPtDTUBNm8tRVzYPrqPZuKb1pIADkm/sS1WKW8KaNjy6op9rXXmlLfeZenEi15tUiqFZkF09D1S+hN7F8a/lH1GfkC/w90Ofxd6n2VbUsVPuskl7d1nThHOqzYthsqhU9vYxqteJ4KrJYWd6p8sK0Dqr90KIE6jPrOz2F/dap3dTntsHunGtJpEdu52sYxr2DXUFnGB7Bgt0wPIIFu2F4BAt2w/AIFuyG4REC2nAyY8bfUKN2F1ULacyvuO2fu5Nqr/cjrzZbs/QbqqVr9ArVMoXrM8oA4K3JejrvwdDWqh0A2sfz9FQYnqJavYSOVPtsS2mqnVpAZss9uJX6vBPFUzy1Pt1ItfYZefPF1k3fU+1ngvg8tFOfxlLtRHv+elbNp6cbAeC5pfq+2v0xbyrZc3YGqmXJUJRqOWbz9+P6WzTM7Bo0SrXH9OXz4S7U0puEXsNl6mNHdsPwCBbshuERLNgNwyNYsBuGR7BgNwyPYMFuGB4hoKm31KfSI9uEcqpW4WGhfkPzj1XtvwwIpT4bE6ZTreTS4VSrl0dv/gcA/7p/hmqfyAuacHhDV6qNy6/PsAOAwt/ySrqhFcdRbVYXfQ7cxeN/Up9Js3kaqlIoT1O61/l+rLJAn+nWu24I9an2gj6/DACO5WtHtVODeeoTZQqp5m9716IuHz3L59t99yJPewYPLUm16B94qM3I95lqj50YR32KNVig2lM9wFN8dmQ3DI9gwW4YHsGC3TA8ggW7YXgEC3bD8AgBPRv/5x/Xsaev3k76kdb7qd+QjStV++75lajPxvjFVFs/YyDVTqRrTLVUW9eq9sdf70N9qk9ZQ7UGI/5FtbcHZKHaqZGPUu3rYF2b8G4D6rOzU22qfTzuWaq1WsjHaH198D7VnvM7/prd/yCVEFaKn3Fv0rwf1bLFZ1ftMQV4QUvpPbwQpk2ByVS7cp3vx5xhF6m2Mnqzal+3Ixf1Wfuf9ao99BDvoWhHdsPwCBbshuERLNgNwyNYsBuGR7BgNwyPYMFuGB4hMeOfJgBoCOCEc66c39YPwKsAbjQNe9s5x5uE+TmQLxTPvqaPf8rYYSb125xHT6MdyqindwCgfDs+Pmn+yj1Uu9iAV7XMflwvkqm/4zHq88Itplm/kV8vrAGAGvMqUq3z9CpUKx2ij3lqUJfv35JpeOqw1xtpqfbeSj4aqtmC06p915gnqc/1onzKb8nG+amWfjn367V4mmpf2IGPcUo1O4xq+6J4KvLL1cWo1vqx76m2cNIU1T7m38HU56lF+vt0f9or1CcxR/aJAOor9qHOuYr+f7cNdMMwUpbbBrtzbjUA/WPaMIz/GpLym72ziGwTkQkiwseKGoZxT3CnwT4GQDEAFQFEAxjM7igikSKyQUQ2uAvn7nBzhmEklTsKdufccefcNefcdQDjAFS7xX3HOucinHMRkoF3ljEMI3m5o2AXkbw3/dkEwI67sxzDMJKLxKTepgN4GEAOETkKoC+Ah0WkIgAH4DAA3jDtJtLhGIrKW6rWpSmvGDrR/3HVfiUyhvr8z06+jpPdeUrj/aBbpJNi9Qq28uXKU5+SP6yjWp6a/ah26jJPDUV259VQ5xfq+yTyWX3EEACUHEt/hWFhuapUK7SXp/Ou1tFTn3/s4ZV+IwpSCWu68ITP9e5RVLt8VD+3/N0rPP1atnwbqqUfU5lqq/Lxsr1Fy3hfu4/L6uHToGZf6lPu0tOqvdU83nvxtsHunGupmD+/nZ9hGPcWdgWdYXgEC3bD8AgW7IbhESzYDcMjWLAbhkcIaMPJa7FXcObT31Vt5XU9vQYAw2rpzQY3hpagPiWqn6Da10fyUq1jlvFUe6Cknq5JN/cj6jOvG08Z5TrCU0ZXEi5R7WRavbINAAYc0iuogvO+SX0qD9ZHcgFAfNNJVAsZo6dRAWDuA/p1Vn/W5o1FB5RYTrUS4z+lWvNFX1Dt++FaDRewoF0c9ekVdJVqXbb1olrovF1UO17kW6rtzJhHtfd4+Efqc3FSDtUenY5XANqR3TA8ggW7YXgEC3bD8AgW7IbhESzYDcMjWLAbhkcQ51zANpaxVGZXbuwDqvbKv/V5bgDwdJqGqv2hcm9Qn+o79dQEAFS8eJ1q49bos8EAoPt5PeU1vOI71KftcZ7K27U/lmphn1yjWsOv+HOr3beFal+fcyz1ea4+r+Tq0IOnyvpmPUK1KwO0+ikg1SR9RhkArG/WhGqLm+rpKQBYGM/L5fqc1JuV5mnzMvUZuZZXU/bedohqJTOsoNqTI3dTrexH1VX7uOJ8vl3IqEKqfU9sHC5eSRBNsyO7YXgEC3bD8AgW7IbhESzYDcMjWLAbhkcIaCFMkfT5MLXse6qWq3lJ6hdV7HnVvjejfpYeAILHT6Ba2Lu8Qd19o/lopfrt9bOjQ9vrZ54B4PeCvABiVS0+NmrU4sZUm7+V90jbm3uOar+6sR71+VL0M9YA0Pl5XlBUoGQRqs07p48nqtmej+Va9RDPMqQtwotM2vZITbWHrxVQ7T2DvqY+O4o/SrVadV6mWodO+vsUAFpNvkUvvzQDVHv+/HycVNXQp1R741P8edmR3TA8ggW7YXgEC3bD8AgW7IbhESzYDcMjWLAbhkdIzPinggAmA8gN37insc654SKSDcCXAMLgGwHV3Dl35laPdenI79jevpuqLbl8gfo9f1EvPnjtj0rUJ6Jlb6pd2B1EtTffCqHaoUo1dZ+1FalPht/bUW3mc7z3W5+VvDBoav3iVDtTY7Nq75JpIvXZO5antfaN7k61V0r2p9rz/R9W7Qv68sKrIXXDqXbsRf7Wkrf4qKzCuUup9hbN+doTMj5Btd9Spafam2FnqZapFR+HGNX+M9Vesjn3Ofv8KdV+rTovNErMkT0BQA/nXBkA1QF0EpEyAKIArHDOlQCwwv+3YRj3KLcNdudctHNuk//2OQC7AeQH0AjAjdajkwA0TqY1GoZxF/hHv9lFJAxAJQC/AMjtnIv2SzHwfc03DOMeJdHBLiIZAcwB0M05F3+z5nwdMNQfYyISKSIbRGRD/GXej9swjOQlUcEuIkHwBfo059xcv/m4iOT163kBqBdRO+fGOucinHMRmdLxE2OGYSQvtw12ERH45rHvds4NuUlaAOBGZUhrAPPv/vIMw7hbJKbqrSaAlwBsF5EtftvbAD4CMFNE2gE4AqD57R7oysWr+GPTcVV7qXBO6je945OqfdnIjNQnPE0tqlWa8gHVnr+fV3K9uFRPODRrx9OGUet6UK1OIV7l9VZxng47l4ZX7a1YofdjC3qHj4wKm8orpWJq61WKADA4H++T16qbnhr6cjHvG3j9Gd6n7fMLfDzY8AU1qBb+zCjV/s2+rtRnY0ImqnU4zt9zDZfzNOXyuR2p1ur7XKq9TQgfa/VsFz3dGHxUrzYEEhHszrkfAKgN7AA8cjt/wzDuDewKOsPwCBbshuERLNgNwyNYsBuGR7BgNwyPENCGk6mKZEf6CS/oC+nwOfVbs1uvAPt3xfepT+4ar1DtzFh+ScDEYaFU299BTzU9uIk3PEx1iDdYbNG3DNWO9uIjpeaE8Zftwx/1NGD9d3lF1pxpPB22+r1Iqo0bP5lqnzXOrNq7tvmB+ow/ztOv67bkp9rcT3nqrfniCqo901e8Omxk5EKq5cvdmGrTs/Imp88tS0u1w5n1Cs1FL2ynPk226M1WD1ziaVk7shuGR7BgNwyPYMFuGB7Bgt0wPIIFu2F4BAt2w/AIAU29hcTGovIovblehdW80eOcbHraJejASOqTPlMw1Zod4mmXj9oPptrRR/Wqt7ZL+Dy0oN7VqXY5ZjzVxizlDRF7T+HPLef2WNVe5AOephzdn6feSjzejGrXvh1LtQmnn1Pty9rwOWody+vNMgHgyqCLVBtUSE+vAcC50gtU+4Fo3jgyfRSv9DtYrz7VMj7QmGozBvOUWKmn9Uq18OrrqE/0T6NV+9Xrl6mPHdkNwyNYsBuGR7BgNwyPYMFuGB7Bgt0wPIL4ukAHhkKpM7qeGfQzp13e42OSJn35kGp/rS0/mz0s7kOq5YkaRLWvI+dQrXptvSDnqUFdqM/6Ap2otrXrcqrVOH2dat+NyUq1fcX1s8+ld/AMxEbhZ9UHPMvHUAUv4z3XBg9vrNrLTx1Gfe5v05Rq6X54jWrduz5AtTd7zVXtzx/n48HWXn2eaqX716PahQZZqLYn+DTVCkTqmYYSafS1A0Dx4Fn6dnq2wYVfd6tt5OzIbhgewYLdMDyCBbtheAQLdsPwCBbshuERLNgNwyPcthBGRAoCmAzfSGYHYKxzbriI9APwKoAblRdvO+e+udVjZXCZUPWynrqY0qU29QsP/1i1V9i/ivpMGMSLCOpkfZVqH8wNp9q6Ut/qPlPOUJ8rvfjYn1eXl6Xa044Xp+xc8iPVHhgTp9oXd+lMfTrn0vuZAUDj479QbdTEwlTbX19/zfZHb6Q+CSP4aKWB63hxyrTfeX+9fO4P1d49rhT12VTyJNVqH9BTmwDQ6qkpVHs+6EGqLUqnFz1tnVCa+vTMpo9/OngmXrUDiat6SwDQwzm3SURCAWwUkRsJ4qHOuU8S8RiGYaQwiZn1Fg0g2n/7nIjsBsBbfRqGcU/yj36zi0gYgEoAbny36ywi20Rkgojwy7oMw0hxEh3sIpIRwBwA3Zxz8QDGACgGoCJ8R36164OIRIrIBhHZcMbxBgSGYSQviQp2EQmCL9CnOefmAoBz7rhz7ppz7jqAcQCqab7OubHOuQjnXERW4d1oDMNIXm4b7CIiAD4HsNs5N+Qme96b7tYEwI67vzzDMO4WiTkbXxPASwC2i8gWv+1tAC1FpCJ86bjDANrf7oEypPoT1dPvUbXf3sxA/bJm1VM8pco/RX2yTelAtfBwPt7n4qKZVPum3mOqff7UbtSn1jK9JxwAFF66hGqzZ+2i2mtt9V54ALB9rp5q6ppvLfUJTeBrfOQnPv7p37UrU23U1L2qfeIpffwXAPQup/diA4BsA56hWvxyfVsAsHRyJtXe5os46vNm3MtUi/lMT78CwIQevH/hG9UOUa1hK72C7elCn1Kf4D/18VpXJAmpN+fcDwC0krlb5tQNw7i3sCvoDMMjWLAbhkewYDcMj2DBbhgewYLdMDxCQMc/7cySHuUb6pVezZrwzN0HD1dS7Xmab6I+ufr8SbUacbx5YZOwMVRb0ECvvJrRtCD1OfEJb154oAgfCXSpJR8zdJ2kXQBgwJAfVHtYj+LUJ28nXplXeDMfNbV7Dvc7slVPUW0ez5tDxsXXoFqZonoVHQAsr8KberYcNU21136Vjw57v3EbqtWcySv94lOnplrxTDzdG/6JfqV5vp1vU5/z2f+j2uUsX4Md2Q3DI1iwG4ZHsGA3DI9gwW4YHsGC3TA8ggW7YXiEgM56Sx2W2qXvq9e0P7LgSeq35Kw+p2x+1ZLUp9Fp3qwvZ9ZJVPvktQFUS51uiGrfvJZ36Tq25D6qZRpzhGov1OaVbf0b5aNa8zz6fmx1jvcSeOBcDqplCa1OtfCa+v4AgCU/67P7VhdsTH3SLWxCtd07nqNavzovU23tXL3yulfU+9RnakOeUtwykGvDpvEqxq2lq1Ktcs13VHufYTzN9/MmfVvNm+zDju0XbdabYXgZC3bD8AgW7IbhESzYDcMjWLAbhkewYDcMjxDQ1Fua+8JdlomLVG3pK8epnwvWU2XRuXh6bffV8lTr3I2nk9o15DPi+o3WK/OWTm5LfdbX+YJqpTZyv6s5+Ry41auuUe3C/DWqPeTMeeozN/vnVNvZTk+hAUDluoupdvplveqweZoC1Cck7SWqNXt/INVimvFKr05xZVR7jqp9qE+D33nz00YNeOrt0mDeADXq7AaqVRyQVrWv+OIV6tNmsp5KXbFmPk7HxVrqzTC8jAW7YXgEC3bD8AgW7IbhESzYDcMj3LYHnYgEA1gNIJ3//rOdc31FpAiAGQCyA9gI4CXn3JVbPVb2tAloGXZC1Vb9MYH6vf7ZJ6o9aP3P1GfT2a+olko6U+3VlcOoFtFaP+tbayE/0z2l7WSqLW3Hp9oWnlWTarlfvUq1+998WrVf7lmK+rjXeCFP81L9qBaaoM7yBAA0XKZnPPLX0EcdAcDwd/mIpyeP8H6DsW22Ue3Dafpr9kYcH3k1bsgCqpX5JI5q2ffy99X9q/T+hQAwr1CEai/adDz1+aTqVNW+4enl1CcxR/bLAOo65yrAN565vohUBzAQwFDnXHEAZwDwHI1hGCnObYPd+bhx6Ary/3MA6gKY7bdPAtA4ORZoGMbdIbHz2VP7J7ieALAcwAEAcc65BP9djgLg3wUNw0hxEhXszrlrzrmKAAoAqAaAX7r2N0QkUkQ2iMiGS6fP3NkqDcNIMv/obLxzLg7ASgAPAMgiIjdO8BUAcIz4jHXORTjnItJn05vhG4aR/Nw22EUkp4hk8d9OD+BRALvhC/qm/ru1BjA/mdZoGMZd4LaFMCISDt8JuNTwfTjMdM69LyJF4Uu9ZQOwGcCLzrnLt3qsdNlyu3yPtlC12al5/666e/U0VOsB6vX+AICts3hqpetlPT0FAO/nXUq1rOO6q/YuszpQn4oZWlEtT32ehpqeKp5qW7aGUu3zRno67Os2PE02/jJ/2WqVf5BqJ4rwNU4KGaHavzzBU1C7jvWg2m/BPIVZZdMeqlX4qp9qr9l5CvXpVT831db14Gv8+BleRHXh4B9UK/zsT6r99Qtrqc+U96br2/n2G1w7fUoNjNvm2Z1z2wD8r2FrzrmD8P1+NwzjvwC7gs4wPIIFu2F4BAt2w/AIFuyG4REs2A3DIwS0B52IxAK4MfMoB4CTAds4x9bxV2wdf+W/bR2FnXM5NSGgwf6XDYtscM7ptX22DluHreOur8O+xhuGR7BgNwyPkJLBPjYFt30zto6/Yuv4K/9n1pFiv9kNwwgs9jXeMDxCigS7iNQXkb0i8quIRKXEGvzrOCwi20Vki4jw+Tx3f7sTROSEiOy4yZZNRJaLyH7//8le/E/W0U9Ejvn3yRYRaRCAdRQUkZUisktEdopIV789oPvkFusI6D4RkWARWSciW/3reM9vLyIiv/jj5ksR0edGMZxzAf0HX6nsAQBFAaQFsBVAmUCvw7+WwwBypMB2awOoDGDHTbZBAKL8t6MADEyhdfQD0DPA+yMvgMr+26EA9gEoE+h9cot1BHSfABAAGf23gwD8AqA6gJkAWvjtnwLo8E8eNyWO7NUA/OqcO+h8radnAGiUAutIMZxzqwGc/pu5EXx9A4AANfAk6wg4zrlo59wm/+1z8DVHyY8A75NbrCOgOB93vclrSgR7fgC/3/R3SjardACWichGEYlMoTXcILdzLtp/OwYA76CQ/HQWkW3+r/kB7SUmImHw9U/4BSm4T/62DiDA+yQ5mrx6/QRdLedcZQBPAOgkIrVTekGA75Mdvg+ilGAMgGLwzQiIBjA4UBsWkYwA5gDo5pz7SxucQO4TZR0B3ycuCU1eGSkR7McAFLzpb9qsMrlxzh3z/38CwFdI2c47x0UkLwD4/9dH5yQzzrnj/jfadQDjEKB9IiJB8AXYNOfcXL854PtEW0dK7RP/tuPwD5u8MlIi2NcDKOE/s5gWQAsAfN5OMiEiGUQk9MZtAI8B2HFrr2RlAXyNO4EUbOB5I7j8NEEA9omICIDPAex2zg25SQroPmHrCPQ+SbYmr4E6w/i3s40N4DvTeQBA7xRaQ1H4MgFbAewM5DoATIfv6+BV+H57tYNvZt4KAPsBfAsgWwqtYwqA7QC2wRdseQOwjlrwfUXfBmCL/1+DQO+TW6wjoPsEQDh8TVy3wffB8u5N79l1AH4FMAtAun/yuHYFnWF4BK+foDMMz2DBbhgewYLdMDyCBbtheAQLdsPwCBbshuERLNgNwyNYsBuGR/h/4nvH98MaZnkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "run = \"peach-river-1270\"\n",
    "params = load_params_at_epoch(run, 10000)\n",
    "\n",
    "\n",
    "init = jax.random.uniform(jax.random.PRNGKey(0), shape=(32, 32, 3))\n",
    "plt.imshow(init)\n",
    "\n",
    "\n",
    "params\n",
    "# print(forward(init, params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8538507c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Dataset CIFAR10:\n",
      "32 x 32 x 3 images with 10 classes\n",
      "50000 train points and 10000 test points with a batch size of 250.\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "dataset = datasets.CIFAR10(\n",
    "    batch_size=250,\n",
    "    data_location=Path(\"/homes/ag2198/data\"),\n",
    "    include_flip=False,\n",
    "    data_aug=None,\n",
    "    key=None,\n",
    "    randomise=False,\n",
    "    data_limit=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2a00531d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fcf4c4348b0>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAf+UlEQVR4nO2da5DcZ5Xen9OXmZ77aGY0o5E00kiyJGTLtmyEYmPHkCVgw5Iy1G5c8IH4A7XaSkElVDYfXGxVIFX5wKYCFB+2SJngWpMQDFlgcRl2F68BG2JsI99kybJ1v8+MrqOeS9/75MO0q2Tnfd4ZS5oewf/5VanU8z7z9v/tf/fpf8/79DnH3B1CiD98Uku9ACFEc1CwC5EQFOxCJAQFuxAJQcEuREJQsAuREDJXM9nM7gPwDQBpAP/D3b8S+/3O7h7vHxwKapXiLJ1XLRWD4zHTMNuSo1pLaxvV0tkWqqVSFhwvFqbpnHKpQDWv1agGDx8LANKZNNXMwu/fHZ1ddE5rjp8rr1WpVijw54w9O3Wv0xnFAj9Xtcg6YvYxO1y1ytdRq8deWXxeJsPDKZ2JXVfD9xlzxetkjbOzRZRL5eCL54qD3czSAP4awIcBnATwOzN73N1fZ3P6B4fwl1/966B28o0X6bHOHtkXHK/V+PKH1ryHams2bKHashVrqJZrCx9v/95n6ZxjB3dTrTLF3yTSkcfWvayHaplce3B8x1330Dk3bOLnqnjpAtX27nmZavV6OTheroTfuAHg9b2vUS0/eY5qpXKJapVy+I3xwnn+RjU9y9dYrfFjLV/eR7VlfZ1Uq/lU+FgVOgXFQjjYn/7Vc3TO1XyM3wHgoLsfdvcygMcA3H8V9yeEWESuJthXAThx2c8nG2NCiOuQRd+gM7OdZrbLzHZN5y8t9uGEEISrCfZTAEYu+3l1Y+xtuPvD7r7d3bd3dvO/NYUQi8vVBPvvAGw0s3Vm1gLgUwAevzbLEkJca654N97dq2b2eQD/iDnr7RF33xubU6/VkL8Y3t3t7+U7mb48bNd5ppvOGV6znmq1Ot/mTNX5Lm19Nmz/FC+ep3O8wHd2Vw0MUm3NyA1UG7lhLdVWrlodHB8klicAZLOtVKv2hnf3AWBk9Qo+rxrejS8Wub02eZG7E+fOcVcgE7FZYeHd+GX9/DHnOvgaL+UvUq01x8Op7tw6zGbCa8lfmqRzyqXwbrxHbMOr8tnd/WcAfnY19yGEaA76Bp0QCUHBLkRCULALkRAU7EIkBAW7EAnhqnbj3zXuQCVse5VL3A6bnQ3bOKOb+Ldzp2dmqBZLxugbiCSZZMPvjRs3bqJz3n/HdqqtGgrbZADQ07OcapUMz5Zrz4VtnEwkg8qqkcy2GW6HlchzCQDtbWHLblkvtxs3rL+Ravv2vUk1GF9HqRS2Unu6l9E5kcRHXMpPUM0Rfp0CPEsNAC5eDL9WC7M86YZlxHkkq1BXdiESgoJdiISgYBciISjYhUgICnYhEkJTd+O9XkeVJEJYle8wt7aEa8ZdOsdLFfWv4Dvda27iSSaDIyuplmXbtJH6QZUq3/l/Y4wn0MwePsvvM8V3fd987dXg+Pu28J3ue3a8j2qx+m75SH2C48dOB8dbspHagC08sWlgOXdejp84wO+TlOmaLnC3Jp/nr6tMltcG7O7mSUOxen2svF6sTl5ra/i1aMbXpyu7EAlBwS5EQlCwC5EQFOxCJAQFuxAJQcEuREJoqvVWr9dQmg1bHp1t3JLp7gsnhdx+6zY6Z2T9RqpNRRI/3jx8gmr52bB9Mj05Seecn+T22tg4r2fWHUmEQYonSDzx/R8Gx7MP8Pf1D9x5N9WyWW4rrljBbUp42L6avBjufgIAL73Mu+dkInXyOrq4ZVetha3D8vQknZOOXAJjXV9qNW6Jnr/A7bwUwpZdrJ1Ub284YSud5q3BdGUXIiEo2IVICAp2IRKCgl2IhKBgFyIhKNiFSAhXZb2Z2VEAUwBqAKruzguuAUiZobU1G9Qq6S46r9AWbmR/JM/b9LzymxeoduE8r6t26jSvMZZNhzOKsimenVQibZAAoFjk2vBy/tScGT9GtW6SDTU1madz9h85wtcxPEC1bJavcXgk3BpqJRkHgOPj3PZ88zWuDQ5zm/LocWJ5VfhzVi9zrRap/5dr4fZgayb8ugeAQjF8n93d3FLMkJZRZvz6fS189n/hTkxVIcR1gz7GC5EQrjbYHcDPzexFM9t5LRYkhFgcrvZj/N3ufsrMBgE8aWZvuPszl/9C401gJwD0LuO1uoUQi8tVXdnd/VTj/zMAfgxgR+B3Hnb37e6+vaMjvNEmhFh8rjjYzazDzLreug3gIwD2XKuFCSGuLVfzMX4IwI8bBe4yAP63u/9DbEIqlUV7+1BQOzPJM9EOngjbLq/v5e8tqYgtVIu0mipM8UKEaWKxFUrc1pqc4tpUpLXS0ZP7qNbRxm3KzRs2h4WIBfh/f/0rqq1dt45qmzbztlf9/eGsrNYcf156url1lary4pYzJX7NYi2UCpM8+65W40VCc23cQpvO8/vsjmTmtebCmWrlcqwlWjgDs17ntuEVB7u7HwZw65XOF0I0F1lvQiQEBbsQCUHBLkRCULALkRAU7EIkhKYWnExnMujtC2dRHTyxn84bOxrOymrP8sKLl2Z4Mcfp/BmqWcS6mJwKW2WTBW7VZEiWHwAMDA1Sra0rbF0BwKpRboKMEBvnyKu/pXPSxm25So1neZ09x4tp3nzzluD4DRvX0zkjkey1zjtuo9ruN45TrVQMFzItZSNZb+A2Wd25RTw+Hu5vBwAtrdxW7FnGXgfcBi4UwhmfdeePS1d2IRKCgl2IhKBgFyIhKNiFSAgKdiESQlN340ulGRw6FK4N98ahg3Te6bFDwfFaJGmlq6eDaps3jlJt65atVBs7G94BPXaWr2P5inDiDwCs3cCTTLr6+U79xEV+PD8Xdi6OH+M71mcjLaq23EglfHhTeMcdAGamyW4x39yHl7krsPc57iZs3LyNakOreoPjz73wTHAcAMYnePJSpcJ344sFvv6LkbZXbZ29wfHYzvoMaaMWS4TRlV2IhKBgFyIhKNiFSAgKdiESgoJdiISgYBciITTVepuZyuO5Z54ML2SI1E4DsGHLzcHxtkibni03bqTa5k2rqVYrhhNJAMBTYTtpBrwhTiYbTsQAgHS6l2qVKk+cmJm6QLWectgaqtaczjl+hicN5TpP8WN189Lg6zeMBsc9cn0pTIbrqgHAG8+/QjUv8NfB1nvvC47ffAtPyCns4tbboYNHqdbezqsn9/T2U22ue9r/Tz7Pn5dSKXyuXNabEELBLkRCULALkRAU7EIkBAW7EAlBwS5EQpjXejOzRwB8HMAZd9/aGOsD8H0AowCOAnjA3blP0KBSqeLMibBNddutf0zntbaGa5P1cZcMwyt5HbELkdY/Jw5yW6tcD9thKeOpXOkMt0JqzmvooRprXxW2AAHAa+HjdfaEa/8BwPlpnkWXauHZg3Xndt5cN+/QJD6jM8efs9GVI1TLpfk6UgjXDbx5K8847O3tpdrjhZ9TbXyMh8CqwZVUq1m4hmE20sIsnw/bg29kT9I5C7my/w2Ad5qVDwF4yt03Aniq8bMQ4jpm3mBv9Ft/5+XufgCPNm4/CuAT13ZZQohrzZX+zT7k7mON2+OY6+gqhLiOueqvy7q7mxn9o8nMdgLYCQDZLK+hLoRYXK70yj5hZsMA0Pifdl1w94fdfbu7b89kmvpVfCHEZVxpsD8O4MHG7QcB/OTaLEcIsVgsxHr7HoAPAhgws5MAvgTgKwB+YGafBXAMwAMLOVgqlUF7Z19Qy0ZcnMnJ8AeH1r5eOme2yj2eIu/WhLZlXVRrrRu5Q269eeQMFys8yyvXxiemIu2a6qnwvM5+bv20OLcb0208s81buPdZt/Bjsxq38lJp/pizHS1Ua+vkWrUUtlnPn5qgc/o7eBuq+z92L9V2vXqUatORYpTF0tngeIm0eAKA3q7e4HgmzZ+TeYPd3T9NpA/NN1cIcf2gb9AJkRAU7EIkBAW7EAlBwS5EQlCwC5EQmvotl2xLK4bXhLONLMXfd4rFcIbPRJ4vv6WXZ3lVqtyqsci3/ArT4QyqivO1ZzK8cGQ1zbX2bp4BNtg/STW/ELZrypEeZVbn629ra6NaKpJ1WPfw8Wo1blOmspFin2m+xukZnsVopABja+T1lj/Lbbm29rB1DAD33HkL1d48dIxqe14fD45P53k2YgspZFqPFBbVlV2IhKBgFyIhKNiFSAgKdiESgoJdiISgYBciITQ9wdwtbK9UItbQ7FTYWmmN2EJT+UjhyCIv9Dib5zZOliS9dXVwC235Mm7VdPfxDLDlvfyx1TI9VCu0hs/jhbU8661UG6MaIpl5tWok+45kCNZSPBvRItZbbx/PvqvXImskr6ueHn5+W3gtFkxOTVLNK2FrFgC2bVlBtd6u8OvniSd4ccuzE+HCrZUqjyNd2YVICAp2IRKCgl2IhKBgFyIhKNiFSAhN3o13gOzgZup8Z7cn/J1/jPSQ7XEA71nfS7XOHN+JTRt//5vJTwbHi7OX6Jy2jgrVNm/kO/Uja1dTLZVdS7Xpycnw/Q0P83UcocWB0d1HTj6AvmU8WSeTCScb1SO1Bj2SWJPraKdatRjZgSbHy8YSr8Ddmv6BTqpNz3JXYGYynOwCAKuWh2vefeJffYTO+buf/lNwPJvhJ1FXdiESgoJdiISgYBciISjYhUgICnYhEoKCXYiEsJD2T48A+DiAM+6+tTH2ZQB/BuCtvjVfdPefzXdfXR3t+MCd7w1q62+8lc47fepUcHzVSm5dbdq4gWorlg9SLe3czpsiSRClSLKIpfj9dXbwRJjOTm55pVu4dZglFmZhJtxiCABu38qtvNFNo1Sr1Lmt6OQ6Uq1zm8zT/Fyls/ylWilyP69OEmFSGX6dsxxfByLzShV+PjJpXtuwVp4Mji+P2Hx3//P3Bcefe2EPnbOQK/vfALgvMP51d9/W+DdvoAshlpZ5g93dnwHA80WFEL8XXM3f7J83s91m9oiZ8WRjIcR1wZUG+zcBbACwDcAYgK+yXzSznWa2y8x2Tc/w5H4hxOJyRcHu7hPuXnP3OoBvAdgR+d2H3X27u2/v7OAbDkKIxeWKgt3MLs+q+CQAvgUohLguWIj19j0AHwQwYGYnAXwJwAfNbBsAB3AUwJ8v5GDtbTm895b3BLWbbuPWW2Fr2Ebr6OFZV7zSGeDGrZVUxCLp6wjXEYt0f4q+m9ZJayIAqEZq8iFi8ZRK4fZPG25YQ+e0tXALsDDDM/o8FXn5WFjzSH23unOtFnnO6pFUunIhfD5qdf6YU5nI6yPyjE6d5xbssSMnqHbX3bcFx2crvB5iO7EHI8l88we7u386MPzt+eYJIa4v9A06IRKCgl2IhKBgFyIhKNiFSAgKdiESQlMLTqbSabSRTK/OHG+h1NFOlhkprhcrbGgx6y1m8XjYKqtXuIUWs5Ms4pNUI+ZhJJEOTgpmdvbyDMFqjR+rVo9UgSQtngDAUQuOp2KLr3GtluGWqCPyZJMCp1YPrw8AWiOPOVvjz1lHkc/zibAFCABnD08Ex1dv5kVHz6XC30aNXb11ZRciISjYhUgICnYhEoKCXYiEoGAXIiEo2IVICE213tLpNLp6whaQR7LNZkth+8RLvCdXicwBgJnpGaqVK3xeqRTONqtWuXVViWSoVSLHmo30DZud4dlQVZJJ19XXQ+d09fRSrbdrgGq5lnA/NwCosd59FunLBq51dfECnOfP8PNYLIQtqnqdF1cy8MdVr/HXXHcXt4/XrhmiWmE2/Hr0SHHOnq6whZ1O8+u3ruxCJAQFuxAJQcEuREJQsAuREBTsQiSEpu7GX5y8hL97/O+DWi37az7vYjhRYPrSOTonFcmNiO3UT0yEjwUANZJd0xdpJ7VsoJ9qrWl++mcuTFJt/4F9VMtPh3efR9bxFk/pLHdCurv4+tet43XtVo+E6/WtW7+Kzulr5YkwXTm+xnqkFiHS4eSUSo3vdKcjLZ7SkTUOjUaci26+U1/xcFJOmpsC6OsLP+ZMJDlMV3YhEoKCXYiEoGAXIiEo2IVICAp2IRKCgl2IhLCQ9k8jAL4DYAhz7Z4edvdvmFkfgO8DGMVcC6gH3P1i7L7yUzN48pfPBrXe1ZvpPK+F7aSXn/0lnbN2Na/fNdDP7aRTJ8epViV1y9r7eumccoonyUyc5C2BPrTjTqptu+Umqs2WisHxVJY/1UeOH6Pa/gOHqPbanpep1tsTbuL5J3/6STrnrps2Ua0l0mNr9fAI1crEerNILbxY3cAKqa0HAKlMpK5dL0/kaSO1COtpbhEzIzJSQnFBV/YqgL9w9xsB3AHgc2Z2I4CHADzl7hsBPNX4WQhxnTJvsLv7mLu/1Lg9BWAfgFUA7gfwaOPXHgXwiUVaoxDiGvCu/mY3s1EAtwF4HsCQu481pHHMfcwXQlynLDjYzawTwA8BfMHd85dr7u5AuHi3me00s11mtqtS5on/QojFZUHBbmZZzAX6d939R43hCTMbbujDAM6E5rr7w+6+3d23Z1v494OFEIvLvMFuc+1Tvg1gn7t/7TLpcQAPNm4/COAn1355QohrxUKy3u4C8BkAr5nZK42xLwL4CoAfmNlnARwD8MB8d9TX149//el/E9RaBzfSebNTYTvswGuv0jnDK7gdk4q0XWrL8Qyqcj3cwmfTVr72ZcM8I252gNdB+/hH/yXV2rvaqDZDrLdIpyZUSVsrAChWw/cHAGfOXKDasSOng+Pt7fz8jp88T7Wjew9QLVXkazw8HvzAiR0f2U7nrB1dSbVYtlwqF0lTy3JbzlitOeNzWiz8nMWst3mD3d1/A4DdxYfmmy+EuD7QN+iESAgKdiESgoJdiISgYBciISjYhUgITS04aQa0toTfX/a/sYfOy18KW28ey04q84yh6Uj7J4t4F7nWcK5RZZa3Y7p0lq9x4jjPevv7fwwX5gSAi1OR401fCo53dXPLq2dZuCUXAHRECiWePBm21wBgcCBcWDLXza3IX/+UP+YLB3ZTrVbmLbYOjocLiJ6MtNDauIVbqT3d7VxbxltstbXzrLeejvDrKpvjxSPb28PPizt//erKLkRCULALkRAU7EIkBAW7EAlBwS5EQlCwC5EQmmq91aoVTJ0P22i/+MlP6bwT4yeD46lKOAsNAHbvzlMtlhpUrfKsJpBMoyef+AWd0pLl1tW2226nWrmli2r50izVDh8PZ3mdP8/7w5WLPOvt9PhRqh05yu9z+23vDY7/u8/9Bzrnhed+S7XqJZ4Rly/xoiiFcE0VHN7Fbc9fvzhGtY4Mt/myLdwqS7fy10EXsd5Wrx2lc+7/k08Fx8tVfv3WlV2IhKBgFyIhKNiFSAgKdiESgoJdiITQ1N34lmwWw0PDQW3j6Do6zxHeLc5EWiulIzvuqTR/j/M6T1xpyXWEhSxPcli5MpwQAgAfvPdeqnW1RxIucrx23et7wnX59h/kbZxWrBqlWjHSdindxte4Z/8bwfHX9++nc9pHt1Dt9Gn+mJf1cm2wJVwXrr2T1/G7MM7bYZ0/dZBqZ8+Fk24AoFiLJG2RAoFjkzw83/+h8JyYmaQruxAJQcEuREJQsAuREBTsQiQEBbsQCUHBLkRCmNd6M7MRAN/BXEtmB/Cwu3/DzL4M4M8AnG386hfd/Wex+6pWarhwNtwy6I5/9n467/0f+EBwvLWVJx5kIvZarP1TPdIKKY3w8Spl3qanUOZJK+dPHqHahSJPuLhwjrddOkwsttNnwglIANA5yNsdoZXbitbCrbdyNZyc8uTTv6Fz1m64mWojfdzCzKX4y7idJCKVirwG3eH8Xqp1dvFafjXnvtf4xWmqDQyMBsdnK/y1+IunXwiO5yP1FRfis1cB/IW7v2RmXQBeNLMnG9rX3f2/LeA+hBBLzEJ6vY0BGGvcnjKzfQD426wQ4rrkXf3NbmajAG4D8Hxj6PNmttvMHjEz/jUmIcSSs+BgN7NOAD8E8AV3zwP4JoANALZh7sr/VTJvp5ntMrNdU9ORghJCiEVlQcFuZlnMBfp33f1HAODuE+5ec/c6gG8B2BGa6+4Pu/t2d9/e1ck3N4QQi8u8wW5zLVK+DWCfu3/tsvHLM1o+CYC3dBFCLDkL2Y2/C8BnALxmZq80xr4I4NNmtg1zdtxRAH8+3x2lUoYO0rbmfL5I5728+8Xg+OAg3yYYGhygWqXCba2LFyephmJ4jZk6v79V67itNbKM15k7tZ/XQZuZ5jXXBodWBMfb+3vpnHSOf+KaLfDnZXh4DdXGT4frBp47H25PBQDDKyNtuSKtvqZL/PwjE369VercLm1tI9mNAFoj2ZTl82ephlS4zhwADJGsw3KJtzCLnA7KQnbjfwMg9AijnroQ4vpC36ATIiEo2IVICAp2IRKCgl2IhKBgFyIhNLXgZCoFtGbDmTyl4iSd9+yzTwXHvcJtoe52XlCwUuHZScUCbymVIe+Na0dH6Jytd9xItQ1ruC03eSJsXQHA+MVzVGtpC1tNG/rDlhwAnD3LM7Ju3ryVajfdvJlqj/2v7wTHMwgXgASAygx/PstlrnmV22jIhZ/rWDum0XXrqXbmxJv8WCmehdnWwY+3Zcum4Hhxlj8vI8ODwfGnW7jFpyu7EAlBwS5EQlCwC5EQFOxCJAQFuxAJQcEuREJoqvVWr9cxWyAFGCNFIO/96MfD91fmWVLpiL1Wr/FCfp7m9kk6E7aNch288OL4JLfypiZ537MLBb5+y/EikG++cjg4fv63PCNr/Tpuob3vho1UK0cy4tpawlaTRzIOYxl2qTR/qZJWaQCAQp30Cazx87t2NbfeitPnqXZjN8+We+HFl6l2+ljYzivM8Ne3z14MjpdLkXNIFSHEHxQKdiESgoJdiISgYBciISjYhUgICnYhEkKTs94MHZ1h+6onUkCva3k4K6hU4oUXc5H3sRbjmVfexrPlWtvD8+pFnp00NcVr5afbeaHHwQ29VNvQzrPeDhwJ93qDcUsxS4qAAsCpseNU6x/gBT+ZVi5wO6lU4sUoZyIZcaVIdlilFLZ6Mzlulw6tXE61Y2MTVJs4Ts49gOI0f2yH9r4SHO/v5+vwZX1EoFN0ZRciKSjYhUgICnYhEoKCXYiEoGAXIiHMuxtvZjkAzwBobfz+37r7l8xsHYDHAPQDeBHAZ9yd96sBUK8VMTtFkj/q/H0na53B8YkJvsN54PWjVMtl+I57S08v1QZIu6mVAz10TiaS4NPf00+1SK4OioVwEgQADA6Gd/hXrSS7twDGxseptn//PqqNltdRjTklU1P8OZud5Tvd+Uvc1YjtxtfK4USkdCtPWtm7h7cOi7VkGhwcotqqW3gtv8Hl4XkDy3ndwBxZ/y+e/RWds5ArewnAH7n7rZhrz3yfmd0B4K8AfN3dbwBwEcBnF3BfQoglYt5g9zneeuvMNv45gD8C8LeN8UcBfGIxFiiEuDYstD97utHB9QyAJwEcAjDp7m8lBZ8EsGpRViiEuCYsKNjdvebu2wCsBrADwHsWegAz22lmu8xsV36aFK4QQiw672o33t0nAfwSwJ0Aes3srQ2+1QBOkTkPu/t2d9/e3cm/oiiEWFzmDXYzW25mvY3bbQA+DGAf5oL+Txu/9iCAnyzSGoUQ14CFJMIMA3jUzNKYe3P4gbs/YWavA3jMzP4LgJcBfHu+O3J31Ekbn1TkfSdTCSdxdJNWUgDw4nNPU218gieSWJYnhezY8d7g+N13bqdzLl3iVtPul56n2kyRJ37sP36CaoePHg2OF2b5n1DuvIhbrpsnY+TzU1SbIi2qZvLcNoyUkkMmzdWeLv6JceW6sD24rH+YzhlcyS2vlbfdTLW+SA26llhtQ6ZFkpfg4XhJRazeeYPd3XcDuC0wfhhzf78LIX4P0DfohEgICnYhEoKCXYiEoGAXIiEo2IVICOYeKVp1rQ9mdhbAscaPAwC4B9Y8tI63o3W8nd+3dax196Bf2tRgf9uBzXa5OzeotQ6tQ+u4puvQx3ghEoKCXYiEsJTB/vASHvtytI63o3W8nT+YdSzZ3+xCiOaij/FCJIQlCXYzu8/M3jSzg2b20FKsobGOo2b2mpm9Yma7mnjcR8zsjJntuWysz8yeNLMDjf95b6XFXceXzexU45y8YmYfa8I6Rszsl2b2upntNbN/3xhv6jmJrKOp58TMcmb2gpm92ljHf26MrzOz5xtx832zSB+zEO7e1H8A0pgra7UeQAuAVwHc2Ox1NNZyFMDAEhz3HgC3A9hz2dh/BfBQ4/ZDAP5qidbxZQD/scnnYxjA7Y3bXQD2A7ix2eckso6mnhPMZft2Nm5nATwP4A4APwDwqcb4fwfwb9/N/S7FlX0HgIPuftjnSk8/BuD+JVjHkuHuzwC48I7h+zFXuBNoUgFPso6m4+5j7v5S4/YU5oqjrEKTz0lkHU3F57jmRV6XIthXAbi8+sJSFqt0AD83sxfNbOcSreEthtx9rHF7HAAvQr74fN7Mdjc+5i/6nxOXY2ajmKuf8DyW8Jy8Yx1Ak8/JYhR5TfoG3d3ufjuAjwL4nJnds9QLAube2RFtvruofBPABsz1CBgD8NVmHdjMOgH8EMAX3P1tXSGaeU4C62j6OfGrKPLKWIpgPwVg5LKfabHKxcbdTzX+PwPgx1jayjsTZjYMAI3/zyzFItx9ovFCqwP4Fpp0Tswsi7kA+667/6gx3PRzElrHUp2TxrEn8S6LvDKWIth/B2BjY2exBcCnADze7EWYWYeZdb11G8BHAOyJz1pUHsdc4U5gCQt4vhVcDT6JJpwTMzPM1TDc5+5fu0xq6jlh62j2OVm0Iq/N2mF8x27jxzC303kIwF8u0RrWY84JeBXA3mauA8D3MPdxsIK5v70+i7meeU8BOADgnwD0LdE6/ieA1wDsxlywDTdhHXdj7iP6bgCvNP59rNnnJLKOpp4TALdgrojrbsy9sfyny16zLwA4COD/AGh9N/erb9AJkRCSvkEnRGJQsAuREBTsQiQEBbsQCUHBLkRCULALkRAU7EIkBAW7EAnh/wFg4oMsAJ/FrwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(dataset.train_data[1] / 256)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
