{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "B2EOnCShVSJ4"
      },
      "source": [
        "# Set-Up\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4siUQ4-eU0qC",
        "outputId": "15b980fe-1030-4ecf-eb19-d30f6ba5afc7",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found existing installation: tensorflow 2.8.3\n",
            "Uninstalling tensorflow-2.8.3:\n",
            "  Successfully uninstalled tensorflow-2.8.3\n"
          ]
        }
      ],
      "source": [
        "!pip uninstall tensorflow -y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xBfZZK217Xpp",
        "outputId": "63e2d570-895c-413f-e9d1-7ef73f96a255",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Collecting tensorflow==2.8.3\n",
            "  Using cached tensorflow-2.8.3-cp38-cp38-manylinux2010_x86_64.whl (498.4 MB)\n",
            "Requirement already satisfied: keras-preprocessing>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (1.1.2)\n",
            "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (3.3.0)\n",
            "Requirement already satisfied: libclang>=9.0.1 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (15.0.6.1)\n",
            "Requirement already satisfied: tensorboard<2.9,>=2.8 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (2.8.0)\n",
            "Requirement already satisfied: absl-py>=0.4.0 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (1.4.0)\n",
            "Requirement already satisfied: flatbuffers>=1.12 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (23.1.21)\n",
            "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (1.6.3)\n",
            "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (0.31.0)\n",
            "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (1.15.0)\n",
            "Requirement already satisfied: tensorflow-estimator<2.9,>=2.8 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (2.8.0)\n",
            "Requirement already satisfied: protobuf<3.20,>=3.9.2 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (3.19.6)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (57.4.0)\n",
            "Requirement already satisfied: gast>=0.2.1 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (0.4.0)\n",
            "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (2.2.0)\n",
            "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (0.2.0)\n",
            "Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (3.1.0)\n",
            "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (1.15.0)\n",
            "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (4.5.0)\n",
            "Requirement already satisfied: keras<2.9,>=2.8.0rc0 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (2.8.0)\n",
            "Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (1.22.4)\n",
            "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.8/dist-packages (from tensorflow==2.8.3) (1.51.3)\n",
            "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.8/dist-packages (from astunparse>=1.6.0->tensorflow==2.8.3) (0.38.4)\n",
            "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<2.9,>=2.8->tensorflow==2.8.3) (0.6.1)\n",
            "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<2.9,>=2.8->tensorflow==2.8.3) (2.16.1)\n",
            "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<2.9,>=2.8->tensorflow==2.8.3) (2.25.1)\n",
            "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.8/dist-packages (from tensorboard<2.9,>=2.8->tensorflow==2.8.3) (2.2.3)\n",
            "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<2.9,>=2.8->tensorflow==2.8.3) (1.8.1)\n",
            "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<2.9,>=2.8->tensorflow==2.8.3) (0.4.6)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.8/dist-packages (from tensorboard<2.9,>=2.8->tensorflow==2.8.3) (3.4.1)\n",
            "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (0.2.8)\n",
            "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (4.9)\n",
            "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (5.3.0)\n",
            "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (1.3.1)\n",
            "Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.8/dist-packages (from markdown>=2.6.8->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (6.0.0)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (1.26.14)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (2.10)\n",
            "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (4.0.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (2022.12.7)\n",
            "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.8/dist-packages (from werkzeug>=0.11.15->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (2.1.2)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (3.15.0)\n",
            "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.8/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (0.4.8)\n",
            "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow==2.8.3) (3.2.2)\n",
            "Installing collected packages: tensorflow\n",
            "Successfully installed tensorflow-2.8.3\n"
          ]
        }
      ],
      "source": [
        "!pip install tensorflow==2.8.3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UYOX3i2JVG-a",
        "outputId": "338794d8-f81b-4d20-e3be-6bf5b5382276",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[33mWARNING: Skipping import as it is not installed.\u001b[0m\u001b[33m\n",
            "\u001b[0mFound existing installation: tensorflow-probability 0.16.0\n",
            "Uninstalling tensorflow-probability-0.16.0:\n",
            "  Successfully uninstalled tensorflow-probability-0.16.0\n"
          ]
        }
      ],
      "source": [
        "!pip uninstall import tensorflow_probability -y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "u0o1OfUk9BQh",
        "outputId": "7f0b9103-dcef-444b-fadc-05d3d919c90d",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Collecting tensorflow_probability==0.16.0\n",
            "  Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n",
            "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.8/dist-packages (from tensorflow_probability==0.16.0) (1.15.0)\n",
            "Requirement already satisfied: decorator in /usr/local/lib/python3.8/dist-packages (from tensorflow_probability==0.16.0) (4.4.2)\n",
            "Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.8/dist-packages (from tensorflow_probability==0.16.0) (1.22.4)\n",
            "Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.8/dist-packages (from tensorflow_probability==0.16.0) (0.4.0)\n",
            "Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.8/dist-packages (from tensorflow_probability==0.16.0) (2.2.1)\n",
            "Requirement already satisfied: dm-tree in /usr/local/lib/python3.8/dist-packages (from tensorflow_probability==0.16.0) (0.1.8)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.8/dist-packages (from tensorflow_probability==0.16.0) (1.4.0)\n",
            "Installing collected packages: tensorflow_probability\n",
            "Successfully installed tensorflow_probability-0.16.0\n"
          ]
        }
      ],
      "source": [
        "!pip install tensorflow_probability==0.16.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DiLjNLa2Fxln",
        "outputId": "c8c2b802-96d8-47db-ff2e-7724a3332c42",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Colab only includes TensorFlow 2.x; %tensorflow_version has no effect.\n"
          ]
        }
      ],
      "source": [
        "%tensorflow_version 2.x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "kFbyHYemoTLm",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# Load the TensorBoard notebook extension\n",
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "lXsO6OyvFiJz",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "import datetime\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_probability as tfp\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from tensorflow.keras import Model\n",
        "from tensorflow.keras.models import Sequential\n",
        "from tensorflow.keras.losses import categorical_crossentropy\n",
        "from tensorflow.keras.layers import Dense, Flatten, Conv2D, AveragePooling2D, Lambda\n",
        "\n",
        "from tensorflow.keras import datasets\n",
        "from tensorflow.keras.utils import to_categorical\n",
        "\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "from __future__ import absolute_import, division, print_function, unicode_literals"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "ezfO-JPQ9JNs",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "assert(tf.test.gpu_device_name())\n",
        "tf.keras.backend.clear_session()\n",
        "tf.config.optimizer.set_jit(True) # Enable XLA."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "rvBgUpdYhcTe",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "from google.colab import files\n",
        "import pickle"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DhtuinGb7p7X",
        "outputId": "9a324154-42bc-4c5e-fcdb-656731bd20e2",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.8.3\n"
          ]
        }
      ],
      "source": [
        "print(tf.__version__)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "Nyjb_iYxcdCR"
      },
      "source": [
        "# Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "SXIhRtCfxCFB",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "num_classes = 10\n",
        "\n",
        "def load_data(data_set):\n",
        "  if data_set=='fashion_mnist':\n",
        "    (x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()\n",
        "  else:\n",
        "    (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()\n",
        "  \n",
        "\n",
        "  tf.keras.utils.set_random_seed(123)\n",
        "  a = np.random.permutation(60000) #60000 original\n",
        "  x_train = x_train[a,...]\n",
        "  y_train = y_train[a]\n",
        "  # Add a new axis\n",
        "  x_train = x_train[..., np.newaxis]\n",
        "  x_test = x_test[..., np.newaxis]\n",
        "\n",
        "\n",
        "  # Convert class vectors to binary class matrices.\n",
        "  y_train = to_categorical(y_train, num_classes)\n",
        "  y_test = to_categorical(y_test, num_classes)\n",
        "\n",
        "  # Data normalization\n",
        "  x_train = x_train.astype('float32')\n",
        "  x_test = x_test.astype('float32')\n",
        "  x_train /= 255\n",
        "  x_test /= 255\n",
        "\n",
        "  return x_train, y_train, x_test, y_test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "KrpnyCdItgBZ",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "num_classes = 10\n",
        "\n",
        "def load_random_labelled_data(data_set):\n",
        "  if data_set=='fashion_mnist':\n",
        "    (x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()\n",
        "  else:\n",
        "    (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()\n",
        "  \n",
        "\n",
        "  tf.keras.utils.set_random_seed(123)\n",
        "  a = np.random.permutation(60000) #60000 original\n",
        "  x_train = x_train[a,...]\n",
        "  #y_train = y_train[a]\n",
        "  # Add a new axis\n",
        "  x_train = x_train[..., np.newaxis]\n",
        "  x_test = x_test[..., np.newaxis]\n",
        "\n",
        "\n",
        "  # Convert class vectors to binary class matrices.\n",
        "  y_train = to_categorical(y_train, num_classes)\n",
        "  y_test = to_categorical(y_test, num_classes)\n",
        "\n",
        "  # Data normalization\n",
        "  x_train = x_train.astype('float32')\n",
        "  x_test = x_test.astype('float32')\n",
        "  x_train /= 255\n",
        "  x_test /= 255\n",
        "\n",
        "  return x_train, y_train, x_test, y_test"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "rdc9t3j5os2T"
      },
      "source": [
        "# Prior Definition"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "bi3i0Qya01CV"
      },
      "source": [
        "Define functions to initialize different types of prior distributions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "1P33CQe3MYeo",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "from tensorflow_probability.python.distributions import normal as normal_lib\n",
        "from tensorflow_probability.python.distributions import independent as independent_lib\n",
        "from tensorflow_probability.python.distributions import kullback_leibler as kl_lib\n",
        "\n",
        "def make_normal_prior(scale):\n",
        "  \"\"\"\n",
        "  Defines a function that returns a Gaussian prior with the given standard\n",
        "  deviation.\n",
        "  \"\"\"\n",
        "  def _fn(dtype, shape, name, trainable,add_variable_fn):\n",
        "    del name, trainable, add_variable_fn   # unused\n",
        "    dist = normal_lib.Normal(loc=tf.zeros(shape, dtype), \n",
        "                             scale=dtype.as_numpy_dtype(scale))\n",
        "    batch_ndims = tf.size(dist.batch_shape_tensor())\n",
        "    return independent_lib.Independent(dist, \n",
        "                                       reinterpreted_batch_ndims=batch_ndims)\n",
        "\n",
        "  return _fn"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "b9E9K-6kuPq0"
      },
      "source": [
        "# Posteriors Definition"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "m6FwtaZJ3EYu"
      },
      "source": [
        "Define posterior distributions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "QoVqrmiDi5r1",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "import tensorflow.compat.v1 as tf1\n",
        "\n",
        "mean_field_init_untransformed_scale = -7.0\n",
        "\n",
        "# TODO(trandustin): Remove need for this boilerplate code.\n",
        "def mean_field_fn(empirical_bayes=False,\n",
        "                  initializer=tf1.initializers.he_normal()):\n",
        "  \"\"\"Constructors for Gaussian prior and posterior distributions.\n",
        "  Args:\n",
        "    empirical_bayes (bool): Whether to train the variance of the prior or not.\n",
        "    initializer (tf1.initializer): Initializer for the posterior means.\n",
        "  Returns:\n",
        "    prior, posterior (tfp.distribution): prior and posterior\n",
        "    to be fed into a Bayesian Layer.\n",
        "  \"\"\"\n",
        "\n",
        "  def prior(dtype, shape, name, trainable, add_variable_fn):\n",
        "    \"\"\"Returns the prior distribution (tfp.distributions.Independent).\"\"\"\n",
        "    softplus_inverse_scale = np.log(np.exp(1.) - 1.)\n",
        "\n",
        "    istrainable = add_variable_fn(\n",
        "        name=name + '_istrainable',\n",
        "        shape=(),\n",
        "        initializer=tf1.constant_initializer(1.),\n",
        "        dtype=dtype,\n",
        "        trainable=False)\n",
        "\n",
        "    untransformed_scale = add_variable_fn(\n",
        "        name=name + '_untransformed_scale',\n",
        "        shape=(),\n",
        "        initializer=tf1.constant_initializer(softplus_inverse_scale),\n",
        "        dtype=dtype,\n",
        "        trainable=empirical_bayes and trainable)\n",
        "    scale = (\n",
        "        np.finfo(dtype.as_numpy_dtype).eps +\n",
        "        tf.nn.softplus(untransformed_scale * istrainable + (1. - istrainable) *\n",
        "                       tf1.stop_gradient(untransformed_scale)))\n",
        "    loc = add_variable_fn(\n",
        "        name=name + '_loc',\n",
        "        shape=shape,\n",
        "        initializer=tf1.constant_initializer(0.),\n",
        "        dtype=dtype,\n",
        "        trainable=False)\n",
        "    dist = tfp.distributions.Normal(loc=loc, scale=scale)\n",
        "    dist.istrainable = istrainable\n",
        "    dist.untransformed_scale = untransformed_scale\n",
        "    batch_ndims = tf1.size(input=dist.batch_shape_tensor())\n",
        "    return tfp.distributions.Independent(dist,\n",
        "                                         reinterpreted_batch_ndims=batch_ndims)\n",
        "\n",
        "  def posterior(dtype, shape, name, trainable, add_variable_fn):\n",
        "    \"\"\"Returns the posterior distribution (tfp.distributions.Independent).\"\"\"\n",
        "    untransformed_scale = add_variable_fn(\n",
        "        name=name + '_untransformed_scale',\n",
        "        shape=shape,\n",
        "        initializer=tf1.initializers.random_normal(\n",
        "            mean=mean_field_init_untransformed_scale, stddev=0.1),\n",
        "        dtype=dtype,\n",
        "        trainable=trainable)\n",
        "    scale = (\n",
        "        np.finfo(dtype.as_numpy_dtype).eps +\n",
        "        tf.nn.softplus(untransformed_scale))\n",
        "    loc = add_variable_fn(\n",
        "        name=name + '_loc',\n",
        "        shape=shape,\n",
        "        initializer=initializer,\n",
        "        dtype=dtype,\n",
        "        trainable=trainable)\n",
        "    dist = tfp.distributions.Normal(loc=loc, scale=scale)\n",
        "    dist.untransformed_scale = untransformed_scale\n",
        "    batch_ndims = tf1.size(input=dist.batch_shape_tensor())\n",
        "    return tfp.distributions.Independent(dist,\n",
        "                                         reinterpreted_batch_ndims=batch_ndims)\n",
        "\n",
        "  return prior, posterior\n",
        "\n",
        "prior_fn, posterior_fn = mean_field_fn(empirical_bayes=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "uKuT191d3Xzh",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "from tensorflow_probability.python.distributions import deterministic as deterministic_lib\n",
        "\n",
        "def make_normal_posterior(\n",
        "    is_singular=False,\n",
        "    loc_initializer=tf1.initializers.random_normal(stddev=0.1),\n",
        "    untransformed_scale_initializer=tf1.initializers.random_normal(\n",
        "        mean=-3., stddev=0.1),\n",
        "    loc_regularizer=None,\n",
        "    untransformed_scale_regularizer=None,\n",
        "    loc_constraint=None,\n",
        "    untransformed_scale_constraint=None):\n",
        "\n",
        "  loc_scale_fn = tfp.layers.util.default_loc_scale_fn(\n",
        "      is_singular=is_singular,\n",
        "      loc_initializer=loc_initializer,\n",
        "      untransformed_scale_initializer=untransformed_scale_initializer,\n",
        "      loc_regularizer=loc_regularizer,\n",
        "      untransformed_scale_regularizer=untransformed_scale_regularizer,\n",
        "      loc_constraint=loc_constraint,\n",
        "      untransformed_scale_constraint=untransformed_scale_constraint)\n",
        "  def _fn(dtype, shape, name, trainable, add_variable_fn):\n",
        "    loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)\n",
        "    if scale is None:\n",
        "      dist = deterministic_lib.Deterministic(loc=loc)\n",
        "    else:\n",
        "      dist = normal_lib.Normal(loc=loc, scale=scale)\n",
        "    batch_ndims = tf.size(dist.batch_shape_tensor())\n",
        "    return independent_lib.Independent(\n",
        "        dist, reinterpreted_batch_ndims=batch_ndims)\n",
        "  return _fn\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "BW9rqXi-QlKc",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "import tensorflow.compat.v1 as tf1\n",
        "from tensorflow_probability.python.distributions import deterministic as deterministic_lib\n",
        "\n",
        "def default_mean_field_mixture_fn(\n",
        "    is_singular=False,\n",
        "    loc_initializer=tf1.initializers.random_normal(stddev=0.1),\n",
        "    untransformed_scale_initializer=tf1.initializers.random_normal(\n",
        "        mean=-3., stddev=0.1),\n",
        "    loc_regularizer=None,\n",
        "    untransformed_scale_regularizer=None,\n",
        "    loc_constraint=None,\n",
        "    untransformed_scale_constraint=None):\n",
        " \n",
        "  loc_scale_fn = tfp.layers.util.default_loc_scale_fn(\n",
        "      is_singular=is_singular,\n",
        "      loc_initializer=loc_initializer,\n",
        "      untransformed_scale_initializer=untransformed_scale_initializer,\n",
        "      loc_regularizer=loc_regularizer,\n",
        "      untransformed_scale_regularizer=untransformed_scale_regularizer,\n",
        "      loc_constraint=loc_constraint,\n",
        "      untransformed_scale_constraint=untransformed_scale_constraint)\n",
        "  def _fn(dtype, shape, name, trainable, add_variable_fn):\n",
        "    loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)\n",
        "    #loc_2, scale_2 = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)\n",
        "    if scale is None:\n",
        "      dist = deterministic_lib.Deterministic(loc=loc)\n",
        "    else:\n",
        "      dist = tfp.distributions.Normal(loc=loc, scale=scale)\n",
        "      #dist2 = tfp.distributions.Normal(loc=tf.stop_gradient(loc_2), scale=scale_2)\n",
        "\n",
        "    batch_ndims = tf.size(dist.batch_shape_tensor())\n",
        "\n",
        "    return tfp.distributions.Mixture(\n",
        "        cat=tfp.distributions.Categorical(probs=[0.5, 0.5]),\n",
        "        components=[\n",
        "            tfp.distributions.Independent(dist,\n",
        "                            reinterpreted_batch_ndims=batch_ndims),\n",
        "            tfp.distributions.Independent(tfp.distributions.Normal(\n",
        "                loc=tf.zeros(shape, dtype=dtype), \n",
        "                scale=1.0*tf.ones(shape, dtype=dtype)),\n",
        "                            reinterpreted_batch_ndims=batch_ndims)],\n",
        "        name='spike_and_slab')\n",
        "\n",
        "  return _fn\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "NR-MGa04cmVh"
      },
      "source": [
        "# Model"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "XJyqZswKTPil"
      },
      "source": [
        "## Set-up"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "VMA-Cx8KULte",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "def conv2d(filters, kernel_size, padding='SAME', l2_val=2e-4):\n",
        "    return tf.keras.layers.Conv2D(\n",
        "          filters,\n",
        "          kernel_size,\n",
        "          padding=padding,\n",
        "          activation=tf.nn.relu,\n",
        "          bias_regularizer=tf.keras.regularizers.L2(l2_val),\n",
        "          kernel_regularizer=tf.keras.regularizers.L2(l2_val))\n",
        "\n",
        "def max_pool():\n",
        "    return tf.keras.layers.MaxPooling2D(\n",
        "        pool_size=[2, 2],\n",
        "        strides=[2, 2],\n",
        "        #padding='same',\n",
        "    )\n",
        "def avg_pool():\n",
        "    return tf.keras.layers.AveragePooling2D(\n",
        "        pool_size=[2, 2],\n",
        "        strides=[2, 2],\n",
        "        #padding='same',\n",
        "    )\n",
        "\n",
        "def dense(units, activation, l2_val=2e-4):\n",
        "    return tf.keras.layers.Dense(\n",
        "        units,\n",
        "        activation,\n",
        "        bias_regularizer=tf.keras.regularizers.L2(l2_val),\n",
        "        kernel_regularizer=tf.keras.regularizers.L2(l2_val))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "oCbBDK6PTRNf",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "def conv2d_variational(filters, kernel_size, strides=(1, 1), padding='SAME', kernel_posterior_fn=None, kernel_prior_fn=None, bias_prior_fn=None, kernel_divergence_fn=None):\n",
        "    \"\"\"Convenience wrapper for conv layers.\"\"\"\n",
        "    if bias_prior_fn==None:\n",
        "      bias_prior_fn = kernel_prior_fn\n",
        "    \n",
        "    return tfp.layers.Convolution2DReparameterization(\n",
        "          filters,\n",
        "          kernel_size,\n",
        "          padding=padding,\n",
        "          activation=tf.nn.relu,\n",
        "          kernel_posterior_fn=kernel_posterior_fn,\n",
        "          kernel_prior_fn=kernel_prior_fn,\n",
        "          kernel_divergence_fn=kernel_divergence_fn,\n",
        "          bias_posterior_fn = kernel_posterior_fn,\n",
        "          bias_prior_fn = bias_prior_fn,\n",
        "          bias_divergence_fn=kernel_divergence_fn)\n",
        "\n",
        "\n",
        "def dense_variational(units, activation, kernel_posterior_fn=None, kernel_prior_fn=None, bias_prior_fn=None, kernel_divergence_fn=None):\n",
        "    if bias_prior_fn==None:\n",
        "      bias_prior_fn = kernel_prior_fn\n",
        "\n",
        "    return tfp.layers.DenseReparameterization(\n",
        "        units,\n",
        "        activation,\n",
        "        kernel_posterior_fn=kernel_posterior_fn,\n",
        "        kernel_prior_fn=kernel_prior_fn,\n",
        "        kernel_divergence_fn=kernel_divergence_fn,\n",
        "        bias_posterior_fn = kernel_posterior_fn,\n",
        "        bias_prior_fn = bias_prior_fn,\n",
        "        bias_divergence_fn=kernel_divergence_fn)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "CHhNiKBIwziq"
      },
      "source": [
        "## LeNet"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "qMo13XRwNtzv"
      },
      "source": [
        "### Variational"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "RE-IGUORIEDn",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "class LeNet_Variational(Sequential):\n",
        "  def __init__(self, input_shape, nb_classes,\n",
        "               kernel_posterior_fn=make_normal_posterior(),\n",
        "               kernel_prior_fn=make_normal_prior(1.0),\n",
        "               kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p)/(x_train.shape[0]*lamb),\n",
        "               gamma=1.0):\n",
        "    \n",
        "    super().__init__()\n",
        "\n",
        "    self.weight_layers = [0, 2, 4, 6, 7]\n",
        "\n",
        "    self.add(tf.keras.layers.Input(shape=input_shape))\n",
        "\n",
        "    self.add(conv2d_variational(6, 5, padding = \"same\", kernel_posterior_fn=kernel_posterior_fn, \n",
        "                                kernel_prior_fn=kernel_prior_fn, \n",
        "                                kernel_divergence_fn=kernel_divergence_fn))\n",
        "    self.add(avg_pool())\n",
        "\n",
        "    self.add(conv2d_variational(16, 5, padding=\"same\", kernel_posterior_fn=kernel_posterior_fn, \n",
        "                                kernel_prior_fn=kernel_prior_fn, \n",
        "                                kernel_divergence_fn=kernel_divergence_fn))\n",
        "    self.add(avg_pool())\n",
        "\n",
        "    self.add(Flatten())\n",
        "\n",
        "\n",
        "    self.add(dense_variational(120, tf.nn.relu, \n",
        "                               kernel_posterior_fn=kernel_posterior_fn, \n",
        "                               kernel_prior_fn=kernel_prior_fn, \n",
        "                               kernel_divergence_fn=kernel_divergence_fn))\n",
        "\n",
        "\n",
        "    self.add(dense_variational(84, tf.nn.relu, \n",
        "                               kernel_posterior_fn=kernel_posterior_fn, \n",
        "                               kernel_prior_fn=kernel_prior_fn, \n",
        "                               kernel_divergence_fn=kernel_divergence_fn))\n",
        "\n",
        "    self.add(dense_variational(nb_classes, None, \n",
        "                               kernel_posterior_fn=kernel_posterior_fn, \n",
        "                               kernel_prior_fn=kernel_prior_fn, \n",
        "                               kernel_divergence_fn=kernel_divergence_fn))\n",
        "    \n",
        "    self.add(Lambda(lambda x: x*gamma))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "USVR5Cnxmoth",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "class LeNet_Variational_Large(Sequential):\n",
        "  def __init__(self, input_shape, nb_classes,\n",
        "               kernel_posterior_fn=make_normal_posterior(),\n",
        "               kernel_prior_fn=make_normal_prior(1.0),\n",
        "               kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p)/(x_train.shape[0]*lamb),\n",
        "               gamma=1.0):\n",
        "    \n",
        "    super().__init__()\n",
        "\n",
        "    self.weight_layers = [0, 2, 4, 6, 7]\n",
        "\n",
        "    self.add(tf.keras.layers.Input(shape=input_shape))\n",
        "\n",
        "    self.add(conv2d_variational(6, 5, kernel_posterior_fn=kernel_posterior_fn, \n",
        "                                kernel_prior_fn=kernel_prior_fn, \n",
        "                                kernel_divergence_fn=kernel_divergence_fn))\n",
        "    self.add(max_pool())\n",
        "\n",
        "    self.add(conv2d_variational(16, 5, kernel_posterior_fn=kernel_posterior_fn, \n",
        "                                kernel_prior_fn=kernel_prior_fn, \n",
        "                                kernel_divergence_fn=kernel_divergence_fn))\n",
        "    self.add(max_pool())\n",
        "\n",
        "    self.add(conv2d_variational(120, 5, kernel_posterior_fn=kernel_posterior_fn, \n",
        "                                kernel_prior_fn=kernel_prior_fn, \n",
        "                                kernel_divergence_fn=kernel_divergence_fn))\n",
        "\n",
        "    self.add(Flatten())\n",
        "\n",
        "    self.add(dense_variational(84, tf.nn.relu, \n",
        "                               kernel_posterior_fn=kernel_posterior_fn, \n",
        "                               kernel_prior_fn=kernel_prior_fn, \n",
        "                               kernel_divergence_fn=kernel_divergence_fn))\n",
        "\n",
        "    self.add(dense_variational(nb_classes, None, \n",
        "                               kernel_posterior_fn=kernel_posterior_fn, \n",
        "                               kernel_prior_fn=kernel_prior_fn, \n",
        "                               kernel_divergence_fn=kernel_divergence_fn))\n",
        "    \n",
        "    self.add(Lambda(lambda x: x*gamma))\n",
        "     "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9r6WjDAwXU6h",
        "outputId": "d0bf0fe6-a76d-4b33-f6db-189503aeea30",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.8/dist-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n",
            "  loc = add_variable_fn(\n",
            "/usr/local/lib/python3.8/dist-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n",
            "  untransformed_scale = add_variable_fn(\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model: \"le_net__variational__large\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " conv2d_reparameterization (  (None, 28, 28, 6)        312       \n",
            " Conv2DReparameterization)                                       \n",
            "                                                                 \n",
            " max_pooling2d (MaxPooling2D  (None, 14, 14, 6)        0         \n",
            " )                                                               \n",
            "                                                                 \n",
            " conv2d_reparameterization_1  (None, 14, 14, 16)       4832      \n",
            "  (Conv2DReparameterization)                                     \n",
            "                                                                 \n",
            " max_pooling2d_1 (MaxPooling  (None, 7, 7, 16)         0         \n",
            " 2D)                                                             \n",
            "                                                                 \n",
            " conv2d_reparameterization_2  (None, 7, 7, 120)        96240     \n",
            "  (Conv2DReparameterization)                                     \n",
            "                                                                 \n",
            " flatten (Flatten)           (None, 5880)              0         \n",
            "                                                                 \n",
            " dense_reparameterization (D  (None, 84)               988008    \n",
            " enseReparameterization)                                         \n",
            "                                                                 \n",
            " dense_reparameterization_1   (None, 10)               1700      \n",
            " (DenseReparameterization)                                       \n",
            "                                                                 \n",
            " lambda (Lambda)             (None, 10)                0         \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 1,091,092\n",
            "Trainable params: 1,091,092\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ]
        }
      ],
      "source": [
        "x_train, y_train, x_test, y_test = load_data('mnist')\n",
        "lamb=1\n",
        "model_variational = LeNet_Variational_Large(x_train[0].shape, num_classes)\n",
        "model_variational.summary()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "UL-Qp9o7Blp_"
      },
      "source": [
        "## MLP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "VQjILFcYBnkH",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "class MLP_Variational(Sequential):\n",
        "  \"\"\"\n",
        "  Defines a Bayesian Multilayer perceptron model with prior and posterior \n",
        "  distribution determined by the given parameters.\n",
        "  \"\"\"\n",
        "  def __init__(self, input_shape, nb_classes, hidden, \n",
        "               kernel_posterior_fn=make_normal_posterior(),\n",
        "               kernel_prior_fn=make_normal_prior(1.0),\n",
        "               kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p)/(x_train.shape[0]*lamb),\n",
        "               gamma=1.0):\n",
        "    \n",
        "    super().__init__()\n",
        "\n",
        "    self.weight_layers = [1,2,3]\n",
        "    \n",
        "    self.add(Flatten())\n",
        "    \n",
        "    self.add(tfp.layers.DenseReparameterization(hidden, \n",
        "                      activation='relu',\n",
        "                      kernel_posterior_fn=kernel_posterior_fn,\n",
        "                      kernel_prior_fn=kernel_prior_fn,\n",
        "                      bias_prior_fn=kernel_prior_fn,\n",
        "                      bias_posterior_fn=kernel_posterior_fn,\n",
        "                      kernel_divergence_fn=kernel_divergence_fn,\n",
        "                      bias_divergence_fn=kernel_divergence_fn))\n",
        "\n",
        "    self.add(tfp.layers.DenseReparameterization(hidden, \n",
        "                      activation='relu',\n",
        "                      kernel_posterior_fn=kernel_posterior_fn,\n",
        "                      kernel_prior_fn=kernel_prior_fn,\n",
        "                      bias_prior_fn=kernel_prior_fn,\n",
        "                      bias_posterior_fn=kernel_posterior_fn,\n",
        "                      kernel_divergence_fn=kernel_divergence_fn,\n",
        "                      bias_divergence_fn=kernel_divergence_fn))\n",
        "\n",
        "\n",
        "    self.add(tfp.layers.DenseReparameterization(nb_classes, \n",
        "                      activation='linear',\n",
        "                      kernel_posterior_fn=kernel_posterior_fn,\n",
        "                      kernel_prior_fn=kernel_prior_fn,\n",
        "                      bias_prior_fn=kernel_prior_fn,\n",
        "                      bias_posterior_fn=kernel_posterior_fn,\n",
        "                      kernel_divergence_fn=kernel_divergence_fn,\n",
        "                      bias_divergence_fn=kernel_divergence_fn))\n",
        "    self.add(Lambda(lambda x: x*gamma))    "
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "iwGNoPYQQdNE"
      },
      "source": [
        "# Training Set-up"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "qXUVl6tWoO00",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.0)\n",
        "\n",
        "@tf.function\n",
        "def ece(y_true,y_pred):\n",
        "  return tfp.stats.expected_calibration_error(10,logits=y_pred, labels_true=tf.argmax(y_true,axis=1))\n",
        "\n",
        "# Place the logs in a timestamped subdirectory\n",
        "# This allows to easy select different training runs\n",
        "# In order not to overwrite some data, it is useful to have a name with a timestamp\n",
        "log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
        "# Specify the callback object\n",
        "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
        "\n",
        "# tf.keras.callback.TensorBoard ensures that logs are created and stored\n",
        "# We need to pass callback object to the fit method\n",
        "# The way to do this is by passing the list of callback objects, which is in our case just one\n",
        "\n",
        "early_stopping = tf.keras.callbacks.EarlyStopping(monitor=\"val_categorical_crossentropy\", patience=10, restore_best_weights=True)\n",
        "\n",
        "adam_opt = tf.keras.optimizers.Adam(0.001)\n",
        "\n",
        "def schedule(epoch, lr):\n",
        "  if (epoch<25):\n",
        "    return 0.001\n",
        "  elif epoch<50:\n",
        "    return 0.0001\n",
        "  else:\n",
        "    return 0.00001\n",
        "\n",
        "lr_scheduler = tf.keras.callbacks.LearningRateScheduler(schedule)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "iCCKNSDXfTis",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "class BMA_Model(tf.keras.Model):\n",
        "\n",
        "  def __init__(self, model, mc_samples):\n",
        "    super().__init__()\n",
        "    self.model = model\n",
        "    self.mc_samples = mc_samples\n",
        "  \n",
        "  def call(self, inputs, training=False):\n",
        "    list_models =[self.model(inputs) for i in range(self.mc_samples)]\n",
        "    stack_models = tf.stack(list_models,axis=1)\n",
        "    return tfp.math.reduce_logmeanexp(stack_models,axis=1)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "TPt69y9ZcusY"
      },
      "source": [
        "# Variational Learning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "id": "7OWzof1fmtBb",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# experiment settings\n",
        "data_use=\"mnist\"\n",
        "model_used=\"lenet\"\n",
        "num_posterior_samples=10\n",
        "n_epochs=10000\n",
        "batch_size=100\n",
        "lr=0.0001\n",
        "\n",
        "seeds=[15]#,24]\n",
        "\n",
        "# lambs for likelihood tempering\n",
        "lambs=np.linspace(0.5, 3.5, 7)\n",
        "#lambs = [0.5, 1.0, 1.5]\n",
        "\n",
        "# label_smoothing\n",
        "# label_smoothing=0.0\n",
        "# cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=label_smoothing)\n",
        "\n",
        "# label noise\n",
        "#label_noises=[0.0,0.1,0.3,0.5]\n",
        "label_noises=[0.0]\n",
        "\n",
        "# smooth softmax\n",
        "gammas=[1.0, 10.0]\n",
        "\n",
        "# data augmentation:\n",
        "#   - None: No data augmentation\n",
        "#   - Standard: \n",
        "#   - Replacement\n",
        "#   - Noise\n",
        "#d_augmentation = [\"Standard\", \"Replacement\", \"Noise\"]\n",
        "d_augmentation = [\"None\"]\n",
        "\n",
        "# prior\n",
        "prior_scales=[0.01, 1.0]\n",
        "#prior_scales=[1.0]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "id": "RbTGSR9jQje6",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# x_train_fashion, y_train_fashion, x_test_fashion, y_test_fashion = load_data('fashion_mnist')\n",
        "x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist = load_data('mnist')\n",
        "#x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist = load_random_labelled_data('mnist')\n",
        "\n",
        "\n",
        "x_train_mnist = x_train_mnist[:100]\n",
        "y_train_mnist = y_train_mnist[:100]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "n4i8AnQ6Qt-3",
        "outputId": "f380c4c8-797e-4a80-f6af-6c03b3f51981",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "15\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  0%|          | 0/7 [00:00<?, ?it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.5\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.8/dist-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n",
            "  loc = add_variable_fn(\n",
            "/usr/local/lib/python3.8/dist-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.\n",
            "  untransformed_scale = add_variable_fn(\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1/1 [==============================] - 2s 2s/step - loss: 14125.7275 - accuracy: 0.1100 - categorical_crossentropy: 2.3075\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7314 - accuracy: 0.1046 - categorical_crossentropy: 2.3026\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 14125.7236 - accuracy: 0.1700 - categorical_crossentropy: 2.3042\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7334 - accuracy: 0.0999 - categorical_crossentropy: 2.3027\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 14125.7295 - accuracy: 0.0600 - categorical_crossentropy: 2.3095\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7402 - accuracy: 0.1033 - categorical_crossentropy: 2.3039\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 14125.7207 - accuracy: 0.0800 - categorical_crossentropy: 2.3024\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7354 - accuracy: 0.1011 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 14125.7168 - accuracy: 0.1000 - categorical_crossentropy: 2.2985\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7354 - accuracy: 0.0973 - categorical_crossentropy: 2.3034\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 14125.7168 - accuracy: 0.1500 - categorical_crossentropy: 2.2984\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7334 - accuracy: 0.1046 - categorical_crossentropy: 2.3024\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 14125.7207 - accuracy: 0.1400 - categorical_crossentropy: 2.3021\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7354 - accuracy: 0.1039 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 27ms/step - loss: 14125.7197 - accuracy: 0.1500 - categorical_crossentropy: 2.3005\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7334 - accuracy: 0.1003 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 14125.7236 - accuracy: 0.1200 - categorical_crossentropy: 2.3042\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7354 - accuracy: 0.0986 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 14125.7246 - accuracy: 0.1200 - categorical_crossentropy: 2.3066\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14125.7363 - accuracy: 0.1036 - categorical_crossentropy: 2.3035\n",
            "1/1 [==============================] - 9s 9s/step - loss: 141234.0625 - accuracy: 0.1100 - categorical_crossentropy: 2.3035\n",
            "100/100 [==============================] - 0s 4ms/step - loss: 141234.0000 - accuracy: 0.1035 - categorical_crossentropy: 2.3027\n",
            "Tue Apr 25 09:19:13 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   70C    P0    51W /  70W |   8975MiB / 15360MiB |     57%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 14%|█▍        | 1/7 [01:47<10:46, 107.74s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.0\n",
            "1/1 [==============================] - 2s 2s/step - loss: 7064.0435 - accuracy: 0.1100 - categorical_crossentropy: 2.3075\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0405 - accuracy: 0.1049 - categorical_crossentropy: 2.3026\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 7064.0400 - accuracy: 0.1700 - categorical_crossentropy: 2.3042\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0386 - accuracy: 0.0999 - categorical_crossentropy: 2.3027\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7064.0454 - accuracy: 0.0600 - categorical_crossentropy: 2.3094\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0396 - accuracy: 0.1034 - categorical_crossentropy: 2.3039\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 7064.0381 - accuracy: 0.0800 - categorical_crossentropy: 2.3023\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0396 - accuracy: 0.1013 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7064.0342 - accuracy: 0.1000 - categorical_crossentropy: 2.2985\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0410 - accuracy: 0.0974 - categorical_crossentropy: 2.3034\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 7064.0342 - accuracy: 0.1500 - categorical_crossentropy: 2.2983\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0396 - accuracy: 0.1052 - categorical_crossentropy: 2.3024\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 7064.0381 - accuracy: 0.1400 - categorical_crossentropy: 2.3021\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0405 - accuracy: 0.1046 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 7064.0361 - accuracy: 0.1500 - categorical_crossentropy: 2.3004\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0396 - accuracy: 0.1002 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 7064.0400 - accuracy: 0.1200 - categorical_crossentropy: 2.3041\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0381 - accuracy: 0.0981 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7064.0430 - accuracy: 0.1200 - categorical_crossentropy: 2.3065\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7064.0405 - accuracy: 0.1040 - categorical_crossentropy: 2.3035\n",
            "1/1 [==============================] - 8s 8s/step - loss: 70617.3125 - accuracy: 0.1100 - categorical_crossentropy: 2.3034\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 70617.3594 - accuracy: 0.1038 - categorical_crossentropy: 2.3027\n",
            "Tue Apr 25 09:21:00 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   72C    P0    34W /  70W |   8975MiB / 15360MiB |     46%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 29%|██▊       | 2/7 [03:33<08:54, 106.82s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 4710.1538 - accuracy: 0.1000 - categorical_crossentropy: 2.3074\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1514 - accuracy: 0.1048 - categorical_crossentropy: 2.3025\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 4710.1509 - accuracy: 0.1700 - categorical_crossentropy: 2.3041\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1523 - accuracy: 0.0997 - categorical_crossentropy: 2.3027\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 4710.1558 - accuracy: 0.0600 - categorical_crossentropy: 2.3094\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1543 - accuracy: 0.1039 - categorical_crossentropy: 2.3039\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4710.1489 - accuracy: 0.0800 - categorical_crossentropy: 2.3023\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1528 - accuracy: 0.1013 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4710.1450 - accuracy: 0.1000 - categorical_crossentropy: 2.2984\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1533 - accuracy: 0.0973 - categorical_crossentropy: 2.3034\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 4710.1450 - accuracy: 0.1500 - categorical_crossentropy: 2.2983\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1523 - accuracy: 0.1048 - categorical_crossentropy: 2.3024\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 4710.1484 - accuracy: 0.1400 - categorical_crossentropy: 2.3020\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1528 - accuracy: 0.1052 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 4710.1470 - accuracy: 0.1500 - categorical_crossentropy: 2.3004\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1528 - accuracy: 0.0999 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 4710.1509 - accuracy: 0.1200 - categorical_crossentropy: 2.3041\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1523 - accuracy: 0.0984 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 4710.1528 - accuracy: 0.1200 - categorical_crossentropy: 2.3065\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4710.1528 - accuracy: 0.1041 - categorical_crossentropy: 2.3035\n",
            "1/1 [==============================] - 9s 9s/step - loss: 47078.4414 - accuracy: 0.1100 - categorical_crossentropy: 2.3034\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 47078.4844 - accuracy: 0.1028 - categorical_crossentropy: 2.3027\n",
            "Tue Apr 25 09:22:47 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   73C    P0    42W /  70W |   8977MiB / 15360MiB |     59%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 43%|████▎     | 3/7 [05:21<07:08, 107.21s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.0\n",
            "1/1 [==============================] - 2s 2s/step - loss: 3533.2124 - accuracy: 0.1100 - categorical_crossentropy: 2.3074\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2053 - accuracy: 0.1048 - categorical_crossentropy: 2.3025\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3533.2090 - accuracy: 0.1700 - categorical_crossentropy: 2.3041\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2053 - accuracy: 0.1000 - categorical_crossentropy: 2.3027\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3533.2144 - accuracy: 0.0600 - categorical_crossentropy: 2.3093\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2085 - accuracy: 0.1044 - categorical_crossentropy: 2.3039\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3533.2068 - accuracy: 0.0800 - categorical_crossentropy: 2.3022\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2070 - accuracy: 0.1013 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3533.2031 - accuracy: 0.1100 - categorical_crossentropy: 2.2984\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2073 - accuracy: 0.0976 - categorical_crossentropy: 2.3034\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3533.2031 - accuracy: 0.1500 - categorical_crossentropy: 2.2983\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2058 - accuracy: 0.1049 - categorical_crossentropy: 2.3023\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 3533.2068 - accuracy: 0.1400 - categorical_crossentropy: 2.3020\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2058 - accuracy: 0.1054 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3533.2048 - accuracy: 0.1500 - categorical_crossentropy: 2.3003\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2058 - accuracy: 0.0988 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3533.2090 - accuracy: 0.1200 - categorical_crossentropy: 2.3040\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2065 - accuracy: 0.0985 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 3533.2114 - accuracy: 0.1200 - categorical_crossentropy: 2.3064\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3533.2083 - accuracy: 0.1046 - categorical_crossentropy: 2.3035\n",
            "1/1 [==============================] - 9s 9s/step - loss: 35309.0312 - accuracy: 0.1100 - categorical_crossentropy: 2.3033\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 35309.0039 - accuracy: 0.1033 - categorical_crossentropy: 2.3027\n",
            "Tue Apr 25 09:24:34 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   74C    P0    35W /  70W |   8977MiB / 15360MiB |     36%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 57%|█████▋    | 4/7 [07:08<05:21, 107.14s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.5\n",
            "1/1 [==============================] - 2s 2s/step - loss: 2827.0505 - accuracy: 0.1100 - categorical_crossentropy: 2.3073\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0447 - accuracy: 0.1051 - categorical_crossentropy: 2.3025\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2827.0471 - accuracy: 0.1700 - categorical_crossentropy: 2.3040\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0442 - accuracy: 0.1004 - categorical_crossentropy: 2.3027\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2827.0522 - accuracy: 0.0600 - categorical_crossentropy: 2.3093\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0464 - accuracy: 0.1049 - categorical_crossentropy: 2.3039\n",
            "1/1 [==============================] - 0s 27ms/step - loss: 2827.0452 - accuracy: 0.0800 - categorical_crossentropy: 2.3022\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0442 - accuracy: 0.1017 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2827.0413 - accuracy: 0.1000 - categorical_crossentropy: 2.2983\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0454 - accuracy: 0.0970 - categorical_crossentropy: 2.3034\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 2827.0413 - accuracy: 0.1500 - categorical_crossentropy: 2.2982\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0442 - accuracy: 0.1049 - categorical_crossentropy: 2.3023\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2827.0452 - accuracy: 0.1400 - categorical_crossentropy: 2.3019\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0449 - accuracy: 0.1053 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 29ms/step - loss: 2827.0430 - accuracy: 0.1500 - categorical_crossentropy: 2.3003\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0454 - accuracy: 0.0984 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 2827.0471 - accuracy: 0.1200 - categorical_crossentropy: 2.3040\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0447 - accuracy: 0.0982 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2827.0496 - accuracy: 0.1200 - categorical_crossentropy: 2.3064\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2827.0457 - accuracy: 0.1043 - categorical_crossentropy: 2.3035\n",
            "1/1 [==============================] - 9s 9s/step - loss: 28247.4199 - accuracy: 0.1100 - categorical_crossentropy: 2.3033\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 28247.4238 - accuracy: 0.1043 - categorical_crossentropy: 2.3027\n",
            "Tue Apr 25 09:26:21 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   75C    P0    37W /  70W |   8979MiB / 15360MiB |      3%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 71%|███████▏  | 5/7 [08:55<03:34, 107.00s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.0\n",
            "1/1 [==============================] - 2s 2s/step - loss: 2356.2776 - accuracy: 0.1100 - categorical_crossentropy: 2.3073\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2739 - accuracy: 0.1045 - categorical_crossentropy: 2.3025\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 2356.2744 - accuracy: 0.1700 - categorical_crossentropy: 2.3040\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2747 - accuracy: 0.1001 - categorical_crossentropy: 2.3027\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2356.2795 - accuracy: 0.0600 - categorical_crossentropy: 2.3092\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2759 - accuracy: 0.1052 - categorical_crossentropy: 2.3039\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2356.2725 - accuracy: 0.0800 - categorical_crossentropy: 2.3021\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2751 - accuracy: 0.1016 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2356.2686 - accuracy: 0.1000 - categorical_crossentropy: 2.2983\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2744 - accuracy: 0.0966 - categorical_crossentropy: 2.3034\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2356.2686 - accuracy: 0.1500 - categorical_crossentropy: 2.2982\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2739 - accuracy: 0.1050 - categorical_crossentropy: 2.3023\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2356.2722 - accuracy: 0.1400 - categorical_crossentropy: 2.3019\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2749 - accuracy: 0.1055 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 17ms/step - loss: 2356.2708 - accuracy: 0.1500 - categorical_crossentropy: 2.3002\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2744 - accuracy: 0.0982 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2356.2744 - accuracy: 0.1200 - categorical_crossentropy: 2.3039\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2749 - accuracy: 0.0985 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2356.2766 - accuracy: 0.1200 - categorical_crossentropy: 2.3063\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2356.2751 - accuracy: 0.1041 - categorical_crossentropy: 2.3035\n",
            "1/1 [==============================] - 10s 10s/step - loss: 23539.6816 - accuracy: 0.1100 - categorical_crossentropy: 2.3032\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 23539.6641 - accuracy: 0.1045 - categorical_crossentropy: 2.3027\n",
            "Tue Apr 25 09:28:09 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   75C    P0    36W /  70W |   8979MiB / 15360MiB |     49%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 86%|████████▌ | 6/7 [10:43<01:47, 107.27s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.5\n",
            "1/1 [==============================] - 2s 2s/step - loss: 2020.0135 - accuracy: 0.1100 - categorical_crossentropy: 2.3072\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0092 - accuracy: 0.1043 - categorical_crossentropy: 2.3025\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2020.0101 - accuracy: 0.1700 - categorical_crossentropy: 2.3039\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0092 - accuracy: 0.1001 - categorical_crossentropy: 2.3027\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2020.0154 - accuracy: 0.0600 - categorical_crossentropy: 2.3092\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0096 - accuracy: 0.1056 - categorical_crossentropy: 2.3039\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2020.0084 - accuracy: 0.0800 - categorical_crossentropy: 2.3021\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0095 - accuracy: 0.1017 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2020.0045 - accuracy: 0.1000 - categorical_crossentropy: 2.2982\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0099 - accuracy: 0.0966 - categorical_crossentropy: 2.3034\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2020.0042 - accuracy: 0.1500 - categorical_crossentropy: 2.2981\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0083 - accuracy: 0.1051 - categorical_crossentropy: 2.3023\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 2020.0081 - accuracy: 0.1400 - categorical_crossentropy: 2.3018\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0092 - accuracy: 0.1050 - categorical_crossentropy: 2.3028\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2020.0063 - accuracy: 0.1500 - categorical_crossentropy: 2.3002\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0092 - accuracy: 0.0984 - categorical_crossentropy: 2.3033\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2020.0101 - accuracy: 0.1200 - categorical_crossentropy: 2.3039\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0087 - accuracy: 0.0988 - categorical_crossentropy: 2.3029\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2020.0125 - accuracy: 0.1200 - categorical_crossentropy: 2.3063\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2020.0099 - accuracy: 0.1043 - categorical_crossentropy: 2.3035\n",
            "1/1 [==============================] - 9s 9s/step - loss: 20177.0703 - accuracy: 0.1100 - categorical_crossentropy: 2.3032\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 20177.0938 - accuracy: 0.1043 - categorical_crossentropy: 2.3027\n",
            "Tue Apr 25 09:29:56 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   76C    P0    46W /  70W |   8979MiB / 15360MiB |     57%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 7/7 [12:30<00:00, 107.17s/it]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "15\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  0%|          | 0/7 [00:00<?, ?it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.5\n",
            "1/1 [==============================] - 2s 2s/step - loss: 14127.1611 - accuracy: 0.1100 - categorical_crossentropy: 2.3803\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1270 - accuracy: 0.1030 - categorical_crossentropy: 2.3462\n",
            "1/1 [==============================] - 0s 27ms/step - loss: 14127.1270 - accuracy: 0.1700 - categorical_crossentropy: 2.3458\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1250 - accuracy: 0.1005 - categorical_crossentropy: 2.3453\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 14127.2158 - accuracy: 0.0600 - categorical_crossentropy: 2.4348\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1494 - accuracy: 0.1057 - categorical_crossentropy: 2.3640\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 14127.1025 - accuracy: 0.0800 - categorical_crossentropy: 2.3210\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1396 - accuracy: 0.1013 - categorical_crossentropy: 2.3557\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 14127.1123 - accuracy: 0.1000 - categorical_crossentropy: 2.3310\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1367 - accuracy: 0.0970 - categorical_crossentropy: 2.3556\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 14127.0918 - accuracy: 0.1500 - categorical_crossentropy: 2.3112\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1270 - accuracy: 0.1053 - categorical_crossentropy: 2.3440\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 14127.0957 - accuracy: 0.1400 - categorical_crossentropy: 2.3153\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1289 - accuracy: 0.1051 - categorical_crossentropy: 2.3509\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 14127.0957 - accuracy: 0.1500 - categorical_crossentropy: 2.3145\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1348 - accuracy: 0.0976 - categorical_crossentropy: 2.3526\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 14127.1494 - accuracy: 0.1200 - categorical_crossentropy: 2.3678\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1348 - accuracy: 0.0991 - categorical_crossentropy: 2.3489\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 14127.2305 - accuracy: 0.1200 - categorical_crossentropy: 2.4491\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 14127.1396 - accuracy: 0.1036 - categorical_crossentropy: 2.3580\n",
            "1/1 [==============================] - 9s 9s/step - loss: 141247.9531 - accuracy: 0.1100 - categorical_crossentropy: 2.3172\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 141247.9844 - accuracy: 0.1009 - categorical_crossentropy: 2.3086\n",
            "Tue Apr 25 09:31:50 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   76C    P0    38W /  70W |   8981MiB / 15360MiB |     19%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 14%|█▍        | 1/7 [01:53<11:22, 113.72s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.0\n",
            "1/1 [==============================] - 2s 2s/step - loss: 7066.4697 - accuracy: 0.1100 - categorical_crossentropy: 2.3772\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4409 - accuracy: 0.1038 - categorical_crossentropy: 2.3471\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 7066.4341 - accuracy: 0.1600 - categorical_crossentropy: 2.3419\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4409 - accuracy: 0.1015 - categorical_crossentropy: 2.3463\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 7066.5244 - accuracy: 0.0600 - categorical_crossentropy: 2.4319\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4580 - accuracy: 0.1058 - categorical_crossentropy: 2.3648\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7066.4097 - accuracy: 0.0900 - categorical_crossentropy: 2.3172\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4497 - accuracy: 0.1007 - categorical_crossentropy: 2.3567\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 7066.4219 - accuracy: 0.0900 - categorical_crossentropy: 2.3292\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4492 - accuracy: 0.0978 - categorical_crossentropy: 2.3565\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 7066.4043 - accuracy: 0.1400 - categorical_crossentropy: 2.3114\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4390 - accuracy: 0.1063 - categorical_crossentropy: 2.3448\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 7066.4048 - accuracy: 0.1300 - categorical_crossentropy: 2.3124\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4443 - accuracy: 0.1049 - categorical_crossentropy: 2.3518\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7066.4048 - accuracy: 0.1600 - categorical_crossentropy: 2.3126\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4458 - accuracy: 0.0971 - categorical_crossentropy: 2.3537\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7066.4575 - accuracy: 0.1200 - categorical_crossentropy: 2.3648\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4429 - accuracy: 0.0989 - categorical_crossentropy: 2.3499\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 7066.5430 - accuracy: 0.1200 - categorical_crossentropy: 2.4504\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7066.4531 - accuracy: 0.1044 - categorical_crossentropy: 2.3588\n",
            "1/1 [==============================] - 10s 10s/step - loss: 70640.9062 - accuracy: 0.1100 - categorical_crossentropy: 2.3135\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 70640.9609 - accuracy: 0.1030 - categorical_crossentropy: 2.3087\n",
            "Tue Apr 25 09:33:38 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    36W /  70W |   8981MiB / 15360MiB |     32%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 29%|██▊       | 2/7 [03:42<09:13, 110.61s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.5\n",
            "1/1 [==============================] - 2s 2s/step - loss: 4713.4175 - accuracy: 0.1100 - categorical_crossentropy: 2.3753\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.3921 - accuracy: 0.1049 - categorical_crossentropy: 2.3486\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4713.3809 - accuracy: 0.1500 - categorical_crossentropy: 2.3389\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.3906 - accuracy: 0.1015 - categorical_crossentropy: 2.3478\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 4713.4731 - accuracy: 0.0600 - categorical_crossentropy: 2.4310\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.4092 - accuracy: 0.1058 - categorical_crossentropy: 2.3662\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 4713.3564 - accuracy: 0.0900 - categorical_crossentropy: 2.3144\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.3989 - accuracy: 0.1020 - categorical_crossentropy: 2.3582\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 4713.3701 - accuracy: 0.0900 - categorical_crossentropy: 2.3280\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.4004 - accuracy: 0.0987 - categorical_crossentropy: 2.3579\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 4713.3540 - accuracy: 0.1400 - categorical_crossentropy: 2.3123\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.3887 - accuracy: 0.1067 - categorical_crossentropy: 2.3462\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 4713.3521 - accuracy: 0.1300 - categorical_crossentropy: 2.3102\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.3955 - accuracy: 0.1044 - categorical_crossentropy: 2.3533\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4713.3530 - accuracy: 0.1500 - categorical_crossentropy: 2.3111\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.3970 - accuracy: 0.0967 - categorical_crossentropy: 2.3553\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 4713.4048 - accuracy: 0.1200 - categorical_crossentropy: 2.3629\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.3950 - accuracy: 0.0982 - categorical_crossentropy: 2.3513\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 4713.4946 - accuracy: 0.1200 - categorical_crossentropy: 2.4526\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4713.4038 - accuracy: 0.1039 - categorical_crossentropy: 2.3602\n",
            "1/1 [==============================] - 9s 9s/step - loss: 47110.4453 - accuracy: 0.1100 - categorical_crossentropy: 2.3105\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 47110.4844 - accuracy: 0.1042 - categorical_crossentropy: 2.3090\n",
            "Tue Apr 25 09:35:26 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    36W /  70W |   8981MiB / 15360MiB |     49%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 43%|████▎     | 3/7 [05:30<07:17, 109.36s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.0\n",
            "1/1 [==============================] - 2s 2s/step - loss: 3537.2124 - accuracy: 0.1100 - categorical_crossentropy: 2.3741\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1890 - accuracy: 0.1041 - categorical_crossentropy: 2.3504\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 3537.1748 - accuracy: 0.1400 - categorical_crossentropy: 2.3366\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1890 - accuracy: 0.1022 - categorical_crossentropy: 2.3496\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 3537.2693 - accuracy: 0.0600 - categorical_crossentropy: 2.4310\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.2065 - accuracy: 0.1060 - categorical_crossentropy: 2.3680\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 3537.1506 - accuracy: 0.0900 - categorical_crossentropy: 2.3121\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1982 - accuracy: 0.1016 - categorical_crossentropy: 2.3600\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3537.1655 - accuracy: 0.1100 - categorical_crossentropy: 2.3273\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1995 - accuracy: 0.1011 - categorical_crossentropy: 2.3597\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 3537.1521 - accuracy: 0.1400 - categorical_crossentropy: 2.3136\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1863 - accuracy: 0.1076 - categorical_crossentropy: 2.3479\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3537.1472 - accuracy: 0.1200 - categorical_crossentropy: 2.3086\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1946 - accuracy: 0.1045 - categorical_crossentropy: 2.3550\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 3537.1487 - accuracy: 0.1500 - categorical_crossentropy: 2.3099\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1953 - accuracy: 0.0973 - categorical_crossentropy: 2.3571\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 3537.2000 - accuracy: 0.1200 - categorical_crossentropy: 2.3617\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.1921 - accuracy: 0.0990 - categorical_crossentropy: 2.3531\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3537.2937 - accuracy: 0.1200 - categorical_crossentropy: 2.4554\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3537.2007 - accuracy: 0.1046 - categorical_crossentropy: 2.3620\n",
            "1/1 [==============================] - 9s 9s/step - loss: 35348.3789 - accuracy: 0.1100 - categorical_crossentropy: 2.3081\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 35348.4258 - accuracy: 0.1054 - categorical_crossentropy: 2.3095\n",
            "Tue Apr 25 09:37:17 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    37W /  70W |   8983MiB / 15360MiB |     25%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 57%|█████▋    | 4/7 [07:21<05:29, 110.00s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.5\n",
            "1/1 [==============================] - 2s 2s/step - loss: 2831.7163 - accuracy: 0.1100 - categorical_crossentropy: 2.3733\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.6948 - accuracy: 0.1029 - categorical_crossentropy: 2.3523\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 2831.6780 - accuracy: 0.1400 - categorical_crossentropy: 2.3350\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.6948 - accuracy: 0.1039 - categorical_crossentropy: 2.3515\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2831.7749 - accuracy: 0.0600 - categorical_crossentropy: 2.4316\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.7139 - accuracy: 0.1085 - categorical_crossentropy: 2.3700\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2831.6533 - accuracy: 0.0900 - categorical_crossentropy: 2.3102\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.7053 - accuracy: 0.1028 - categorical_crossentropy: 2.3621\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2831.6699 - accuracy: 0.1100 - categorical_crossentropy: 2.3269\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.7051 - accuracy: 0.1013 - categorical_crossentropy: 2.3617\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2831.6580 - accuracy: 0.1400 - categorical_crossentropy: 2.3151\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.6929 - accuracy: 0.1073 - categorical_crossentropy: 2.3498\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 2831.6506 - accuracy: 0.1200 - categorical_crossentropy: 2.3075\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.7007 - accuracy: 0.1044 - categorical_crossentropy: 2.3570\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2831.6523 - accuracy: 0.1500 - categorical_crossentropy: 2.3092\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.7024 - accuracy: 0.0964 - categorical_crossentropy: 2.3591\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2831.7041 - accuracy: 0.1200 - categorical_crossentropy: 2.3610\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.6982 - accuracy: 0.0998 - categorical_crossentropy: 2.3550\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 2831.8018 - accuracy: 0.1200 - categorical_crossentropy: 2.4588\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2831.7063 - accuracy: 0.1054 - categorical_crossentropy: 2.3639\n",
            "1/1 [==============================] - 10s 10s/step - loss: 28293.4160 - accuracy: 0.1100 - categorical_crossentropy: 2.3062\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 28293.4238 - accuracy: 0.1053 - categorical_crossentropy: 2.3100\n",
            "Tue Apr 25 09:39:04 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    38W /  70W |   8983MiB / 15360MiB |     55%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 71%|███████▏  | 5/7 [09:07<03:37, 108.90s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.0\n",
            "1/1 [==============================] - 2s 2s/step - loss: 2361.5552 - accuracy: 0.1100 - categorical_crossentropy: 2.3727\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5366 - accuracy: 0.1029 - categorical_crossentropy: 2.3544\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 2361.5164 - accuracy: 0.1400 - categorical_crossentropy: 2.3338\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5366 - accuracy: 0.1045 - categorical_crossentropy: 2.3535\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 2361.6150 - accuracy: 0.0600 - categorical_crossentropy: 2.4325\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5552 - accuracy: 0.1092 - categorical_crossentropy: 2.3721\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 2361.4912 - accuracy: 0.1000 - categorical_crossentropy: 2.3088\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5469 - accuracy: 0.1036 - categorical_crossentropy: 2.3642\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2361.5093 - accuracy: 0.1200 - categorical_crossentropy: 2.3267\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5466 - accuracy: 0.1018 - categorical_crossentropy: 2.3638\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2361.4993 - accuracy: 0.1400 - categorical_crossentropy: 2.3168\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5344 - accuracy: 0.1086 - categorical_crossentropy: 2.3517\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2361.4893 - accuracy: 0.1200 - categorical_crossentropy: 2.3066\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5420 - accuracy: 0.1051 - categorical_crossentropy: 2.3590\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2361.4915 - accuracy: 0.1700 - categorical_crossentropy: 2.3088\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5442 - accuracy: 0.0971 - categorical_crossentropy: 2.3612\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2361.5432 - accuracy: 0.1200 - categorical_crossentropy: 2.3606\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5393 - accuracy: 0.1003 - categorical_crossentropy: 2.3570\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 2361.6450 - accuracy: 0.1200 - categorical_crossentropy: 2.4624\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2361.5479 - accuracy: 0.1060 - categorical_crossentropy: 2.3659\n",
            "1/1 [==============================] - 9s 9s/step - loss: 23591.8262 - accuracy: 0.1100 - categorical_crossentropy: 2.3045\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 23591.8340 - accuracy: 0.1067 - categorical_crossentropy: 2.3107\n",
            "Tue Apr 25 09:40:53 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    38W /  70W |   8985MiB / 15360MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 86%|████████▌ | 6/7 [10:56<01:48, 108.86s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.5\n",
            "1/1 [==============================] - 2s 2s/step - loss: 2025.8572 - accuracy: 0.1100 - categorical_crossentropy: 2.3724\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8420 - accuracy: 0.1027 - categorical_crossentropy: 2.3564\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2025.8176 - accuracy: 0.1400 - categorical_crossentropy: 2.3329\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8403 - accuracy: 0.1047 - categorical_crossentropy: 2.3555\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2025.9182 - accuracy: 0.0600 - categorical_crossentropy: 2.4336\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8593 - accuracy: 0.1087 - categorical_crossentropy: 2.3742\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 2025.7924 - accuracy: 0.1000 - categorical_crossentropy: 2.3076\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8511 - accuracy: 0.1046 - categorical_crossentropy: 2.3664\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2025.8114 - accuracy: 0.1200 - categorical_crossentropy: 2.3266\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8510 - accuracy: 0.1016 - categorical_crossentropy: 2.3658\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 2025.8032 - accuracy: 0.1400 - categorical_crossentropy: 2.3185\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8383 - accuracy: 0.1095 - categorical_crossentropy: 2.3537\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 2025.7908 - accuracy: 0.1200 - categorical_crossentropy: 2.3060\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8451 - accuracy: 0.1052 - categorical_crossentropy: 2.3611\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 2025.7933 - accuracy: 0.1700 - categorical_crossentropy: 2.3087\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8485 - accuracy: 0.0971 - categorical_crossentropy: 2.3633\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 2025.8451 - accuracy: 0.1200 - categorical_crossentropy: 2.3605\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8439 - accuracy: 0.1001 - categorical_crossentropy: 2.3590\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 2025.9507 - accuracy: 0.1200 - categorical_crossentropy: 2.4661\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 2025.8525 - accuracy: 0.1060 - categorical_crossentropy: 2.3679\n",
            "1/1 [==============================] - 9s 9s/step - loss: 20234.8418 - accuracy: 0.1100 - categorical_crossentropy: 2.3031\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 20234.8496 - accuracy: 0.1067 - categorical_crossentropy: 2.3113\n",
            "Tue Apr 25 09:42:43 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    39W /  70W |   8985MiB / 15360MiB |     55%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 7/7 [12:46<00:00, 109.54s/it]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "15\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  0%|          | 0/7 [00:00<?, ?it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 21118.7852 - accuracy: 0.8700 - categorical_crossentropy: 0.4433\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.8086 - accuracy: 0.5578 - categorical_crossentropy: 2.4698\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 21118.7520 - accuracy: 0.8800 - categorical_crossentropy: 0.4102\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.5938 - accuracy: 0.5689 - categorical_crossentropy: 2.2550\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 21118.9941 - accuracy: 0.8400 - categorical_crossentropy: 0.6544\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.9121 - accuracy: 0.5298 - categorical_crossentropy: 2.5673\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 21118.8379 - accuracy: 0.9300 - categorical_crossentropy: 0.4969\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.8027 - accuracy: 0.5576 - categorical_crossentropy: 2.4638\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 21118.5762 - accuracy: 0.9100 - categorical_crossentropy: 0.2361\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.8359 - accuracy: 0.5473 - categorical_crossentropy: 2.4882\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 21118.6523 - accuracy: 0.8500 - categorical_crossentropy: 0.3102\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.7695 - accuracy: 0.5475 - categorical_crossentropy: 2.4266\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 21118.7520 - accuracy: 0.8800 - categorical_crossentropy: 0.4094\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.7793 - accuracy: 0.5532 - categorical_crossentropy: 2.4321\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 21118.9590 - accuracy: 0.8300 - categorical_crossentropy: 0.6175\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.8281 - accuracy: 0.5485 - categorical_crossentropy: 2.4804\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 21118.8516 - accuracy: 0.8600 - categorical_crossentropy: 0.5112\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.7539 - accuracy: 0.5617 - categorical_crossentropy: 2.4086\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 21118.8242 - accuracy: 0.8400 - categorical_crossentropy: 0.4851\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21120.7969 - accuracy: 0.5494 - categorical_crossentropy: 2.4573\n",
            "1/1 [==============================] - 9s 9s/step - loss: 211183.2656 - accuracy: 0.9800 - categorical_crossentropy: 0.0426\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 211183.3594 - accuracy: 0.6711 - categorical_crossentropy: 1.5595\n",
            "Tue Apr 25 09:44:29 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    41W /  70W |   8985MiB / 15360MiB |     59%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 14%|█▍        | 1/7 [01:46<10:36, 106.12s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.0\n",
            "1/1 [==============================] - 1s 1s/step - loss: 10560.6025 - accuracy: 0.9300 - categorical_crossentropy: 0.2148\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10563.0273 - accuracy: 0.5936 - categorical_crossentropy: 2.6394\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 10560.6162 - accuracy: 0.9300 - categorical_crossentropy: 0.2287\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10562.7783 - accuracy: 0.6044 - categorical_crossentropy: 2.3916\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 10560.8711 - accuracy: 0.8900 - categorical_crossentropy: 0.4828\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10563.0742 - accuracy: 0.5736 - categorical_crossentropy: 2.6872\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 10560.6396 - accuracy: 0.9400 - categorical_crossentropy: 0.2521\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10562.9902 - accuracy: 0.5962 - categorical_crossentropy: 2.6032\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 10560.5381 - accuracy: 0.9500 - categorical_crossentropy: 0.1511\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10563.0195 - accuracy: 0.5857 - categorical_crossentropy: 2.6350\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 10560.5146 - accuracy: 0.9500 - categorical_crossentropy: 0.1275\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10562.9688 - accuracy: 0.5917 - categorical_crossentropy: 2.5808\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 10560.5889 - accuracy: 0.9700 - categorical_crossentropy: 0.2011\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10562.9668 - accuracy: 0.5913 - categorical_crossentropy: 2.5793\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 10560.7363 - accuracy: 0.8800 - categorical_crossentropy: 0.3481\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10563.0303 - accuracy: 0.5902 - categorical_crossentropy: 2.6426\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 10560.5645 - accuracy: 0.9500 - categorical_crossentropy: 0.1765\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10562.9385 - accuracy: 0.5952 - categorical_crossentropy: 2.5513\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 10560.7207 - accuracy: 0.9200 - categorical_crossentropy: 0.3335\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10562.9697 - accuracy: 0.5896 - categorical_crossentropy: 2.5754\n",
            "1/1 [==============================] - 8s 8s/step - loss: 105603.8203 - accuracy: 1.0000 - categorical_crossentropy: 0.0103\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 105603.9141 - accuracy: 0.6869 - categorical_crossentropy: 1.7907\n",
            "Tue Apr 25 09:46:16 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    41W /  70W |   8987MiB / 15360MiB |     53%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 29%|██▊       | 2/7 [03:33<08:54, 106.85s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 7041.3379 - accuracy: 0.9300 - categorical_crossentropy: 0.1363\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.9810 - accuracy: 0.6024 - categorical_crossentropy: 2.7778\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7041.3711 - accuracy: 0.9500 - categorical_crossentropy: 0.1695\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.7241 - accuracy: 0.6195 - categorical_crossentropy: 2.5226\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 7041.6211 - accuracy: 0.9300 - categorical_crossentropy: 0.4194\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7044.0249 - accuracy: 0.5923 - categorical_crossentropy: 2.8230\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 7041.3359 - accuracy: 0.9600 - categorical_crossentropy: 0.1341\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.9292 - accuracy: 0.6113 - categorical_crossentropy: 2.7265\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 7041.3237 - accuracy: 0.9500 - categorical_crossentropy: 0.1222\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.9683 - accuracy: 0.5991 - categorical_crossentropy: 2.7643\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 7041.2925 - accuracy: 0.9500 - categorical_crossentropy: 0.0908\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.9194 - accuracy: 0.6074 - categorical_crossentropy: 2.7184\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 7041.3594 - accuracy: 0.9700 - categorical_crossentropy: 0.1577\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.8906 - accuracy: 0.6081 - categorical_crossentropy: 2.6900\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 7041.4312 - accuracy: 0.9100 - categorical_crossentropy: 0.2294\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.9883 - accuracy: 0.6030 - categorical_crossentropy: 2.7853\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 7041.2988 - accuracy: 0.9700 - categorical_crossentropy: 0.0973\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.8794 - accuracy: 0.6128 - categorical_crossentropy: 2.6775\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 7041.4497 - accuracy: 0.9200 - categorical_crossentropy: 0.2486\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7043.8931 - accuracy: 0.6068 - categorical_crossentropy: 2.6911\n",
            "1/1 [==============================] - 8s 8s/step - loss: 70412.0781 - accuracy: 1.0000 - categorical_crossentropy: 0.0043\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 70412.0312 - accuracy: 0.6896 - categorical_crossentropy: 1.9623\n",
            "Tue Apr 25 09:48:04 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    37W /  70W |   8987MiB / 15360MiB |     54%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 43%|████▎     | 3/7 [05:20<07:08, 107.14s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.0\n",
            "1/1 [==============================] - 1s 1s/step - loss: 5281.7988 - accuracy: 0.9500 - categorical_crossentropy: 0.0837\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.6138 - accuracy: 0.6107 - categorical_crossentropy: 2.8982\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 5281.8398 - accuracy: 0.9700 - categorical_crossentropy: 0.1244\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.3521 - accuracy: 0.6243 - categorical_crossentropy: 2.6359\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 5282.0659 - accuracy: 0.9300 - categorical_crossentropy: 0.3509\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.6572 - accuracy: 0.6029 - categorical_crossentropy: 2.9412\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 5281.7969 - accuracy: 0.9700 - categorical_crossentropy: 0.0812\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.5601 - accuracy: 0.6218 - categorical_crossentropy: 2.8449\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 5281.8193 - accuracy: 0.9700 - categorical_crossentropy: 0.1041\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.5913 - accuracy: 0.6103 - categorical_crossentropy: 2.8773\n",
            "1/1 [==============================] - 0s 28ms/step - loss: 5281.7817 - accuracy: 0.9700 - categorical_crossentropy: 0.0663\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.5508 - accuracy: 0.6150 - categorical_crossentropy: 2.8357\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 5281.8604 - accuracy: 0.9800 - categorical_crossentropy: 0.1449\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.5093 - accuracy: 0.6175 - categorical_crossentropy: 2.7946\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 5281.9092 - accuracy: 0.9300 - categorical_crossentropy: 0.1938\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.6250 - accuracy: 0.6122 - categorical_crossentropy: 2.9101\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 5281.7773 - accuracy: 0.9800 - categorical_crossentropy: 0.0622\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.5142 - accuracy: 0.6255 - categorical_crossentropy: 2.8002\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 5281.9097 - accuracy: 0.9500 - categorical_crossentropy: 0.1947\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5284.5107 - accuracy: 0.6137 - categorical_crossentropy: 2.7964\n",
            "1/1 [==============================] - 9s 9s/step - loss: 52817.1445 - accuracy: 1.0000 - categorical_crossentropy: 0.0019\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 52817.1484 - accuracy: 0.6911 - categorical_crossentropy: 2.0913\n",
            "Tue Apr 25 09:49:49 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    35W /  70W |   8987MiB / 15360MiB |     40%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 57%|█████▋    | 4/7 [07:05<05:18, 106.26s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 4226.1553 - accuracy: 0.9800 - categorical_crossentropy: 0.0592\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4229.0850 - accuracy: 0.6173 - categorical_crossentropy: 2.9886\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 4226.1958 - accuracy: 0.9700 - categorical_crossentropy: 0.1001\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4228.8247 - accuracy: 0.6303 - categorical_crossentropy: 2.7292\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 4226.4082 - accuracy: 0.9300 - categorical_crossentropy: 0.3123\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4229.1328 - accuracy: 0.6068 - categorical_crossentropy: 3.0359\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 4226.1514 - accuracy: 0.9800 - categorical_crossentropy: 0.0555\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4229.0322 - accuracy: 0.6289 - categorical_crossentropy: 2.9358\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4226.1855 - accuracy: 0.9700 - categorical_crossentropy: 0.0897\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4229.0659 - accuracy: 0.6154 - categorical_crossentropy: 2.9696\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 4226.1450 - accuracy: 0.9900 - categorical_crossentropy: 0.0492\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4229.0200 - accuracy: 0.6212 - categorical_crossentropy: 2.9238\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 4226.2241 - accuracy: 0.9800 - categorical_crossentropy: 0.1284\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4228.9790 - accuracy: 0.6254 - categorical_crossentropy: 2.8822\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 4226.2646 - accuracy: 0.9400 - categorical_crossentropy: 0.1688\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4229.1001 - accuracy: 0.6186 - categorical_crossentropy: 3.0045\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 4226.1426 - accuracy: 0.9800 - categorical_crossentropy: 0.0468\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4228.9956 - accuracy: 0.6285 - categorical_crossentropy: 2.8997\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 4226.2593 - accuracy: 0.9600 - categorical_crossentropy: 0.1637\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4228.9712 - accuracy: 0.6200 - categorical_crossentropy: 2.8744\n",
            "1/1 [==============================] - 9s 9s/step - loss: 42260.9336 - accuracy: 1.0000 - categorical_crossentropy: 9.6759e-04\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 42260.9492 - accuracy: 0.6925 - categorical_crossentropy: 2.1888\n",
            "Tue Apr 25 09:51:36 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    35W /  70W |   8989MiB / 15360MiB |     37%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 71%|███████▏  | 5/7 [08:53<03:33, 106.65s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.0\n",
            "1/1 [==============================] - 1s 1s/step - loss: 3522.4443 - accuracy: 0.9800 - categorical_crossentropy: 0.0396\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.4749 - accuracy: 0.6219 - categorical_crossentropy: 3.0701\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 3522.4668 - accuracy: 0.9700 - categorical_crossentropy: 0.0621\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.2195 - accuracy: 0.6332 - categorical_crossentropy: 2.8150\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 3522.6980 - accuracy: 0.9300 - categorical_crossentropy: 0.2931\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.5317 - accuracy: 0.6129 - categorical_crossentropy: 3.1279\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3522.4524 - accuracy: 0.9900 - categorical_crossentropy: 0.0476\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.4338 - accuracy: 0.6332 - categorical_crossentropy: 3.0285\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3522.4893 - accuracy: 0.9700 - categorical_crossentropy: 0.0847\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.4604 - accuracy: 0.6194 - categorical_crossentropy: 3.0553\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 3522.4502 - accuracy: 0.9900 - categorical_crossentropy: 0.0457\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.4050 - accuracy: 0.6266 - categorical_crossentropy: 2.9997\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3522.5195 - accuracy: 0.9800 - categorical_crossentropy: 0.1150\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.3738 - accuracy: 0.6280 - categorical_crossentropy: 2.9696\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 3522.5505 - accuracy: 0.9400 - categorical_crossentropy: 0.1460\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.4971 - accuracy: 0.6225 - categorical_crossentropy: 3.0914\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 3522.4365 - accuracy: 0.9800 - categorical_crossentropy: 0.0318\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.3989 - accuracy: 0.6312 - categorical_crossentropy: 2.9940\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 3522.5608 - accuracy: 0.9600 - categorical_crossentropy: 0.1559\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3525.3601 - accuracy: 0.6234 - categorical_crossentropy: 2.9550\n",
            "1/1 [==============================] - 8s 8s/step - loss: 35224.0312 - accuracy: 1.0000 - categorical_crossentropy: 4.3785e-04\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 35224.0039 - accuracy: 0.6940 - categorical_crossentropy: 2.2720\n",
            "Tue Apr 25 09:53:22 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   77C    P0    38W /  70W |   8989MiB / 15360MiB |     55%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 86%|████████▌ | 6/7 [10:39<01:46, 106.44s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 3019.8442 - accuracy: 0.9800 - categorical_crossentropy: 0.0289\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.9604 - accuracy: 0.6253 - categorical_crossentropy: 3.1448\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3019.8608 - accuracy: 0.9800 - categorical_crossentropy: 0.0453\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.7090 - accuracy: 0.6372 - categorical_crossentropy: 2.8933\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 3020.0886 - accuracy: 0.9400 - categorical_crossentropy: 0.2735\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3023.0234 - accuracy: 0.6170 - categorical_crossentropy: 3.2081\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 3019.8638 - accuracy: 0.9900 - categorical_crossentropy: 0.0483\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.9307 - accuracy: 0.6372 - categorical_crossentropy: 3.1146\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3019.8989 - accuracy: 0.9700 - categorical_crossentropy: 0.0837\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.9534 - accuracy: 0.6240 - categorical_crossentropy: 3.1384\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3019.8630 - accuracy: 0.9800 - categorical_crossentropy: 0.0477\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.8909 - accuracy: 0.6311 - categorical_crossentropy: 3.0762\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 3019.9167 - accuracy: 0.9800 - categorical_crossentropy: 0.1015\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.8550 - accuracy: 0.6308 - categorical_crossentropy: 3.0398\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 3019.9409 - accuracy: 0.9500 - categorical_crossentropy: 0.1255\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.9871 - accuracy: 0.6262 - categorical_crossentropy: 3.1714\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 3019.8442 - accuracy: 0.9900 - categorical_crossentropy: 0.0287\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.8977 - accuracy: 0.6334 - categorical_crossentropy: 3.0820\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 3019.9436 - accuracy: 0.9600 - categorical_crossentropy: 0.1285\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3022.8406 - accuracy: 0.6276 - categorical_crossentropy: 3.0244\n",
            "1/1 [==============================] - 8s 8s/step - loss: 30198.1445 - accuracy: 1.0000 - categorical_crossentropy: 2.6410e-04\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 30198.1660 - accuracy: 0.6974 - categorical_crossentropy: 2.3468\n",
            "Tue Apr 25 09:55:05 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   75C    P0    36W /  70W |   8991MiB / 15360MiB |     32%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 7/7 [12:22<00:00, 106.09s/it]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "15\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  0%|          | 0/7 [00:00<?, ?it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 21165.5371 - accuracy: 0.9800 - categorical_crossentropy: 0.1516\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21189.8066 - accuracy: 0.6174 - categorical_crossentropy: 24.4193\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 21165.5039 - accuracy: 0.9800 - categorical_crossentropy: 0.1187\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21187.7617 - accuracy: 0.6292 - categorical_crossentropy: 22.3726\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 21166.6621 - accuracy: 0.9500 - categorical_crossentropy: 1.2770\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21190.5918 - accuracy: 0.6093 - categorical_crossentropy: 25.2034\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 21165.5488 - accuracy: 0.9900 - categorical_crossentropy: 0.1627\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21189.7285 - accuracy: 0.6250 - categorical_crossentropy: 24.3402\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 21166.0039 - accuracy: 0.9700 - categorical_crossentropy: 0.6202\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21190.0898 - accuracy: 0.6124 - categorical_crossentropy: 24.6975\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 21166.1113 - accuracy: 0.9400 - categorical_crossentropy: 0.7242\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21189.4082 - accuracy: 0.6248 - categorical_crossentropy: 24.0242\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 21166.1348 - accuracy: 0.9800 - categorical_crossentropy: 0.7502\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21189.2500 - accuracy: 0.6199 - categorical_crossentropy: 23.8702\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 21166.3965 - accuracy: 0.9500 - categorical_crossentropy: 1.0129\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21190.2754 - accuracy: 0.6143 - categorical_crossentropy: 24.8871\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 21165.9336 - accuracy: 0.9700 - categorical_crossentropy: 0.5510\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21189.4629 - accuracy: 0.6246 - categorical_crossentropy: 24.0773\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 21166.2246 - accuracy: 0.9700 - categorical_crossentropy: 0.8397\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 21189.0918 - accuracy: 0.6230 - categorical_crossentropy: 23.7108\n",
            "1/1 [==============================] - 9s 9s/step - loss: 211653.8906 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 211653.9688 - accuracy: 0.6931 - categorical_crossentropy: 18.0656\n",
            "Tue Apr 25 09:56:51 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   75C    P0    38W /  70W |   8991MiB / 15360MiB |     56%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 14%|█▍        | 1/7 [01:45<10:34, 105.81s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.0\n",
            "1/1 [==============================] - 1s 1s/step - loss: 10621.7871 - accuracy: 0.9800 - categorical_crossentropy: 0.1470\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10649.3721 - accuracy: 0.6356 - categorical_crossentropy: 27.7310\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 10621.6396 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10647.4912 - accuracy: 0.6445 - categorical_crossentropy: 25.8519\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 10622.6562 - accuracy: 0.9500 - categorical_crossentropy: 1.0161\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10649.9863 - accuracy: 0.6206 - categorical_crossentropy: 28.3474\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 10621.7227 - accuracy: 0.9900 - categorical_crossentropy: 0.0835\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10649.3252 - accuracy: 0.6347 - categorical_crossentropy: 27.6877\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 10622.5322 - accuracy: 0.9700 - categorical_crossentropy: 0.8925\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10649.5342 - accuracy: 0.6304 - categorical_crossentropy: 27.8957\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 10622.2949 - accuracy: 0.9400 - categorical_crossentropy: 0.6548\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10648.6846 - accuracy: 0.6375 - categorical_crossentropy: 27.0474\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 10621.9473 - accuracy: 0.9900 - categorical_crossentropy: 0.3072\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10648.5059 - accuracy: 0.6354 - categorical_crossentropy: 26.8646\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 10621.9326 - accuracy: 0.9700 - categorical_crossentropy: 0.2931\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10649.8926 - accuracy: 0.6300 - categorical_crossentropy: 28.2517\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 10622.3438 - accuracy: 0.9800 - categorical_crossentropy: 0.7047\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10649.0791 - accuracy: 0.6366 - categorical_crossentropy: 27.4363\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 10622.0059 - accuracy: 0.9800 - categorical_crossentropy: 0.3667\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 10648.1553 - accuracy: 0.6372 - categorical_crossentropy: 26.5181\n",
            "1/1 [==============================] - 8s 8s/step - loss: 106216.4531 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 106216.3906 - accuracy: 0.6975 - categorical_crossentropy: 21.0807\n",
            "Tue Apr 25 09:58:37 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   74C    P0    37W /  70W |   8991MiB / 15360MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 29%|██▊       | 2/7 [03:31<08:49, 105.93s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 7109.9805 - accuracy: 1.0000 - categorical_crossentropy: 0.0050\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7138.8154 - accuracy: 0.6399 - categorical_crossentropy: 28.8399\n",
            "1/1 [==============================] - 0s 27ms/step - loss: 7109.9756 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7136.9541 - accuracy: 0.6488 - categorical_crossentropy: 26.9781\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 7111.0605 - accuracy: 0.9500 - categorical_crossentropy: 1.0849\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7139.2412 - accuracy: 0.6260 - categorical_crossentropy: 29.2668\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 7110.0361 - accuracy: 0.9900 - categorical_crossentropy: 0.0608\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7138.7300 - accuracy: 0.6394 - categorical_crossentropy: 28.7533\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 7110.8579 - accuracy: 0.9700 - categorical_crossentropy: 0.8835\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7139.0898 - accuracy: 0.6339 - categorical_crossentropy: 29.1134\n",
            "1/1 [==============================] - 0s 29ms/step - loss: 7110.1299 - accuracy: 0.9800 - categorical_crossentropy: 0.1546\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7138.0586 - accuracy: 0.6417 - categorical_crossentropy: 28.0840\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 7109.9785 - accuracy: 1.0000 - categorical_crossentropy: 0.0033\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7137.9258 - accuracy: 0.6445 - categorical_crossentropy: 27.9510\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 7110.4448 - accuracy: 0.9600 - categorical_crossentropy: 0.4704\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7139.2363 - accuracy: 0.6375 - categorical_crossentropy: 29.2609\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 7110.1748 - accuracy: 0.9900 - categorical_crossentropy: 0.1995\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7138.4629 - accuracy: 0.6418 - categorical_crossentropy: 28.4860\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 7110.2173 - accuracy: 0.9900 - categorical_crossentropy: 0.2429\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 7137.2354 - accuracy: 0.6474 - categorical_crossentropy: 27.2597\n",
            "1/1 [==============================] - 8s 8s/step - loss: 71099.7891 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 71099.8594 - accuracy: 0.6978 - categorical_crossentropy: 22.2098\n",
            "Tue Apr 25 10:00:22 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   75C    P0    37W /  70W |   8993MiB / 15360MiB |     57%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 43%|████▎     | 3/7 [05:16<07:01, 105.43s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.0\n",
            "1/1 [==============================] - 1s 1s/step - loss: 5354.4766 - accuracy: 1.0000 - categorical_crossentropy: 5.8042e-05\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5383.8237 - accuracy: 0.6422 - categorical_crossentropy: 29.3466\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 5354.4766 - accuracy: 1.0000 - categorical_crossentropy: 1.1921e-09\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5381.9761 - accuracy: 0.6525 - categorical_crossentropy: 27.4987\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 5356.1118 - accuracy: 0.9500 - categorical_crossentropy: 1.6347\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5384.3721 - accuracy: 0.6293 - categorical_crossentropy: 29.8955\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 5354.4766 - accuracy: 1.0000 - categorical_crossentropy: 5.6952e-05\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5383.6558 - accuracy: 0.6460 - categorical_crossentropy: 29.1801\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 5355.6011 - accuracy: 0.9700 - categorical_crossentropy: 1.1241\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5384.1118 - accuracy: 0.6365 - categorical_crossentropy: 29.6350\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 5354.7373 - accuracy: 0.9700 - categorical_crossentropy: 0.2611\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5383.1128 - accuracy: 0.6430 - categorical_crossentropy: 28.6356\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 5354.4766 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5383.1221 - accuracy: 0.6466 - categorical_crossentropy: 28.6445\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 5355.5005 - accuracy: 0.9500 - categorical_crossentropy: 1.0237\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5384.4053 - accuracy: 0.6350 - categorical_crossentropy: 29.9275\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 5354.7798 - accuracy: 0.9900 - categorical_crossentropy: 0.3029\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5383.7964 - accuracy: 0.6429 - categorical_crossentropy: 29.3190\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 5354.8149 - accuracy: 0.9800 - categorical_crossentropy: 0.3376\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 5382.4771 - accuracy: 0.6454 - categorical_crossentropy: 27.9991\n",
            "1/1 [==============================] - 10s 10s/step - loss: 53544.7656 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 53544.8047 - accuracy: 0.6971 - categorical_crossentropy: 22.7353\n",
            "Tue Apr 25 10:02:08 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   75C    P0    37W /  70W |   8993MiB / 15360MiB |     53%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 57%|█████▋    | 4/7 [07:02<05:16, 105.63s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 4301.1797 - accuracy: 1.0000 - categorical_crossentropy: 0.0029\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4330.8140 - accuracy: 0.6389 - categorical_crossentropy: 29.6367\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 4301.1768 - accuracy: 1.0000 - categorical_crossentropy: 3.7545e-05\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4328.7646 - accuracy: 0.6539 - categorical_crossentropy: 27.5871\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 4303.1021 - accuracy: 0.9500 - categorical_crossentropy: 1.9254\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4331.2520 - accuracy: 0.6295 - categorical_crossentropy: 30.0729\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4301.3115 - accuracy: 0.9900 - categorical_crossentropy: 0.1348\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4330.3242 - accuracy: 0.6470 - categorical_crossentropy: 29.1474\n",
            "1/1 [==============================] - 0s 19ms/step - loss: 4302.2197 - accuracy: 0.9700 - categorical_crossentropy: 1.0430\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4330.9180 - accuracy: 0.6358 - categorical_crossentropy: 29.7415\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4301.5923 - accuracy: 0.9600 - categorical_crossentropy: 0.4154\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4329.9038 - accuracy: 0.6440 - categorical_crossentropy: 28.7270\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 4301.1768 - accuracy: 1.0000 - categorical_crossentropy: 1.6332e-07\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4330.0713 - accuracy: 0.6467 - categorical_crossentropy: 28.8933\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 4301.7139 - accuracy: 0.9600 - categorical_crossentropy: 0.5368\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4331.3862 - accuracy: 0.6366 - categorical_crossentropy: 30.2097\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 4301.5469 - accuracy: 0.9800 - categorical_crossentropy: 0.3698\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4330.6621 - accuracy: 0.6415 - categorical_crossentropy: 29.4842\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 4301.2095 - accuracy: 0.9800 - categorical_crossentropy: 0.0324\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 4329.3325 - accuracy: 0.6434 - categorical_crossentropy: 28.1560\n",
            "1/1 [==============================] - 9s 9s/step - loss: 43011.7539 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 43011.7578 - accuracy: 0.6929 - categorical_crossentropy: 23.0426\n",
            "Tue Apr 25 10:03:56 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   76C    P0    37W /  70W |   8993MiB / 15360MiB |     57%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 71%|███████▏  | 5/7 [08:50<03:32, 106.34s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.0\n",
            "1/1 [==============================] - 1s 1s/step - loss: 3598.5093 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3627.7266 - accuracy: 0.6448 - categorical_crossentropy: 29.2175\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3598.5425 - accuracy: 0.9900 - categorical_crossentropy: 0.0335\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3626.2097 - accuracy: 0.6525 - categorical_crossentropy: 27.7003\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3599.8508 - accuracy: 0.9500 - categorical_crossentropy: 1.3417\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3628.3113 - accuracy: 0.6354 - categorical_crossentropy: 29.8017\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3598.5093 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3627.2393 - accuracy: 0.6521 - categorical_crossentropy: 28.7300\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3598.9419 - accuracy: 0.9700 - categorical_crossentropy: 0.4326\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3628.0361 - accuracy: 0.6385 - categorical_crossentropy: 29.5270\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3598.6182 - accuracy: 0.9800 - categorical_crossentropy: 0.1092\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3627.0674 - accuracy: 0.6510 - categorical_crossentropy: 28.5585\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3598.5093 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3627.1299 - accuracy: 0.6514 - categorical_crossentropy: 28.6201\n",
            "1/1 [==============================] - 0s 20ms/step - loss: 3598.5742 - accuracy: 0.9900 - categorical_crossentropy: 0.0653\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3628.3025 - accuracy: 0.6424 - categorical_crossentropy: 29.7929\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3598.6318 - accuracy: 0.9900 - categorical_crossentropy: 0.1225\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3627.7461 - accuracy: 0.6467 - categorical_crossentropy: 29.2357\n",
            "1/1 [==============================] - 0s 24ms/step - loss: 3598.5193 - accuracy: 0.9900 - categorical_crossentropy: 0.0102\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3626.3198 - accuracy: 0.6506 - categorical_crossentropy: 27.8115\n",
            "1/1 [==============================] - 8s 8s/step - loss: 35985.0781 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 1s 5ms/step - loss: 35985.0547 - accuracy: 0.6955 - categorical_crossentropy: 22.9028\n",
            "Tue Apr 25 10:05:45 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   76C    P0    37W /  70W |   8995MiB / 15360MiB |     10%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 86%|████████▌ | 6/7 [10:39<01:47, 107.33s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3.5\n",
            "1/1 [==============================] - 1s 1s/step - loss: 3096.3062 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3124.9592 - accuracy: 0.6452 - categorical_crossentropy: 28.6536\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 3096.3083 - accuracy: 1.0000 - categorical_crossentropy: 0.0020\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3123.5205 - accuracy: 0.6548 - categorical_crossentropy: 27.2134\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 3097.5322 - accuracy: 0.9400 - categorical_crossentropy: 1.2258\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3125.4272 - accuracy: 0.6390 - categorical_crossentropy: 29.1209\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3096.3062 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3124.4431 - accuracy: 0.6562 - categorical_crossentropy: 28.1372\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3097.0742 - accuracy: 0.9700 - categorical_crossentropy: 0.7678\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3125.1277 - accuracy: 0.6437 - categorical_crossentropy: 28.8216\n",
            "1/1 [==============================] - 0s 30ms/step - loss: 3096.4053 - accuracy: 0.9900 - categorical_crossentropy: 0.0991\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3124.2102 - accuracy: 0.6539 - categorical_crossentropy: 27.9036\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 3096.3062 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3124.1252 - accuracy: 0.6553 - categorical_crossentropy: 27.8192\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 3096.3462 - accuracy: 0.9900 - categorical_crossentropy: 0.0399\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3125.5503 - accuracy: 0.6453 - categorical_crossentropy: 29.2439\n",
            "1/1 [==============================] - 0s 18ms/step - loss: 3096.5330 - accuracy: 0.9900 - categorical_crossentropy: 0.2268\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3124.9880 - accuracy: 0.6501 - categorical_crossentropy: 28.6808\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3096.3081 - accuracy: 1.0000 - categorical_crossentropy: 0.0018\n",
            "100/100 [==============================] - 0s 3ms/step - loss: 3123.4578 - accuracy: 0.6534 - categorical_crossentropy: 27.1507\n",
            "1/1 [==============================] - 9s 9s/step - loss: 30963.0547 - accuracy: 1.0000 - categorical_crossentropy: 0.0000e+00\n",
            "100/100 [==============================] - 0s 5ms/step - loss: 30963.0195 - accuracy: 0.6964 - categorical_crossentropy: 22.3851\n",
            "Tue Apr 25 10:07:30 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   76C    P0    36W /  70W |   8995MiB / 15360MiB |     40%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 7/7 [12:24<00:00, 106.42s/it]\n"
          ]
        }
      ],
      "source": [
        "model_variational=[]\n",
        "for prior_scale in prior_scales: \n",
        "  for gamma in gammas:\n",
        "    for label_noise in label_noises:\n",
        "      for da in d_augmentation:\n",
        "\n",
        "        setting_name=f\"prior_scale_{prior_scale}_label_noise_{label_noise}_smooth_softmax_{gamma}_data_augmentation_{da}_MLP_100\" \n",
        "\n",
        "        x_train = x_train_mnist\n",
        "        y_train = y_train_mnist\n",
        "        x_test = x_test_mnist\n",
        "        y_test = y_test_mnist\n",
        "\n",
        "        input_shape=x_train[0].shape\n",
        "\n",
        "        if label_noise > 0.0:\n",
        "          import random\n",
        "          # train\n",
        "          ind_noisy=random.sample(range(1, y_train.shape[0]), int(y_train.shape[0]*label_noise))\n",
        "          for ind in ind_noisy:\n",
        "            false_label=random.sample(range(1,y_train.shape[1]),1)[0]\n",
        "            while y_train[ind, false_label] != 1.0:\n",
        "              temp=np.zeros((y_train.shape[1]))\n",
        "              temp[false_label]=1.0\n",
        "              y_train[ind, :]=temp\n",
        "              break\n",
        "            else:\n",
        "              false_label=random.sample(range(1,y_train.shape[1]),1)[0]\n",
        "          # test\n",
        "          ind_noisy=random.sample(range(1, y_test.shape[0]), int(y_test.shape[0]*label_noise))\n",
        "          for ind in ind_noisy:\n",
        "            false_label=random.sample(range(1,y_test.shape[1]),1)[0]\n",
        "            while y_test[ind, false_label] != 1.0:\n",
        "              temp=np.zeros((y_test.shape[1]))\n",
        "              temp[false_label]=1.0\n",
        "              y_test[ind, :]=temp\n",
        "              break\n",
        "            else:\n",
        "              false_label=random.sample(range(1,y_test.shape[1]),1)[0]\n",
        "\n",
        "        from tensorflow import keras\n",
        "        from tensorflow.keras import layers\n",
        "\n",
        "        if da == \"None\":\n",
        "\n",
        "            train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "            train_loader = train_loader.batch(batch_size)\n",
        "\n",
        "        if da == \"Standard\":\n",
        "            # Create a data augmentation stage with horizontal flipping, rotations, zooms\n",
        "            data_augmentation = keras.Sequential(\n",
        "                [\n",
        "                    layers.RandomFlip(\"horizontal\"),\n",
        "                    layers.RandomRotation(0.1),\n",
        "                    layers.RandomZoom(0.1),\n",
        "                ]\n",
        "            )\n",
        "            # Create a tf.data pipeline of augmented images (and their labels)\n",
        "            train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "            train_loader = train_loader.batch(batch_size).map(lambda x, y: (data_augmentation(x), y))\n",
        "\n",
        "\n",
        "        if da == \"Noise\":\n",
        "\n",
        "            rng = np.random.default_rng(1234)\n",
        "            n_features = x_train.shape[1]*x_train.shape[2]*x_train.shape[3]\n",
        "            idx = np.random.choice(np.arange(0, n_features, 1), n_features // 5 * 4, replace = False)\n",
        "            perm = rng.permutation(idx)\n",
        "\n",
        "            aux =  x_train.reshape(x_train.shape[0], -1)\n",
        "            aux[:, idx] = aux[:, perm]\n",
        "\n",
        "            x_train = aux.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], x_train.shape[3])\n",
        "\n",
        "            # Create a tf.data pipeline of augmented images (and their labels)\n",
        "            train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "            train_loader = train_loader.batch(batch_size)\n",
        "\n",
        "        if da == \"Replacement\":\n",
        "\n",
        "            rng = np.random.default_rng(1234)\n",
        "\n",
        "            idx = np.random.choice(np.arange(0, x_train.shape[0], 1), int(x_train.shape[0]*0.05), replace= False)\n",
        "            idx = np.random.choice(idx, x_train.shape[0], replace= True)\n",
        "\n",
        "            # Create a tf.data pipeline of augmented images (and their labels)\n",
        "            train_loader = tf.data.Dataset.from_tensor_slices((x_train[idx], y_train[idx]))\n",
        "            train_loader = train_loader.batch(batch_size)\n",
        "        \n",
        "        \n",
        "        input_shape = x_train.shape[1:]\n",
        "        results={}\n",
        "\n",
        "        for seed in seeds:\n",
        "          print(seed)\n",
        "          results[seed]={} # a dict, with the structure, dict_results[lamb]=[log_ps_train, log_ps_test, metrics_bma]\n",
        "          for lamb in tqdm(lambs):\n",
        "            \n",
        "            print(lamb)\n",
        "            results[seed][lamb]=[]\n",
        "\n",
        "            tf.keras.utils.set_random_seed(seed)\n",
        "\n",
        "            #model_variational = LeNet_Variational_Large(input_shape, num_classes,\n",
        "            #                                          kernel_prior_fn=make_normal_prior(prior_scale),\n",
        "            #                                          kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p)/(x_train.shape[0]*lamb),\n",
        "            #                                          gamma=gamma)\n",
        "\n",
        "            model_variational = MLP_Variational(input_shape, num_classes,512,\n",
        "                                                      kernel_prior_fn=make_normal_prior(prior_scale),\n",
        "                                                      kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p)/(x_train.shape[0]*lamb),\n",
        "                                                      gamma=gamma)\n",
        "            \n",
        "            model_variational.compile(optimizer=tf.keras.optimizers.Adam(lr),\n",
        "                            loss = cce,\n",
        "                            metrics=['accuracy', cce]\n",
        "                            )\n",
        "\n",
        "            model_variational.fit(train_loader, \n",
        "                    epochs=n_epochs, \n",
        "                    #validation_data = (x_test, y_test),\n",
        "                    # workers=8,\n",
        "                    verbose = 0)\n",
        "            #model_variational.summary()\n",
        "\n",
        "\n",
        "            log_ps_train=[]\n",
        "            log_ps_test=[]\n",
        "            metrics_gibbs_train=[]\n",
        "            metrics_gibbs_test=[]\n",
        "            metrics_bayes_train=[]\n",
        "            metrics_bayes_test=[]\n",
        "\n",
        "            # loop through posterior samples\n",
        "            for i in range(num_posterior_samples):\n",
        "              \n",
        "              # compute the categorical dist. and log_p for train data\n",
        "              x_train_augmented = np.concatenate([x for x, y in train_loader.as_numpy_iterator()])\n",
        "              y_train_augmented = np.concatenate([y for x, y in train_loader.as_numpy_iterator()])\n",
        "              p_train=tfp.distributions.Categorical(logits=model_variational(x_train_augmented))\n",
        "              log_p_train=p_train.log_prob(tf.argmax(y_train_augmented, axis=1))\n",
        "\n",
        "              log_ps_train.append(log_p_train)\n",
        "\n",
        "              # compute the categorical dist. and log_p for test data\n",
        "              p_test=tfp.distributions.Categorical(logits=model_variational(x_test)) \n",
        "              log_p_test=p_test.log_prob(tf.argmax(y_test, axis=1))\n",
        "              log_ps_test.append(log_p_test)\n",
        "\n",
        "\n",
        "\n",
        "              # compute gibbs-based metrics\n",
        "              metrics_gibbs_train.append(model_variational.evaluate(x_train, y_train, batch_size=batch_size))\n",
        "              metrics_gibbs_test.append(model_variational.evaluate(x_test, y_test, batch_size=batch_size))\n",
        "\n",
        "            # compute bayes-based metric\n",
        "            bma_model=BMA_Model(model_variational, num_posterior_samples)\n",
        "            bma_model.compile(metrics=['accuracy', cce])\n",
        "            metrics_bayes_train.append(bma_model.evaluate(x_train, y_train, batch_size=batch_size))\n",
        "            metrics_bayes_test.append(bma_model.evaluate(x_test, y_test, batch_size=batch_size))\n",
        "\n",
        "            !nvidia-smi\n",
        "            \n",
        "            # save log_p for train data\n",
        "            results[seed][lamb].append(tf.stack(log_ps_train,axis=1).numpy())\n",
        "            # save log_p for test data \n",
        "            results[seed][lamb].append(tf.stack(log_ps_test,axis=1).numpy())\n",
        "            # save gibbs-based metric\n",
        "            results[seed][lamb].append(metrics_gibbs_train)\n",
        "            results[seed][lamb].append(metrics_gibbs_test)\n",
        "            # save bayes-based metric\n",
        "            results[seed][lamb].append(metrics_bayes_train)\n",
        "            results[seed][lamb].append(metrics_bayes_test)\n",
        "\n",
        "        with open(f'{setting_name}.pickle', 'wb') as handle:\n",
        "          pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MjOWNM9Cla0u",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# if if_da:\n",
        "#   from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
        "#   datagen = ImageDataGenerator(\n",
        "#           rotation_range=10,  # Rotate the image randomly by up to 10 degrees\n",
        "#           zoom_range=0.1,     # Zoom in or out of the image randomly by up to 10%\n",
        "#           width_shift_range=0.1,  # Shift the image horizontally by up to 10%\n",
        "#           height_shift_range=0.1, # Shift the image vertically by up to 10%\n",
        "#           )\n",
        "#   datagen.fit(x_train)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TB60tuuQzZ3q",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# if if_da:\n",
        "#   from tensorflow import keras\n",
        "#   from tensorflow.keras import layers\n",
        "\n",
        "#   # Create a data augmentation stage with horizontal flipping, rotations, zooms\n",
        "#   data_augmentation = keras.Sequential(\n",
        "#       [\n",
        "#           layers.RandomFlip(\"horizontal\"),\n",
        "#           layers.RandomRotation(0.1),\n",
        "#           layers.RandomZoom(0.1),\n",
        "#       ]\n",
        "#   )\n",
        "\n",
        "#   input_shape = x_train.shape[1:]\n",
        "#   classes = 10\n",
        "\n",
        "#   # Create a tf.data pipeline of augmented images (and their labels)\n",
        "#   train_augmented = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "#   train_augmented = train_augmented.batch(batch_size).map(lambda x, y: (data_augmentation(x), y))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7bAlp9m5Zg9q",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# !nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KJEY_ldvaQ83",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# files.download('prior_scale_1.0_label_noise_0.0_smooth_softmax_3.0_data_augmentation_False.pickle') "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "itHnSRJmgOXp",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# from google.colab import drive\n",
        "# drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sqmjiCxDXxUO",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# from google.colab import runtime\n",
        "# runtime.unassign()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "include_colab_link": true,
      "machine_shape": "hm",
      "provenance": []
    },
    "gpuClass": "premium",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
