{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [
        {
          "file_id": "16iux6HSF_fav3r88rMsMFkUglLA3MYSj",
          "timestamp": 1674663198394
        },
        {
          "file_id": "1kypcF3svum8PeZelkYZlDj2prLjpbV5o",
          "timestamp": 1665423411412
        },
        {
          "file_id": "1XTDIiAvYlr3PgACkrcYcM7ARUi1bulC0",
          "timestamp": 1665135649078
        },
        {
          "file_id": "1CbVFp70yBaRE8id2tBSsAgF2gTO9DkRo",
          "timestamp": 1660275155859
        }
      ],
      "last_runtime": {
        "build_target": "XXXX",
        "kind": "private"
      },
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "tf.random.set_seed(103847532)"
      ],
      "metadata": {
        "id": "YBKtIwH996tK",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841785,
          "user_tz": 300,
          "elapsed": 1,
          "user": {
            "displayName": "XXX",
            "userId": "0000000"
          }
        }
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Model Definitions"
      ],
      "metadata": {
        "id": "BwiCChDUzGrS"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "tlArshXvkqEv",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841786,
          "user_tz": 300,
          "elapsed": 2,
          "user": {
            "displayName": "XXX",
            "userId": "0000"
          }
        }
      },
      "outputs": [],
      "source": [
        "\"\"\"A collection of Neural Net models to be used for experiments.\"\"\"\n",
        "\n",
        "from typing import Callable, Optional\n",
        "\n",
        "from absl import app\n",
        "\n",
        "\n",
        "def _get_mobilenet(depth_multiplier, num_classes, width_multiplier=1.0):\n",
        "\n",
        "  \"\"\"Loads mobilenet.\n",
        "\n",
        "  Args:\n",
        "    depth_multiplier: the depth_multiplier parameter for mobilenet\n",
        "    num_classes: the number of classes\n",
        "    width_multiplier: the width_multiplier for mobilenet\n",
        "\n",
        "  Returns:\n",
        "    Returns the mobilenet model.\n",
        "  \"\"\"\n",
        "  model = tf.keras.applications.mobilenet.MobileNet(\n",
        "      input_shape=(32, 32, 3),\n",
        "      alpha=width_multiplier,\n",
        "      depth_multiplier=depth_multiplier,\n",
        "      dropout=0.001,\n",
        "      include_top=True,\n",
        "      weights=None,\n",
        "      input_tensor=None,\n",
        "      pooling=None,\n",
        "      classes=num_classes,\n",
        "      classifier_activation='softmax')\n",
        "  return model\n",
        "\n",
        "\n",
        "# Implementation of ResNet models that can be found in\n",
        "# https://keras.io/zh/examples/cifar10_resnet/\n",
        "def _resnet_layer(inputs,\n",
        "                  num_filters=16,\n",
        "                  kernel_size=3,\n",
        "                  strides=1,\n",
        "                  activation='relu',\n",
        "                  batch_normalization=True,\n",
        "                  conv_first=True):\n",
        "  \"\"\"2D Convolution-Batch Normalization-Activation stack builder.\n",
        "\n",
        "  Args:\n",
        "    inputs: input tensor from input image or previous layer\n",
        "    num_filters: Conv2D number of filters\n",
        "    kernel_size: Conv2D square kernel dimensions\n",
        "    strides: Conv2D square stride dimensions\n",
        "    activation: activation name\n",
        "    batch_normalization: whether to include batch normalization\n",
        "    conv_first: conv-bn-activation (True) or bn-activation-conv (False)\n",
        "\n",
        "  Returns:\n",
        "   x: tensor as input to the next layer\n",
        "  \"\"\"\n",
        "\n",
        "  conv = tf.keras.layers.Conv2D(\n",
        "      num_filters,\n",
        "      kernel_size=kernel_size,\n",
        "      strides=strides,\n",
        "      padding='same',\n",
        "      kernel_initializer='he_normal',\n",
        "      kernel_regularizer=tf.keras.regularizers.l2(1e-4))\n",
        "  x = inputs\n",
        "  if conv_first:\n",
        "    x = conv(x)\n",
        "    if batch_normalization:\n",
        "      x = tf.keras.layers.BatchNormalization()(x)\n",
        "    if activation is not None:\n",
        "      x = tf.keras.layers.Activation(activation)(x)\n",
        "  else:\n",
        "    if batch_normalization:\n",
        "      x = tf.keras.layers.BatchNormalization()(x)\n",
        "    if activation is not None:\n",
        "      x = tf.keras.layers.Activation(activation)(x)\n",
        "    x = conv(x)\n",
        "  return x\n",
        "\n",
        "\n",
        "def _resnet_v2(\n",
        "    input_shape=(32, 32, 3), depth=29, num_classes=10, data_augmentation=False):\n",
        "\n",
        "  \"\"\"ResNet Version 2 Model builder.\n",
        "\n",
        "  Args:\n",
        "    input_shape: shape of input image tensor\n",
        "    depth: number of core convolutional layers\n",
        "    num_classes: number of classes\n",
        "    data_augmentation: A boolean variable that determines whether we use data\n",
        "      augmentation or not\n",
        "\n",
        "  Returns:\n",
        "    model (Model): Keras model instance\n",
        "  \"\"\"\n",
        "\n",
        "  if (depth - 2) % 9 != 0:\n",
        "    raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\n",
        "\n",
        "  # Start model definition.\n",
        "  num_filters_in = 16\n",
        "  num_res_blocks = int((depth - 2) / 9)\n",
        "  inputs = tf.keras.layers.Input(shape=input_shape)\n",
        "\n",
        "  # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\n",
        "  if data_augmentation:\n",
        "    data_augmentation_module = tf.keras.Sequential([\n",
        "        tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),\n",
        "        tf.keras.layers.experimental.preprocessing.RandomTranslation(\n",
        "            3. / 32, 3. / 32),\n",
        "    ])\n",
        "    x = data_augmentation_module(inputs)\n",
        "    x = _resnet_layer(\n",
        "        inputs=x, num_filters=num_filters_in, conv_first=True)\n",
        "  else:\n",
        "    x = _resnet_layer(\n",
        "        inputs=inputs, num_filters=num_filters_in, conv_first=True)\n",
        "\n",
        "  # Instantiate the stack of residual units\n",
        "  for stage in range(3):\n",
        "    for res_block in range(num_res_blocks):\n",
        "      activation = 'relu'\n",
        "      batch_normalization = True\n",
        "      strides = 1\n",
        "      if stage == 0:\n",
        "        num_filters_out = num_filters_in * 4\n",
        "        if res_block == 0:  # first layer and first stage\n",
        "          activation = None\n",
        "          batch_normalization = False\n",
        "      else:\n",
        "        num_filters_out = num_filters_in * 2\n",
        "        if res_block == 0:  # first layer but not first stage\n",
        "          strides = 2  # downsample\n",
        "\n",
        "      # bottleneck residual unit\n",
        "      y = _resnet_layer(\n",
        "          inputs=x,\n",
        "          num_filters=num_filters_in,\n",
        "          kernel_size=1,\n",
        "          strides=strides,\n",
        "          activation=activation,\n",
        "          batch_normalization=batch_normalization,\n",
        "          conv_first=False)\n",
        "      y = _resnet_layer(inputs=y, num_filters=num_filters_in, conv_first=False)\n",
        "      y = _resnet_layer(\n",
        "          inputs=y,\n",
        "          num_filters=num_filters_out,\n",
        "          kernel_size=1,\n",
        "          conv_first=False)\n",
        "      if res_block == 0:\n",
        "        # linear projection residual shortcut connection to match\n",
        "        # changed dims\n",
        "        x = _resnet_layer(\n",
        "            inputs=x,\n",
        "            num_filters=num_filters_out,\n",
        "            kernel_size=1,\n",
        "            strides=strides,\n",
        "            activation=None,\n",
        "            batch_normalization=False)\n",
        "      x = tf.keras.layers.add([x, y])\n",
        "    num_filters_in = num_filters_out\n",
        "\n",
        "  # Add classifier on top.\n",
        "  # v2 has BN-ReLU before Pooling\n",
        "  x = tf.keras.layers.BatchNormalization()(x)\n",
        "  x = tf.keras.layers.Activation('relu')(x)\n",
        "  x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)\n",
        "  y = tf.keras.layers.Flatten()(x)\n",
        "  outputs = tf.keras.layers.Dense(\n",
        "      num_classes, activation='softmax', kernel_initializer='he_normal')(\n",
        "          y)\n",
        "\n",
        "  # Instantiate model.\n",
        "  model = tf.keras.models.Model(inputs=inputs, outputs=outputs)\n",
        "  return model\n",
        "\n",
        "\n",
        "def load_model(model_architecture: str,\n",
        "               num_classes: int,\n",
        "               optimizer_name: str,\n",
        "               loss_function: Optional[Callable[\n",
        "                   ..., tf.Tensor]] = tf.keras.losses.CategoricalCrossentropy(),\n",
        "               learning_rate: Optional[float] = 0.001,\n",
        "               width_multiplier: Optional[float] = 1.0,\n",
        "               depth_multiplier: Optional[int] = 1,\n",
        "               resnet_depth: Optional[int] = 11) -> tf.keras.Model:\n",
        "\n",
        "  \"\"\"Loads a compiled model.\n",
        "\n",
        "  Args:\n",
        "    model_architecture: The name of the model architecture.\n",
        "    num_classes: The number of classes.\n",
        "    optimizer_name: The name of the optimizer we use.\n",
        "    loss_function: The loss function used in distillation.\n",
        "    learning_rate: The initial learning rate for our optimizer.\n",
        "    width_multiplier: The width_multiplier of the base_CNN/mobilenet model.\n",
        "    depth_multiplier: The depth_multiplier of the mobilenet model.\n",
        "    resnet_depth: The depth of the resnet network which must be of the form 9*d\n",
        "      + 2 for some integer d, e.g., 11, 20.\n",
        "  Returns:\n",
        "    A compiled model to be trained.\n",
        "  \"\"\"\n",
        "\n",
        "  optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
        "\n",
        "  if model_architecture == 'mobilenet':\n",
        "    model = _get_mobilenet(\n",
        "        depth_multiplier=depth_multiplier,\n",
        "        num_classes=num_classes,\n",
        "        width_multiplier=width_multiplier)\n",
        "  elif model_architecture == 'resnet':\n",
        "    model = _resnet_v2(depth=resnet_depth, num_classes=num_classes)\n",
        "  else:\n",
        "    raise app.UsageError('Invalid argument to --model')\n",
        "\n",
        "  if optimizer_name == 'adam':\n",
        "    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
        "  elif optimizer_name == 'adagrad':\n",
        "    optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate)\n",
        "  elif optimizer_name == 'SGD':\n",
        "    optimizer = tf.keras.optimizers.SGD(\n",
        "        learning_rate=learning_rate, momentum=0.2, nesterov=False)\n",
        "  else:\n",
        "    print('Default settings: Adam')\n",
        "    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
        "\n",
        "  model.compile(\n",
        "      loss=loss_function, optimizer=optimizer, metrics='categorical_accuracy')\n",
        "\n",
        "  return model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Data Utils"
      ],
      "metadata": {
        "id": "dMwnTVBV7uGY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Contains some simple data structures for passing around dataset splits.\"\"\"\n",
        "\n",
        "import dataclasses\n",
        "import datetime as dt\n",
        "from typing import Optional, Tuple\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class LabeledExamples:\n",
        "  \"\"\"A LabeledExample dataclass to make passing dataset splits around easier.\n",
        "\n",
        "    Attributes\n",
        "      ----------\n",
        "      examples: tf.Tensor\n",
        "          A tensor containing the examples (x) of the data.\n",
        "      labels: tf.Tensor\n",
        "          A tensor containing the labels (y) of the data.\n",
        "      size: int\n",
        "          The number of examples.\n",
        "      num_classes: int\n",
        "          The number of different labels.\n",
        "      trainable: bool\n",
        "          If true this dataset split is ok to train on (e.g. should be set to\n",
        "          False for the test split).\n",
        "  \"\"\"\n",
        "\n",
        "  examples: tf.Tensor\n",
        "  labels: tf.Tensor\n",
        "  size: int\n",
        "  num_classes: int\n",
        "  trainable: bool\n",
        "\n",
        "  def shuffle(self, seed: Optional[int] = None):\n",
        "    \"\"\"Shuffles the order of the examples.\"\"\"\n",
        "\n",
        "    if seed is None:\n",
        "      rand_seed = int(dt.datetime.now().strftime('%f')[:-5])\n",
        "    else:\n",
        "      rand_seed = seed\n",
        "\n",
        "    shuffled_examples = np.random.RandomState(seed=rand_seed).permutation(\n",
        "        self.examples)\n",
        "    shuffled_labels = np.random.RandomState(seed=rand_seed).permutation(\n",
        "        self.labels)\n",
        "    self.examples = shuffled_examples\n",
        "    self.labels = shuffled_labels\n",
        "\n",
        "\n",
        "def split(data: LabeledExamples,\n",
        "          index: int) -> Tuple[LabeledExamples, LabeledExamples]:\n",
        "  \"\"\"Splits a Labeled Examples instance into two parts.\"\"\"\n",
        "\n",
        "  split_a_examples = data.examples[:-index]\n",
        "  spit_a_labels = data.labels[:-index]\n",
        "\n",
        "  split_a = LabeledExamples(\n",
        "      split_a_examples,\n",
        "      spit_a_labels,\n",
        "      size=len(split_a_examples),\n",
        "      num_classes=data.num_classes,\n",
        "      trainable=True)\n",
        "\n",
        "  split_b_examples = data.examples[-index:]\n",
        "  split_b_labels = data.labels[-index:]\n",
        "\n",
        "  split_b = LabeledExamples(\n",
        "      split_b_examples,\n",
        "      split_b_labels,\n",
        "      size=len(split_b_examples),\n",
        "      num_classes=data.num_classes,\n",
        "      trainable=True)\n",
        "  return split_a, split_b\n",
        "\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class DataSplit:\n",
        "  \"\"\"Class for the dataset splits used in unlabeled distillation.\n",
        "\n",
        "  Attributes\n",
        "    ----------\n",
        "    dataset_a: LabeledExamples\n",
        "        An instance of LabeledExamples containing the (small) set of labeled\n",
        "        training data used to train the teacher.\n",
        "    dataset_b: LabeledExamples\n",
        "        An instance of LabeledExamples containing the (large) set of training\n",
        "        data used to train the teacher.  The labels of these points should *not*\n",
        "        be used to train the student model.\n",
        "    test: LabeledExamples\n",
        "        An instance of LabeledExamples containing the test set for evaluating\n",
        "        the teacher and student models.\n",
        "    validtion: LabeledExamples, Optional\n",
        "        An instance of LabeledExamples containing the validation.\n",
        "\n",
        "  \"\"\"\n",
        "  dataset_a: LabeledExamples\n",
        "  dataset_b: LabeledExamples\n",
        "  test: LabeledExamples\n",
        "  validation: Optional[LabeledExamples] = None\n"
      ],
      "metadata": {
        "id": "4qTmr89n3voa",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841786,
          "user_tz": 300,
          "elapsed": 2,
          "user": {
            "displayName": "XXX",
            "userId": "0000"
          }
        }
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Methods for loading datasets.\"\"\"\n",
        "\n",
        "from typing import Optional\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "\n",
        "def _example_label_split(dataset):\n",
        "  x, y = zip(*dataset)\n",
        "  return tf.convert_to_tensor(x, dtype=tf.float32), tf.convert_to_tensor(y)\n",
        "\n",
        "\n",
        "def _normalize_x(x):\n",
        "  x /= 127.5\n",
        "  x -= 1.0\n",
        "  return x\n",
        "\n",
        "\n",
        "def _normalize_y(y, num_classes):\n",
        "  return tf.keras.utils.to_categorical(y, num_classes)\n",
        "\n",
        "\n",
        "def load_cifar_data(labeled_percentage: Optional[int] = 20,\n",
        "                    num_classes: Optional[int] = 10,\n",
        "                    binary: Optional[bool] = False) -> DataSplit:\n",
        "  \"\"\"Loads cifar10 or cifar100 dataset.\"\"\"\n",
        "\n",
        "  if num_classes == 10:\n",
        "    train1, train2, test = tfds.load(\n",
        "        'cifar10',\n",
        "        split=[\n",
        "            f'train[:{labeled_percentage}%]', f'train[{labeled_percentage}%:]',\n",
        "            'test'\n",
        "        ],\n",
        "        as_supervised=True)\n",
        "  else:\n",
        "    train1, train2, test = tfds.load(\n",
        "        'cifar100',\n",
        "        split=[\n",
        "            f'train[:{labeled_percentage}%]', f'train[{labeled_percentage}%:]',\n",
        "            'test'\n",
        "        ],\n",
        "        as_supervised=True)\n",
        "\n",
        "  x_train1, y_train1 = _example_label_split(train1)\n",
        "  x_train2, y_train2 = _example_label_split(train2)\n",
        "  x_test, y_test = _example_label_split(test)\n",
        "\n",
        "  if binary:\n",
        "    y_train1 = y_train1 % 2\n",
        "    y_train2 = y_train2 % 2\n",
        "    y_test = y_test % 2\n",
        "    num_classes = 2\n",
        "\n",
        "  x_train1 = _normalize_x(x_train1)\n",
        "  y_train1 = _normalize_y(y_train1, num_classes)\n",
        "\n",
        "  x_train2 = _normalize_x(x_train2)\n",
        "  y_train2 = _normalize_y(y_train2, num_classes)\n",
        "\n",
        "  x_test = _normalize_x(x_test)\n",
        "  y_test = _normalize_y(y_test, num_classes)\n",
        "\n",
        "  dataset_a = LabeledExamples(\n",
        "      examples=x_train1,\n",
        "      labels=y_train1,\n",
        "      size=len(x_train1),\n",
        "      num_classes=num_classes,\n",
        "      trainable=True)\n",
        "\n",
        "  dataset_b = LabeledExamples(\n",
        "      examples=x_train2,\n",
        "      labels=y_train2,\n",
        "      size=len(x_train2),\n",
        "      num_classes=num_classes,\n",
        "      trainable=True)\n",
        "\n",
        "  test = LabeledExamples(\n",
        "      examples=x_test,\n",
        "      labels=y_test,\n",
        "      size=len(x_test),\n",
        "      num_classes=num_classes,\n",
        "      trainable=False)\n",
        "\n",
        "  data_split = DataSplit(\n",
        "      dataset_a=dataset_a, dataset_b=dataset_b, validation=None, test=test)\n",
        "\n",
        "  return data_split\n",
        "\n",
        "\n",
        "def load_celeb_a_data(\n",
        "    labeled_percentage: int,\n",
        "    num_classes: Optional[int] = 2,\n",
        "    label_key: Optional[str] = 'Male',\n",
        "    group_key: Optional[str] = 'Young') -> DataSplit:\n",
        "  \"\"\"Loads celeb_a.\"\"\"\n",
        "\n",
        "  def _get_image_and_label(feat_dict):\n",
        "    return (feat_dict['image'], feat_dict['attributes'][label_key])\n",
        "\n",
        "  def _preprocess_input_dict(feat_dict):\n",
        "    # Separate out the image and target variable from the feature dictionary.\n",
        "    image = feat_dict['image']\n",
        "    label = feat_dict['attributes'][label_key]\n",
        "    group = feat_dict['attributes'][group_key]\n",
        "\n",
        "    image = tf.cast(image, tf.float32)\n",
        "    image = tf.image.resize(image, [32, 32])\n",
        "\n",
        "    label = tf.cast(label, tf.float32)\n",
        "    group = tf.cast(group, tf.float32)\n",
        "\n",
        "    feat_dict['image'] = image\n",
        "    feat_dict['attributes'][label_key] = label\n",
        "    feat_dict['attributes'][group_key] = group\n",
        "\n",
        "    return feat_dict\n",
        "\n",
        "  celeb_a_builder = tfds.builder('celeb_a')\n",
        "  celeb_a_builder.download_and_prepare()\n",
        "\n",
        "  train1, train2, test = celeb_a_builder.as_dataset(split=[\n",
        "      f'train[:{labeled_percentage}%]', f'train[{labeled_percentage}%:]', 'test'\n",
        "  ])\n",
        "\n",
        "  x_train1, y_train1 = _example_label_split(\n",
        "      train1.batch(1).map(_preprocess_input_dict).map(_get_image_and_label))\n",
        "  x_train2, y_train2 = _example_label_split(\n",
        "      train2.batch(1).map(_preprocess_input_dict).map(_get_image_and_label))\n",
        "  x_test, y_test = _example_label_split(\n",
        "      test.batch(1).map(_preprocess_input_dict).map(_get_image_and_label))\n",
        "\n",
        "  x_train1 = tf.squeeze(x_train1)\n",
        "  y_train1 = tf.squeeze(y_train1)\n",
        "\n",
        "  x_train2 = tf.squeeze(x_train2)\n",
        "  y_train2 = tf.squeeze(y_train2)\n",
        "\n",
        "  x_test = tf.squeeze(x_test)\n",
        "  y_test = tf.squeeze(y_test)\n",
        "\n",
        "  x_train1 = _normalize_x(x_train1)\n",
        "  y_train1 = _normalize_y(y_train1, num_classes)\n",
        "\n",
        "  x_train2 = _normalize_x(x_train2)\n",
        "  y_train2 = _normalize_y(y_train2, num_classes)\n",
        "\n",
        "  x_test = _normalize_x(x_test)\n",
        "  y_test = _normalize_y(y_test, num_classes)\n",
        "\n",
        "  dataset_a = LabeledExamples(\n",
        "      examples=x_train1,\n",
        "      labels=y_train1,\n",
        "      size=len(x_train1),\n",
        "      num_classes=num_classes,\n",
        "      trainable=True)\n",
        "\n",
        "  dataset_b = LabeledExamples(\n",
        "      examples=x_train2,\n",
        "      labels=y_train2,\n",
        "      size=len(x_train2),\n",
        "      num_classes=num_classes,\n",
        "      trainable=True)\n",
        "\n",
        "  test = LabeledExamples(\n",
        "      examples=x_test,\n",
        "      labels=y_test,\n",
        "      size=len(x_test),\n",
        "      num_classes=num_classes,\n",
        "      trainable=False)\n",
        "\n",
        "  data_split = DataSplit(\n",
        "      dataset_a=dataset_a, dataset_b=dataset_b, validation=None, test=test)\n",
        "\n",
        "  return data_split\n",
        "\n",
        "\n",
        "def load_data(dataset_name: str,\n",
        "              labeled_percentage: int) -> DataSplit:\n",
        "  \"\"\"Loads either cifar10, cifar100, or celeb_a.\"\"\"\n",
        "\n",
        "  if dataset_name == 'cifar10':\n",
        "    num_classes = 10\n",
        "    return load_cifar_data(\n",
        "        labeled_percentage=labeled_percentage, num_classes=num_classes)\n",
        "  elif dataset_name == 'cifar10bin':\n",
        "    num_classes = 10\n",
        "    return load_cifar_data(\n",
        "        labeled_percentage=labeled_percentage,\n",
        "        num_classes=num_classes,\n",
        "        binary=True)\n",
        "  elif dataset_name == 'cifar100':\n",
        "    num_classes = 100\n",
        "    return load_cifar_data(\n",
        "        labeled_percentage=labeled_percentage, num_classes=num_classes)\n",
        "  elif dataset_name == 'celeb_a':\n",
        "    num_classes = 2\n",
        "    return load_celeb_a_data(\n",
        "        labeled_percentage=labeled_percentage, num_classes=num_classes)\n",
        "  else:\n",
        "    raise NameError('Wrong dataset name.')\n"
      ],
      "metadata": {
        "id": "yDtvUA9A2tpC",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841786,
          "user_tz": 300,
          "elapsed": 1,
          "user": {
            "displayName": "",
            "userId": "00000"
          }
        }
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Training Methods"
      ],
      "metadata": {
        "id": "qZ9CyD2Z7_Fh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"A function that trains models with or without data augmentation.\"\"\"\n",
        "\n",
        "from typing import Optional\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "\n",
        "def _train_model(\n",
        "    model: tf.keras.Model,\n",
        "    train_x: tf.Tensor,\n",
        "    train_y: tf.Tensor,\n",
        "    test_x: tf.Tensor,\n",
        "    test_y: tf.Tensor,\n",
        "    data_augmentation: Optional[bool] = False,\n",
        "    weights: Optional[tf.Tensor] = None,\n",
        "    epochs: Optional[int] = 200,\n",
        "    with_lr_scheduler: Optional[bool] = True,\n",
        "    batch_size: Optional[int] = 128,\n",
        "    epochs_offset: Optional[int] = 0) -> tf.keras.callbacks.History:\n",
        "  \"\"\"Trains a model.\n",
        "\n",
        "  Args:\n",
        "    model: Instance of tf.keras.Model.\n",
        "    train_x: A dataset containing the data to train on.\n",
        "    train_y: A dataset containing the labels of the data to train on.\n",
        "    test_x: A dataset containing the data to test on.\n",
        "    test_y: A dataset containing the labels of the data to test on.\n",
        "    data_augmentation: True if data augmentation is used and False otherwise.\n",
        "    weights: The weight sample-weights to be used.\n",
        "    epochs: The number of epochs to train for.\n",
        "    with_lr_scheduler: Whether we use a schedule for the learning-rate or not.\n",
        "    batch_size: The batch size used for training.\n",
        "    epochs_offset: How many epochs we assume we have performed so far.\n",
        "\n",
        "  Returns:\n",
        "      history: History of trained model.\n",
        "\n",
        "  Raises:\n",
        "    NotImplementedError: if online data augmention is used with weights or if\n",
        "    an augmentation method other than None, 'offline', or 'online' is provided.\n",
        "    To use weights with online data augmentation, weights should be passed as\n",
        "    an extra advice label and a corresponding loss with advice should be used,\n",
        "    see the loss functions defined in loss_functions.py.\n",
        "  \"\"\"\n",
        "\n",
        "  def lr_schedule(epoch):\n",
        "    lr = 1e-3\n",
        "    if epoch + epochs_offset > 180:\n",
        "      lr *= 0.5e-3\n",
        "    elif epoch + epochs_offset > 160:\n",
        "      lr *= 1e-3\n",
        "    elif epoch + epochs_offset > 120:\n",
        "      lr *= 1e-2\n",
        "    elif epoch + epochs_offset > 80:\n",
        "      lr *= 1e-1\n",
        "    return lr\n",
        "\n",
        "  if with_lr_scheduler:\n",
        "    lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)\n",
        "    callbacks = [lr_scheduler]\n",
        "\n",
        "  with_weights = False\n",
        "  if weights is not None:\n",
        "    with_weights = True\n",
        "    train_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "        (train_x, train_y, weights))\n",
        "  else:\n",
        "    with_weights = False\n",
        "    train_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "        (train_x, train_y))\n",
        "\n",
        "  def prepare(ds, shuffle=False, augment=False, with_weights=False):\n",
        "    autotune = tf.data.AUTOTUNE\n",
        "\n",
        "    data_augmentation_layer = tf.keras.Sequential([\n",
        "        tf.keras.layers.RandomFlip('horizontal'),\n",
        "        tf.keras.layers.RandomTranslation(\n",
        "            height_factor=0.1,\n",
        "            width_factor=0.1,\n",
        "            fill_mode='nearest',\n",
        "            interpolation='bilinear',\n",
        "            seed=None,\n",
        "            fill_value=0.0)\n",
        "    ])\n",
        "\n",
        "    if shuffle:\n",
        "      ds = ds.shuffle(len(train_x), reshuffle_each_iteration=True)\n",
        "\n",
        "    # Use data augmentation only on the training set.\n",
        "    if augment:\n",
        "      if with_weights:\n",
        "        ds = ds.map(\n",
        "            lambda x, y, w: (data_augmentation_layer(x, training=True), y, w),\n",
        "            num_parallel_calls=autotune)\n",
        "      else:\n",
        "        ds = ds.map(\n",
        "            lambda x, y: (data_augmentation_layer(x, training=True), y),\n",
        "            num_parallel_calls=autotune)\n",
        "\n",
        "    # Batch the dataset. The drop_remainder is needed to avoid uneven batches.\n",
        "    ds = ds.batch(batch_size, drop_remainder=True)\n",
        "\n",
        "    # Use buffered prefetching on all datasets.\n",
        "    ds = ds.prefetch(buffer_size=autotune)\n",
        "\n",
        "    return ds\n",
        "\n",
        "  train_dataset = prepare(\n",
        "      train_dataset,\n",
        "      shuffle=True,\n",
        "      augment=data_augmentation,\n",
        "      with_weights=with_weights)\n",
        "\n",
        "  history = model.fit(\n",
        "      train_dataset,\n",
        "      validation_data=(test_x, test_y),\n",
        "      epochs=epochs,\n",
        "      verbose=1,\n",
        "      workers=4,\n",
        "      callbacks=callbacks)\n",
        "\n",
        "  return history\n",
        "\n",
        "\n",
        "def train_model(\n",
        "    model: tf.keras.Model,\n",
        "    train_x: tf.Tensor,\n",
        "    train_y: tf.Tensor,\n",
        "    test_x: tf.Tensor,\n",
        "    test_y: tf.Tensor,\n",
        "    data_augmentation: Optional[str] = None,\n",
        "    weights: Optional[tf.Tensor] = None,\n",
        "    epochs: Optional[int] = 200,\n",
        "    with_lr_scheduler: Optional[bool] = True,\n",
        "    batch_size: Optional[int] = 128,\n",
        "    epochs_offset: Optional[int] = 0) -> tf.keras.callbacks.History:\n",
        "  \"\"\"Trains a model.\n",
        "\n",
        "  Args:\n",
        "    model: Instance of tf.keras.Model.\n",
        "    train_x: A dataset containing the data to train on.\n",
        "    train_y: A dataset containing the labels of the data to train on.\n",
        "    test_x: A dataset containing the data to test on.\n",
        "    test_y: A dataset containing the labels of the data to test on.\n",
        "    data_augmentation: Can be either 'no', 'offline', 'online'.\n",
        "    weights: The weight sample-weights to be used.\n",
        "    epochs: The number of epochs to train for.\n",
        "    with_lr_scheduler: Whether we use a schedule for the learning-rate or not.\n",
        "    batch_size: The batch size used for training.\n",
        "    epochs_offset: How many epochs we assume we have performed so far.\n",
        "\n",
        "  Returns:\n",
        "      history: History of trained model.\n",
        "  \"\"\"\n",
        "\n",
        "  if data_augmentation != 'online':\n",
        "\n",
        "    if data_augmentation == 'no' or not data_augmentation:\n",
        "      use_data_augmentation = False\n",
        "    else:\n",
        "      use_data_augmentation = True\n",
        "\n",
        "    return _train_model(\n",
        "        model=model,\n",
        "        train_x=train_x,\n",
        "        train_y=train_y,\n",
        "        test_x=test_x,\n",
        "        test_y=test_y,\n",
        "        epochs=epochs,\n",
        "        data_augmentation=use_data_augmentation,\n",
        "        weights=weights,\n",
        "        with_lr_scheduler=with_lr_scheduler,\n",
        "        batch_size=batch_size,\n",
        "        epochs_offset=epochs_offset)\n",
        "\n",
        "  elif data_augmentation == 'online':\n",
        "    results = []\n",
        "\n",
        "    for i in range(1, epochs+1):\n",
        "\n",
        "      print(f'Epoch {i}/{epochs}')\n",
        "\n",
        "      history_tmp = _train_model(\n",
        "          model=model,\n",
        "          train_x=train_x,\n",
        "          train_y=train_y,\n",
        "          test_x=test_x,\n",
        "          test_y=test_y,\n",
        "          epochs=1,\n",
        "          data_augmentation=True,\n",
        "          weights=weights,\n",
        "          with_lr_scheduler=with_lr_scheduler,\n",
        "          batch_size=batch_size,\n",
        "          epochs_offset=i)\n",
        "\n",
        "      results.append(history_tmp)\n",
        "\n",
        "    history = results[0]\n",
        "\n",
        "    for hist in results[1:]:\n",
        "      for key in hist.history.keys():\n",
        "        history.history[key] += hist.history[key]\n",
        "    return history\n",
        "\n",
        "  else:\n",
        "    raise NotImplementedError()\n"
      ],
      "metadata": {
        "id": "n0BRIb4B4VWQ",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841787,
          "user_tz": 300,
          "elapsed": 2,
          "user": {
            "displayName": "XXXX",
            "userId": "0000"
          }
        }
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Experiment Parameters"
      ],
      "metadata": {
        "id": "3quluVIF76CT"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Dataset used for training.\n",
        "_DATASET =  'cifar100'\n",
        "# Whether the teacher outputs soft-labels or hard-labels.\n",
        "_WITH_SOFT_LABELS = True\n",
        "# The model architecture we are using for the teacher.\n",
        "_TEACHER_MODEL = 'resnet'\n",
        "# The data augmentation method we are using for the teacher.\n",
        "_TEACHER_DATA_AUGMENTATION =  'online'\n",
        "# The data augmentation method we are using for pretraining the student.\n",
        "_STUDENT_PRETRAINING_DATA_AUGMENTATION = 'online'\n",
        "# Data augmentation method we are using for training the student during\n",
        "_STUDENT_DISTILLATION_DATA_AUGMENTATION =  'online'\n",
        "# Student model architecture.\n",
        "_STUDENT_MODEL =  'resnet'\n",
        "# Teacher mobilenet depth multiplier.\n",
        "_TEACHER_MOBILENET_DEPTH_MULTIPLIER = 2\n",
        "# Student mobilenet depth multiplier.\n",
        "_STUDENT_MOBILENET_DEPTH_MULTIPLIER = 1\n",
        "# Teacher resnet depth.\n",
        "_TEACHER_RESNET_DEPTH = 110\n",
        "# Student resnet depth.\n",
        "_STUDENT_RESNET_DEPTH = 56\n",
        "# Student training epochs.\n",
        "_STUDENT_EPOCHS = 200\n",
        "# Teacher training epochs.\n",
        "_TEACHER_EPOCHS = 200\n",
        "# Student pretraining epochs.\n",
        "_STUDENT_OPTIMIZER = 'adam'\n",
        "# Teacher optimizer.\n",
        "_TEACHER_OPTIMIZER = 'adam'\n",
        "# Size of validation dataset.\n",
        "_VALIDATION_DATASET_SIZE = 512\n",
        "# Size of labeled dataset (percentage of labeled examples).\n",
        "_SIZE_OF_DATASET_A = 10\n",
        "# Number of experiment trials.\n",
        "_NUM_TRIALS =  3\n",
        "# Whether to randomize the dataset in each trial.\n",
        "_RANDOMIZE_DATASET = True\n",
        "# Whether to train on validation data.\n",
        "_TRAIN_ON_VALIDATION = True\n",
        "# The batch size to be used.\n",
        "_BATCH_SIZE = 128\n",
        "# The random seed to be used.\n",
        "_SEED = 753410"
      ],
      "metadata": {
        "id": "WBdL4N7G6JC3",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841788,
          "user_tz": 300,
          "elapsed": 3,
          "user": {
            "displayName": "XXX",
            "userId": "000000"
          }
        }
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the Dataset"
      ],
      "metadata": {
        "id": "XHVF28nf0ZLY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Load the dataset\n",
        "data_split = load_data(\n",
        "    dataset_name=_DATASET, labeled_percentage=_SIZE_OF_DATASET_A)\n",
        "\n",
        "# Number of classes should be equal for all three splits (dataset_a,\n",
        "# dataset_b, test) so we set it using the labeled split, dataset_a.\n",
        "num_classes = data_split.dataset_a.num_classes"
      ],
      "metadata": {
        "id": "Fn9Z6_bgDgFx",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841788,
          "user_tz": 300,
          "elapsed": 3,
          "user": {
            "displayName": "XXXX",
            "userId": "00000"
          }
        }
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "# Train/Load the Teacher Model"
      ],
      "metadata": {
        "id": "2hwUpBI-kfqH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "# Training the teacher\n",
        "\n",
        "teacher_model = load_model(\n",
        "      _TEACHER_MODEL,\n",
        "      num_classes,\n",
        "      _TEACHER_OPTIMIZER,\n",
        "      width_multiplier=1.0,\n",
        "      depth_multiplier=_TEACHER_MOBILENET_DEPTH_MULTIPLIER,\n",
        "      resnet_depth=_TEACHER_RESNET_DEPTH)\n",
        "\n",
        "_ = train_model(\n",
        "      teacher_model,\n",
        "      data_split.dataset_a.examples,\n",
        "      data_split.dataset_a.labels,\n",
        "      data_split.test.examples,\n",
        "      data_split.test.labels,\n",
        "      data_augmentation=_TEACHER_DATA_AUGMENTATION,\n",
        "      epochs=_TEACHER_EPOCHS)\n"
      ],
      "metadata": {
        "id": "RPjarREQjEsl",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841801,
          "user_tz": 300,
          "elapsed": 0,
          "user": {
            "displayName": "XXXXX",
            "userId": "0000000"
          }
        },
        "outputId": "4342f4a6-83d1-4446-a18d-2fa29ef25486"
      },
      "execution_count": 8,
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/200\n",
            "39/39 [==============================] - 72s 228ms/step - loss: 6.4264 - categorical_accuracy: 0.0529 - val_loss: 42.0636 - val_categorical_accuracy: 0.0106 - lr: 0.0010\n",
            "Epoch 2/200\n",
            "39/39 [==============================] - 7s 170ms/step - loss: 5.8560 - categorical_accuracy: 0.0978 - val_loss: 6.3051 - val_categorical_accuracy: 0.0457 - lr: 0.0010\n",
            "Epoch 3/200\n",
            "39/39 [==============================] - 7s 170ms/step - loss: 5.5039 - categorical_accuracy: 0.1210 - val_loss: 5.7499 - val_categorical_accuracy: 0.0817 - lr: 0.0010\n",
            "Epoch 4/200\n",
            "39/39 [==============================] - 7s 176ms/step - loss: 5.2332 - categorical_accuracy: 0.1444 - val_loss: 5.5063 - val_categorical_accuracy: 0.0991 - lr: 0.0010\n",
            "Epoch 5/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 5.0100 - categorical_accuracy: 0.1565 - val_loss: 5.3868 - val_categorical_accuracy: 0.0982 - lr: 0.0010\n",
            "Epoch 6/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 4.7962 - categorical_accuracy: 0.1791 - val_loss: 5.2402 - val_categorical_accuracy: 0.1091 - lr: 0.0010\n",
            "Epoch 7/200\n",
            "39/39 [==============================] - 7s 171ms/step - loss: 4.6837 - categorical_accuracy: 0.1807 - val_loss: 5.0669 - val_categorical_accuracy: 0.1222 - lr: 0.0010\n",
            "Epoch 8/200\n",
            "39/39 [==============================] - 7s 169ms/step - loss: 4.4999 - categorical_accuracy: 0.2127 - val_loss: 4.9362 - val_categorical_accuracy: 0.1272 - lr: 0.0010\n",
            "Epoch 9/200\n",
            "39/39 [==============================] - 7s 170ms/step - loss: 4.3252 - categorical_accuracy: 0.2284 - val_loss: 4.8628 - val_categorical_accuracy: 0.1333 - lr: 0.0010\n",
            "Epoch 10/200\n",
            "39/39 [==============================] - 7s 170ms/step - loss: 4.1985 - categorical_accuracy: 0.2420 - val_loss: 4.7680 - val_categorical_accuracy: 0.1476 - lr: 0.0010\n",
            "Epoch 11/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 4.0534 - categorical_accuracy: 0.2612 - val_loss: 4.6304 - val_categorical_accuracy: 0.1559 - lr: 0.0010\n",
            "Epoch 12/200\n",
            "39/39 [==============================] - 6s 168ms/step - loss: 3.9484 - categorical_accuracy: 0.2722 - val_loss: 4.7309 - val_categorical_accuracy: 0.1572 - lr: 0.0010\n",
            "Epoch 13/200\n",
            "39/39 [==============================] - 7s 170ms/step - loss: 3.8395 - categorical_accuracy: 0.2949 - val_loss: 4.5383 - val_categorical_accuracy: 0.1630 - lr: 0.0010\n",
            "Epoch 14/200\n",
            "39/39 [==============================] - 6s 169ms/step - loss: 3.7136 - categorical_accuracy: 0.3179 - val_loss: 4.5459 - val_categorical_accuracy: 0.1684 - lr: 0.0010\n",
            "Epoch 15/200\n",
            "39/39 [==============================] - 7s 173ms/step - loss: 3.6513 - categorical_accuracy: 0.3173 - val_loss: 4.5911 - val_categorical_accuracy: 0.1616 - lr: 0.0010\n",
            "Epoch 16/200\n",
            "39/39 [==============================] - 7s 189ms/step - loss: 3.5444 - categorical_accuracy: 0.3303 - val_loss: 4.3934 - val_categorical_accuracy: 0.1817 - lr: 0.0010\n",
            "Epoch 17/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 3.3991 - categorical_accuracy: 0.3608 - val_loss: 4.4685 - val_categorical_accuracy: 0.1690 - lr: 0.0010\n",
            "Epoch 18/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 3.2824 - categorical_accuracy: 0.3930 - val_loss: 4.3883 - val_categorical_accuracy: 0.1828 - lr: 0.0010\n",
            "Epoch 19/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 3.2042 - categorical_accuracy: 0.4046 - val_loss: 4.4805 - val_categorical_accuracy: 0.1762 - lr: 0.0010\n",
            "Epoch 20/200\n",
            "39/39 [==============================] - 7s 193ms/step - loss: 3.2301 - categorical_accuracy: 0.3898 - val_loss: 4.2759 - val_categorical_accuracy: 0.2127 - lr: 0.0010\n",
            "Epoch 21/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 3.0260 - categorical_accuracy: 0.4465 - val_loss: 4.4780 - val_categorical_accuracy: 0.1967 - lr: 0.0010\n",
            "Epoch 22/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 3.0364 - categorical_accuracy: 0.4327 - val_loss: 4.3010 - val_categorical_accuracy: 0.2089 - lr: 0.0010\n",
            "Epoch 23/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 2.8982 - categorical_accuracy: 0.4639 - val_loss: 4.2092 - val_categorical_accuracy: 0.2220 - lr: 0.0010\n",
            "Epoch 24/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 2.7583 - categorical_accuracy: 0.4994 - val_loss: 4.3629 - val_categorical_accuracy: 0.2101 - lr: 0.0010\n",
            "Epoch 25/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 2.7180 - categorical_accuracy: 0.5026 - val_loss: 4.1544 - val_categorical_accuracy: 0.2415 - lr: 0.0010\n",
            "Epoch 26/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 2.8051 - categorical_accuracy: 0.4782 - val_loss: 4.3106 - val_categorical_accuracy: 0.2364 - lr: 0.0010\n",
            "Epoch 27/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 2.5709 - categorical_accuracy: 0.5417 - val_loss: 4.3001 - val_categorical_accuracy: 0.2265 - lr: 0.0010\n",
            "Epoch 28/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 2.6145 - categorical_accuracy: 0.5232 - val_loss: 4.4489 - val_categorical_accuracy: 0.2150 - lr: 0.0010\n",
            "Epoch 29/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 2.4396 - categorical_accuracy: 0.5715 - val_loss: 4.5482 - val_categorical_accuracy: 0.2038 - lr: 0.0010\n",
            "Epoch 30/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 2.3249 - categorical_accuracy: 0.6008 - val_loss: 4.3103 - val_categorical_accuracy: 0.2370 - lr: 0.0010\n",
            "Epoch 31/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 2.4677 - categorical_accuracy: 0.5637 - val_loss: 4.7798 - val_categorical_accuracy: 0.2053 - lr: 0.0010\n",
            "Epoch 32/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 2.3278 - categorical_accuracy: 0.5927 - val_loss: 4.3724 - val_categorical_accuracy: 0.2306 - lr: 0.0010\n",
            "Epoch 33/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 2.2371 - categorical_accuracy: 0.6208 - val_loss: 4.2659 - val_categorical_accuracy: 0.2475 - lr: 0.0010\n",
            "Epoch 34/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 2.1569 - categorical_accuracy: 0.6492 - val_loss: 4.7056 - val_categorical_accuracy: 0.2142 - lr: 0.0010\n",
            "Epoch 35/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 2.0172 - categorical_accuracy: 0.6887 - val_loss: 4.4976 - val_categorical_accuracy: 0.2273 - lr: 0.0010\n",
            "Epoch 36/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 2.1734 - categorical_accuracy: 0.6346 - val_loss: 4.4821 - val_categorical_accuracy: 0.2386 - lr: 0.0010\n",
            "Epoch 37/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 2.0182 - categorical_accuracy: 0.6849 - val_loss: 4.5522 - val_categorical_accuracy: 0.2411 - lr: 0.0010\n",
            "Epoch 38/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.9489 - categorical_accuracy: 0.7073 - val_loss: 4.6027 - val_categorical_accuracy: 0.2371 - lr: 0.0010\n",
            "Epoch 39/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.8875 - categorical_accuracy: 0.7214 - val_loss: 4.5938 - val_categorical_accuracy: 0.2385 - lr: 0.0010\n",
            "Epoch 40/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 1.8599 - categorical_accuracy: 0.7238 - val_loss: 4.7749 - val_categorical_accuracy: 0.2434 - lr: 0.0010\n",
            "Epoch 41/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 1.7727 - categorical_accuracy: 0.7496 - val_loss: 4.3486 - val_categorical_accuracy: 0.2714 - lr: 0.0010\n",
            "Epoch 42/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 1.7050 - categorical_accuracy: 0.7782 - val_loss: 4.8688 - val_categorical_accuracy: 0.2409 - lr: 0.0010\n",
            "Epoch 43/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.6123 - categorical_accuracy: 0.8051 - val_loss: 4.6361 - val_categorical_accuracy: 0.2478 - lr: 0.0010\n",
            "Epoch 44/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.5667 - categorical_accuracy: 0.8199 - val_loss: 4.8400 - val_categorical_accuracy: 0.2407 - lr: 0.0010\n",
            "Epoch 45/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 1.5589 - categorical_accuracy: 0.8123 - val_loss: 4.8082 - val_categorical_accuracy: 0.2501 - lr: 0.0010\n",
            "Epoch 46/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 1.4886 - categorical_accuracy: 0.8405 - val_loss: 4.9528 - val_categorical_accuracy: 0.2293 - lr: 0.0010\n",
            "Epoch 47/200\n",
            "39/39 [==============================] - 8s 214ms/step - loss: 1.4405 - categorical_accuracy: 0.8538 - val_loss: 4.9247 - val_categorical_accuracy: 0.2536 - lr: 0.0010\n",
            "Epoch 48/200\n",
            "39/39 [==============================] - 7s 193ms/step - loss: 1.4000 - categorical_accuracy: 0.8676 - val_loss: 4.7898 - val_categorical_accuracy: 0.2592 - lr: 0.0010\n",
            "Epoch 49/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.7323 - categorical_accuracy: 0.7488 - val_loss: 5.2858 - val_categorical_accuracy: 0.2373 - lr: 0.0010\n",
            "Epoch 50/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 1.6300 - categorical_accuracy: 0.7885 - val_loss: 5.1657 - val_categorical_accuracy: 0.2444 - lr: 0.0010\n",
            "Epoch 51/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 1.5078 - categorical_accuracy: 0.8239 - val_loss: 4.9280 - val_categorical_accuracy: 0.2563 - lr: 0.0010\n",
            "Epoch 52/200\n",
            "39/39 [==============================] - 7s 191ms/step - loss: 1.4089 - categorical_accuracy: 0.8564 - val_loss: 4.9077 - val_categorical_accuracy: 0.2649 - lr: 0.0010\n",
            "Epoch 53/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 1.4137 - categorical_accuracy: 0.8564 - val_loss: 4.6934 - val_categorical_accuracy: 0.2656 - lr: 0.0010\n",
            "Epoch 54/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 1.3421 - categorical_accuracy: 0.8798 - val_loss: 4.8395 - val_categorical_accuracy: 0.2631 - lr: 0.0010\n",
            "Epoch 55/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.2398 - categorical_accuracy: 0.9109 - val_loss: 4.9884 - val_categorical_accuracy: 0.2710 - lr: 0.0010\n",
            "Epoch 56/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.2153 - categorical_accuracy: 0.9143 - val_loss: 5.0360 - val_categorical_accuracy: 0.2628 - lr: 0.0010\n",
            "Epoch 57/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.3830 - categorical_accuracy: 0.8516 - val_loss: 5.2880 - val_categorical_accuracy: 0.2581 - lr: 0.0010\n",
            "Epoch 58/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.3134 - categorical_accuracy: 0.8872 - val_loss: 4.7965 - val_categorical_accuracy: 0.2769 - lr: 0.0010\n",
            "Epoch 59/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.3189 - categorical_accuracy: 0.8804 - val_loss: 5.0649 - val_categorical_accuracy: 0.2661 - lr: 0.0010\n",
            "Epoch 60/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 1.3016 - categorical_accuracy: 0.8844 - val_loss: 5.3420 - val_categorical_accuracy: 0.2451 - lr: 0.0010\n",
            "Epoch 61/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.2602 - categorical_accuracy: 0.8914 - val_loss: 5.0521 - val_categorical_accuracy: 0.2710 - lr: 0.0010\n",
            "Epoch 62/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 1.1676 - categorical_accuracy: 0.9265 - val_loss: 5.2738 - val_categorical_accuracy: 0.2617 - lr: 0.0010\n",
            "Epoch 63/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 1.1408 - categorical_accuracy: 0.9371 - val_loss: 5.1671 - val_categorical_accuracy: 0.2542 - lr: 0.0010\n",
            "Epoch 64/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.1193 - categorical_accuracy: 0.9375 - val_loss: 5.0435 - val_categorical_accuracy: 0.2600 - lr: 0.0010\n",
            "Epoch 65/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.2232 - categorical_accuracy: 0.9004 - val_loss: 5.1308 - val_categorical_accuracy: 0.2678 - lr: 0.0010\n",
            "Epoch 66/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.2025 - categorical_accuracy: 0.9030 - val_loss: 5.2123 - val_categorical_accuracy: 0.2708 - lr: 0.0010\n",
            "Epoch 67/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 1.1054 - categorical_accuracy: 0.9403 - val_loss: 5.1331 - val_categorical_accuracy: 0.2668 - lr: 0.0010\n",
            "Epoch 68/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 1.2156 - categorical_accuracy: 0.9010 - val_loss: 5.5604 - val_categorical_accuracy: 0.2381 - lr: 0.0010\n",
            "Epoch 69/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 1.2300 - categorical_accuracy: 0.8932 - val_loss: 5.6196 - val_categorical_accuracy: 0.2450 - lr: 0.0010\n",
            "Epoch 70/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 1.1845 - categorical_accuracy: 0.9085 - val_loss: 4.9535 - val_categorical_accuracy: 0.2725 - lr: 0.0010\n",
            "Epoch 71/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 1.1272 - categorical_accuracy: 0.9275 - val_loss: 5.1439 - val_categorical_accuracy: 0.2699 - lr: 0.0010\n",
            "Epoch 72/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 1.0483 - categorical_accuracy: 0.9559 - val_loss: 5.0834 - val_categorical_accuracy: 0.2657 - lr: 0.0010\n",
            "Epoch 73/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 1.0342 - categorical_accuracy: 0.9547 - val_loss: 5.3845 - val_categorical_accuracy: 0.2760 - lr: 0.0010\n",
            "Epoch 74/200\n",
            "39/39 [==============================] - 7s 190ms/step - loss: 1.1710 - categorical_accuracy: 0.9052 - val_loss: 5.8841 - val_categorical_accuracy: 0.2349 - lr: 0.0010\n",
            "Epoch 75/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 1.1601 - categorical_accuracy: 0.9105 - val_loss: 5.6393 - val_categorical_accuracy: 0.2524 - lr: 0.0010\n",
            "Epoch 76/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 1.1892 - categorical_accuracy: 0.8990 - val_loss: 5.3940 - val_categorical_accuracy: 0.2616 - lr: 0.0010\n",
            "Epoch 77/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 1.1666 - categorical_accuracy: 0.9075 - val_loss: 5.4885 - val_categorical_accuracy: 0.2651 - lr: 0.0010\n",
            "Epoch 78/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 1.1113 - categorical_accuracy: 0.9245 - val_loss: 5.4571 - val_categorical_accuracy: 0.2398 - lr: 0.0010\n",
            "Epoch 79/200\n",
            "39/39 [==============================] - 7s 191ms/step - loss: 1.0651 - categorical_accuracy: 0.9435 - val_loss: 5.3295 - val_categorical_accuracy: 0.2640 - lr: 0.0010\n",
            "Epoch 80/200\n",
            "39/39 [==============================] - 7s 193ms/step - loss: 1.0885 - categorical_accuracy: 0.9295 - val_loss: 5.5011 - val_categorical_accuracy: 0.2605 - lr: 0.0010\n",
            "Epoch 81/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.9767 - categorical_accuracy: 0.9688 - val_loss: 4.7826 - val_categorical_accuracy: 0.3036 - lr: 1.0000e-04\n",
            "Epoch 82/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.9056 - categorical_accuracy: 0.9890 - val_loss: 4.6740 - val_categorical_accuracy: 0.3179 - lr: 1.0000e-04\n",
            "Epoch 83/200\n",
            "39/39 [==============================] - 7s 193ms/step - loss: 0.8852 - categorical_accuracy: 0.9958 - val_loss: 4.6300 - val_categorical_accuracy: 0.3232 - lr: 1.0000e-04\n",
            "Epoch 84/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 0.8787 - categorical_accuracy: 0.9964 - val_loss: 4.5987 - val_categorical_accuracy: 0.3266 - lr: 1.0000e-04\n",
            "Epoch 85/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 0.8724 - categorical_accuracy: 0.9982 - val_loss: 4.5823 - val_categorical_accuracy: 0.3269 - lr: 1.0000e-04\n",
            "Epoch 86/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.8665 - categorical_accuracy: 0.9984 - val_loss: 4.5982 - val_categorical_accuracy: 0.3280 - lr: 1.0000e-04\n",
            "Epoch 87/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.8692 - categorical_accuracy: 0.9966 - val_loss: 4.6000 - val_categorical_accuracy: 0.3312 - lr: 1.0000e-04\n",
            "Epoch 88/200\n",
            "39/39 [==============================] - 7s 177ms/step - loss: 0.8619 - categorical_accuracy: 0.9988 - val_loss: 4.6082 - val_categorical_accuracy: 0.3285 - lr: 1.0000e-04\n",
            "Epoch 89/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.8567 - categorical_accuracy: 0.9992 - val_loss: 4.6023 - val_categorical_accuracy: 0.3313 - lr: 1.0000e-04\n",
            "Epoch 90/200\n",
            "39/39 [==============================] - 7s 189ms/step - loss: 0.8535 - categorical_accuracy: 0.9990 - val_loss: 4.6032 - val_categorical_accuracy: 0.3320 - lr: 1.0000e-04\n",
            "Epoch 91/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 0.8494 - categorical_accuracy: 0.9990 - val_loss: 4.5960 - val_categorical_accuracy: 0.3318 - lr: 1.0000e-04\n",
            "Epoch 92/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.8497 - categorical_accuracy: 0.9984 - val_loss: 4.6017 - val_categorical_accuracy: 0.3330 - lr: 1.0000e-04\n",
            "Epoch 93/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.8462 - categorical_accuracy: 0.9998 - val_loss: 4.6076 - val_categorical_accuracy: 0.3325 - lr: 1.0000e-04\n",
            "Epoch 94/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.8450 - categorical_accuracy: 0.9988 - val_loss: 4.6135 - val_categorical_accuracy: 0.3314 - lr: 1.0000e-04\n",
            "Epoch 95/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.8410 - categorical_accuracy: 0.9990 - val_loss: 4.6077 - val_categorical_accuracy: 0.3332 - lr: 1.0000e-04\n",
            "Epoch 96/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 0.8395 - categorical_accuracy: 0.9990 - val_loss: 4.6119 - val_categorical_accuracy: 0.3325 - lr: 1.0000e-04\n",
            "Epoch 97/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.8364 - categorical_accuracy: 0.9990 - val_loss: 4.6192 - val_categorical_accuracy: 0.3337 - lr: 1.0000e-04\n",
            "Epoch 98/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.8332 - categorical_accuracy: 0.9994 - val_loss: 4.6262 - val_categorical_accuracy: 0.3355 - lr: 1.0000e-04\n",
            "Epoch 99/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.8304 - categorical_accuracy: 0.9992 - val_loss: 4.6267 - val_categorical_accuracy: 0.3352 - lr: 1.0000e-04\n",
            "Epoch 100/200\n",
            "39/39 [==============================] - 7s 177ms/step - loss: 0.8306 - categorical_accuracy: 0.9988 - val_loss: 4.6251 - val_categorical_accuracy: 0.3348 - lr: 1.0000e-04\n",
            "Epoch 101/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.8274 - categorical_accuracy: 0.9992 - val_loss: 4.6378 - val_categorical_accuracy: 0.3355 - lr: 1.0000e-04\n",
            "Epoch 102/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.8241 - categorical_accuracy: 1.0000 - val_loss: 4.6368 - val_categorical_accuracy: 0.3368 - lr: 1.0000e-04\n",
            "Epoch 103/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.8210 - categorical_accuracy: 0.9996 - val_loss: 4.6341 - val_categorical_accuracy: 0.3342 - lr: 1.0000e-04\n",
            "Epoch 104/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.8174 - categorical_accuracy: 1.0000 - val_loss: 4.6499 - val_categorical_accuracy: 0.3318 - lr: 1.0000e-04\n",
            "Epoch 105/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.8170 - categorical_accuracy: 0.9996 - val_loss: 4.6518 - val_categorical_accuracy: 0.3342 - lr: 1.0000e-04\n",
            "Epoch 106/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.8139 - categorical_accuracy: 0.9996 - val_loss: 4.6521 - val_categorical_accuracy: 0.3357 - lr: 1.0000e-04\n",
            "Epoch 107/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 0.8123 - categorical_accuracy: 0.9994 - val_loss: 4.6527 - val_categorical_accuracy: 0.3354 - lr: 1.0000e-04\n",
            "Epoch 108/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.8087 - categorical_accuracy: 0.9998 - val_loss: 4.6586 - val_categorical_accuracy: 0.3347 - lr: 1.0000e-04\n",
            "Epoch 109/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.8060 - categorical_accuracy: 1.0000 - val_loss: 4.6634 - val_categorical_accuracy: 0.3353 - lr: 1.0000e-04\n",
            "Epoch 110/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.8036 - categorical_accuracy: 0.9992 - val_loss: 4.6653 - val_categorical_accuracy: 0.3350 - lr: 1.0000e-04\n",
            "Epoch 111/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.8020 - categorical_accuracy: 0.9994 - val_loss: 4.6714 - val_categorical_accuracy: 0.3351 - lr: 1.0000e-04\n",
            "Epoch 112/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 0.7993 - categorical_accuracy: 0.9996 - val_loss: 4.6834 - val_categorical_accuracy: 0.3358 - lr: 1.0000e-04\n",
            "Epoch 113/200\n",
            "39/39 [==============================] - 7s 190ms/step - loss: 0.7978 - categorical_accuracy: 0.9996 - val_loss: 4.6950 - val_categorical_accuracy: 0.3342 - lr: 1.0000e-04\n",
            "Epoch 114/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 0.7944 - categorical_accuracy: 0.9992 - val_loss: 4.6605 - val_categorical_accuracy: 0.3346 - lr: 1.0000e-04\n",
            "Epoch 115/200\n",
            "39/39 [==============================] - 7s 191ms/step - loss: 0.7920 - categorical_accuracy: 0.9994 - val_loss: 4.6829 - val_categorical_accuracy: 0.3335 - lr: 1.0000e-04\n",
            "Epoch 116/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.7886 - categorical_accuracy: 0.9998 - val_loss: 4.6745 - val_categorical_accuracy: 0.3361 - lr: 1.0000e-04\n",
            "Epoch 117/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7876 - categorical_accuracy: 1.0000 - val_loss: 4.6750 - val_categorical_accuracy: 0.3350 - lr: 1.0000e-04\n",
            "Epoch 118/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.7820 - categorical_accuracy: 0.9998 - val_loss: 4.6761 - val_categorical_accuracy: 0.3334 - lr: 1.0000e-04\n",
            "Epoch 119/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 0.7796 - categorical_accuracy: 0.9998 - val_loss: 4.6854 - val_categorical_accuracy: 0.3340 - lr: 1.0000e-04\n",
            "Epoch 120/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.7786 - categorical_accuracy: 0.9998 - val_loss: 4.6901 - val_categorical_accuracy: 0.3353 - lr: 1.0000e-04\n",
            "Epoch 121/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 0.7755 - categorical_accuracy: 1.0000 - val_loss: 4.6904 - val_categorical_accuracy: 0.3358 - lr: 1.0000e-05\n",
            "Epoch 122/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.7757 - categorical_accuracy: 0.9998 - val_loss: 4.6917 - val_categorical_accuracy: 0.3367 - lr: 1.0000e-05\n",
            "Epoch 123/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 0.7752 - categorical_accuracy: 1.0000 - val_loss: 4.6912 - val_categorical_accuracy: 0.3376 - lr: 1.0000e-05\n",
            "Epoch 124/200\n",
            "39/39 [==============================] - 7s 191ms/step - loss: 0.7745 - categorical_accuracy: 1.0000 - val_loss: 4.6921 - val_categorical_accuracy: 0.3377 - lr: 1.0000e-05\n",
            "Epoch 125/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.7749 - categorical_accuracy: 1.0000 - val_loss: 4.6932 - val_categorical_accuracy: 0.3379 - lr: 1.0000e-05\n",
            "Epoch 126/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7767 - categorical_accuracy: 0.9994 - val_loss: 4.6922 - val_categorical_accuracy: 0.3388 - lr: 1.0000e-05\n",
            "Epoch 127/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7736 - categorical_accuracy: 1.0000 - val_loss: 4.6912 - val_categorical_accuracy: 0.3392 - lr: 1.0000e-05\n",
            "Epoch 128/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.7741 - categorical_accuracy: 0.9998 - val_loss: 4.6925 - val_categorical_accuracy: 0.3385 - lr: 1.0000e-05\n",
            "Epoch 129/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 0.7721 - categorical_accuracy: 1.0000 - val_loss: 4.6921 - val_categorical_accuracy: 0.3378 - lr: 1.0000e-05\n",
            "Epoch 130/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.7737 - categorical_accuracy: 0.9994 - val_loss: 4.6927 - val_categorical_accuracy: 0.3378 - lr: 1.0000e-05\n",
            "Epoch 131/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7733 - categorical_accuracy: 0.9996 - val_loss: 4.6927 - val_categorical_accuracy: 0.3375 - lr: 1.0000e-05\n",
            "Epoch 132/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7727 - categorical_accuracy: 0.9998 - val_loss: 4.6921 - val_categorical_accuracy: 0.3380 - lr: 1.0000e-05\n",
            "Epoch 133/200\n",
            "39/39 [==============================] - 7s 194ms/step - loss: 0.7714 - categorical_accuracy: 1.0000 - val_loss: 4.6929 - val_categorical_accuracy: 0.3383 - lr: 1.0000e-05\n",
            "Epoch 134/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7713 - categorical_accuracy: 0.9998 - val_loss: 4.6906 - val_categorical_accuracy: 0.3386 - lr: 1.0000e-05\n",
            "Epoch 135/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.7704 - categorical_accuracy: 1.0000 - val_loss: 4.6931 - val_categorical_accuracy: 0.3387 - lr: 1.0000e-05\n",
            "Epoch 136/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7708 - categorical_accuracy: 1.0000 - val_loss: 4.6952 - val_categorical_accuracy: 0.3389 - lr: 1.0000e-05\n",
            "Epoch 137/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.7712 - categorical_accuracy: 1.0000 - val_loss: 4.6928 - val_categorical_accuracy: 0.3382 - lr: 1.0000e-05\n",
            "Epoch 138/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.7708 - categorical_accuracy: 1.0000 - val_loss: 4.6927 - val_categorical_accuracy: 0.3386 - lr: 1.0000e-05\n",
            "Epoch 139/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 0.7694 - categorical_accuracy: 1.0000 - val_loss: 4.6917 - val_categorical_accuracy: 0.3382 - lr: 1.0000e-05\n",
            "Epoch 140/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 0.7701 - categorical_accuracy: 0.9996 - val_loss: 4.6915 - val_categorical_accuracy: 0.3384 - lr: 1.0000e-05\n",
            "Epoch 141/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 0.7704 - categorical_accuracy: 0.9996 - val_loss: 4.6925 - val_categorical_accuracy: 0.3385 - lr: 1.0000e-05\n",
            "Epoch 142/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.7695 - categorical_accuracy: 0.9994 - val_loss: 4.6929 - val_categorical_accuracy: 0.3379 - lr: 1.0000e-05\n",
            "Epoch 143/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.7688 - categorical_accuracy: 1.0000 - val_loss: 4.6927 - val_categorical_accuracy: 0.3390 - lr: 1.0000e-05\n",
            "Epoch 144/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.7691 - categorical_accuracy: 0.9996 - val_loss: 4.6932 - val_categorical_accuracy: 0.3371 - lr: 1.0000e-05\n",
            "Epoch 145/200\n",
            "39/39 [==============================] - 7s 183ms/step - loss: 0.7680 - categorical_accuracy: 0.9996 - val_loss: 4.6926 - val_categorical_accuracy: 0.3374 - lr: 1.0000e-05\n",
            "Epoch 146/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.7668 - categorical_accuracy: 0.9998 - val_loss: 4.6899 - val_categorical_accuracy: 0.3372 - lr: 1.0000e-05\n",
            "Epoch 147/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.7666 - categorical_accuracy: 0.9998 - val_loss: 4.6911 - val_categorical_accuracy: 0.3376 - lr: 1.0000e-05\n",
            "Epoch 148/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.7667 - categorical_accuracy: 1.0000 - val_loss: 4.6923 - val_categorical_accuracy: 0.3374 - lr: 1.0000e-05\n",
            "Epoch 149/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.7661 - categorical_accuracy: 1.0000 - val_loss: 4.6934 - val_categorical_accuracy: 0.3376 - lr: 1.0000e-05\n",
            "Epoch 150/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 0.7652 - categorical_accuracy: 1.0000 - val_loss: 4.6932 - val_categorical_accuracy: 0.3371 - lr: 1.0000e-05\n",
            "Epoch 151/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.7642 - categorical_accuracy: 1.0000 - val_loss: 4.6896 - val_categorical_accuracy: 0.3375 - lr: 1.0000e-05\n",
            "Epoch 152/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 0.7640 - categorical_accuracy: 1.0000 - val_loss: 4.6925 - val_categorical_accuracy: 0.3376 - lr: 1.0000e-05\n",
            "Epoch 153/200\n",
            "39/39 [==============================] - 7s 177ms/step - loss: 0.7640 - categorical_accuracy: 0.9998 - val_loss: 4.6924 - val_categorical_accuracy: 0.3387 - lr: 1.0000e-05\n",
            "Epoch 154/200\n",
            "39/39 [==============================] - 7s 188ms/step - loss: 0.7640 - categorical_accuracy: 0.9994 - val_loss: 4.6905 - val_categorical_accuracy: 0.3379 - lr: 1.0000e-05\n",
            "Epoch 155/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.7626 - categorical_accuracy: 1.0000 - val_loss: 4.6910 - val_categorical_accuracy: 0.3380 - lr: 1.0000e-05\n",
            "Epoch 156/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.7625 - categorical_accuracy: 1.0000 - val_loss: 4.6900 - val_categorical_accuracy: 0.3378 - lr: 1.0000e-05\n",
            "Epoch 157/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7621 - categorical_accuracy: 0.9996 - val_loss: 4.6925 - val_categorical_accuracy: 0.3376 - lr: 1.0000e-05\n",
            "Epoch 158/200\n",
            "39/39 [==============================] - 7s 177ms/step - loss: 0.7617 - categorical_accuracy: 1.0000 - val_loss: 4.6932 - val_categorical_accuracy: 0.3382 - lr: 1.0000e-05\n",
            "Epoch 159/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.7604 - categorical_accuracy: 1.0000 - val_loss: 4.6940 - val_categorical_accuracy: 0.3381 - lr: 1.0000e-05\n",
            "Epoch 160/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7601 - categorical_accuracy: 1.0000 - val_loss: 4.6947 - val_categorical_accuracy: 0.3387 - lr: 1.0000e-05\n",
            "Epoch 161/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7599 - categorical_accuracy: 0.9998 - val_loss: 4.6960 - val_categorical_accuracy: 0.3390 - lr: 1.0000e-06\n",
            "Epoch 162/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.7605 - categorical_accuracy: 0.9998 - val_loss: 4.6974 - val_categorical_accuracy: 0.3384 - lr: 1.0000e-06\n",
            "Epoch 163/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.7598 - categorical_accuracy: 1.0000 - val_loss: 4.6980 - val_categorical_accuracy: 0.3386 - lr: 1.0000e-06\n",
            "Epoch 164/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.7595 - categorical_accuracy: 1.0000 - val_loss: 4.6988 - val_categorical_accuracy: 0.3390 - lr: 1.0000e-06\n",
            "Epoch 165/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.7609 - categorical_accuracy: 1.0000 - val_loss: 4.6992 - val_categorical_accuracy: 0.3388 - lr: 1.0000e-06\n",
            "Epoch 166/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.7604 - categorical_accuracy: 0.9996 - val_loss: 4.6990 - val_categorical_accuracy: 0.3390 - lr: 1.0000e-06\n",
            "Epoch 167/200\n",
            "39/39 [==============================] - 7s 177ms/step - loss: 0.7598 - categorical_accuracy: 0.9996 - val_loss: 4.6988 - val_categorical_accuracy: 0.3385 - lr: 1.0000e-06\n",
            "Epoch 168/200\n",
            "39/39 [==============================] - 7s 177ms/step - loss: 0.7596 - categorical_accuracy: 1.0000 - val_loss: 4.6995 - val_categorical_accuracy: 0.3384 - lr: 1.0000e-06\n",
            "Epoch 169/200\n",
            "39/39 [==============================] - 7s 177ms/step - loss: 0.7592 - categorical_accuracy: 1.0000 - val_loss: 4.6989 - val_categorical_accuracy: 0.3383 - lr: 1.0000e-06\n",
            "Epoch 170/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.7604 - categorical_accuracy: 0.9998 - val_loss: 4.7001 - val_categorical_accuracy: 0.3386 - lr: 1.0000e-06\n",
            "Epoch 171/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7596 - categorical_accuracy: 1.0000 - val_loss: 4.7003 - val_categorical_accuracy: 0.3386 - lr: 1.0000e-06\n",
            "Epoch 172/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.7600 - categorical_accuracy: 0.9994 - val_loss: 4.7005 - val_categorical_accuracy: 0.3388 - lr: 1.0000e-06\n",
            "Epoch 173/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 0.7602 - categorical_accuracy: 0.9998 - val_loss: 4.7019 - val_categorical_accuracy: 0.3387 - lr: 1.0000e-06\n",
            "Epoch 174/200\n",
            "39/39 [==============================] - 7s 187ms/step - loss: 0.7589 - categorical_accuracy: 1.0000 - val_loss: 4.7008 - val_categorical_accuracy: 0.3387 - lr: 1.0000e-06\n",
            "Epoch 175/200\n",
            "39/39 [==============================] - 7s 192ms/step - loss: 0.7600 - categorical_accuracy: 1.0000 - val_loss: 4.7002 - val_categorical_accuracy: 0.3385 - lr: 1.0000e-06\n",
            "Epoch 176/200\n",
            "39/39 [==============================] - 8s 194ms/step - loss: 0.7597 - categorical_accuracy: 0.9998 - val_loss: 4.6994 - val_categorical_accuracy: 0.3387 - lr: 1.0000e-06\n",
            "Epoch 177/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.7587 - categorical_accuracy: 1.0000 - val_loss: 4.7011 - val_categorical_accuracy: 0.3387 - lr: 1.0000e-06\n",
            "Epoch 178/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.7606 - categorical_accuracy: 0.9994 - val_loss: 4.7011 - val_categorical_accuracy: 0.3385 - lr: 1.0000e-06\n",
            "Epoch 179/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.7588 - categorical_accuracy: 0.9998 - val_loss: 4.7013 - val_categorical_accuracy: 0.3390 - lr: 1.0000e-06\n",
            "Epoch 180/200\n",
            "39/39 [==============================] - 7s 182ms/step - loss: 0.7590 - categorical_accuracy: 0.9998 - val_loss: 4.7007 - val_categorical_accuracy: 0.3391 - lr: 1.0000e-06\n",
            "Epoch 181/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.7584 - categorical_accuracy: 0.9998 - val_loss: 4.6998 - val_categorical_accuracy: 0.3385 - lr: 5.0000e-07\n",
            "Epoch 182/200\n",
            "39/39 [==============================] - 8s 197ms/step - loss: 0.7583 - categorical_accuracy: 1.0000 - val_loss: 4.6991 - val_categorical_accuracy: 0.3382 - lr: 5.0000e-07\n",
            "Epoch 183/200\n",
            "39/39 [==============================] - 8s 199ms/step - loss: 0.7594 - categorical_accuracy: 0.9996 - val_loss: 4.6986 - val_categorical_accuracy: 0.3385 - lr: 5.0000e-07\n",
            "Epoch 184/200\n",
            "39/39 [==============================] - 7s 194ms/step - loss: 0.7588 - categorical_accuracy: 1.0000 - val_loss: 4.6994 - val_categorical_accuracy: 0.3381 - lr: 5.0000e-07\n",
            "Epoch 185/200\n",
            "39/39 [==============================] - 7s 191ms/step - loss: 0.7589 - categorical_accuracy: 0.9996 - val_loss: 4.6988 - val_categorical_accuracy: 0.3379 - lr: 5.0000e-07\n",
            "Epoch 186/200\n",
            "39/39 [==============================] - 7s 189ms/step - loss: 0.7596 - categorical_accuracy: 0.9994 - val_loss: 4.6991 - val_categorical_accuracy: 0.3384 - lr: 5.0000e-07\n",
            "Epoch 187/200\n",
            "39/39 [==============================] - 7s 175ms/step - loss: 0.7585 - categorical_accuracy: 1.0000 - val_loss: 4.6989 - val_categorical_accuracy: 0.3388 - lr: 5.0000e-07\n",
            "Epoch 188/200\n",
            "39/39 [==============================] - 7s 180ms/step - loss: 0.7589 - categorical_accuracy: 0.9996 - val_loss: 4.6988 - val_categorical_accuracy: 0.3387 - lr: 5.0000e-07\n",
            "Epoch 189/200\n",
            "39/39 [==============================] - 8s 199ms/step - loss: 0.7576 - categorical_accuracy: 1.0000 - val_loss: 4.6984 - val_categorical_accuracy: 0.3387 - lr: 5.0000e-07\n",
            "Epoch 190/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.7581 - categorical_accuracy: 1.0000 - val_loss: 4.6985 - val_categorical_accuracy: 0.3390 - lr: 5.0000e-07\n",
            "Epoch 191/200\n",
            "39/39 [==============================] - 7s 191ms/step - loss: 0.7583 - categorical_accuracy: 1.0000 - val_loss: 4.6994 - val_categorical_accuracy: 0.3388 - lr: 5.0000e-07\n",
            "Epoch 192/200\n",
            "39/39 [==============================] - 7s 181ms/step - loss: 0.7583 - categorical_accuracy: 1.0000 - val_loss: 4.6980 - val_categorical_accuracy: 0.3391 - lr: 5.0000e-07\n",
            "Epoch 193/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7584 - categorical_accuracy: 0.9998 - val_loss: 4.6993 - val_categorical_accuracy: 0.3387 - lr: 5.0000e-07\n",
            "Epoch 194/200\n",
            "39/39 [==============================] - 7s 186ms/step - loss: 0.7582 - categorical_accuracy: 0.9994 - val_loss: 4.6986 - val_categorical_accuracy: 0.3382 - lr: 5.0000e-07\n",
            "Epoch 195/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.7589 - categorical_accuracy: 0.9998 - val_loss: 4.6992 - val_categorical_accuracy: 0.3389 - lr: 5.0000e-07\n",
            "Epoch 196/200\n",
            "39/39 [==============================] - 7s 179ms/step - loss: 0.7581 - categorical_accuracy: 0.9998 - val_loss: 4.6996 - val_categorical_accuracy: 0.3386 - lr: 5.0000e-07\n",
            "Epoch 197/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.7588 - categorical_accuracy: 0.9998 - val_loss: 4.7002 - val_categorical_accuracy: 0.3387 - lr: 5.0000e-07\n",
            "Epoch 198/200\n",
            "39/39 [==============================] - 7s 185ms/step - loss: 0.7580 - categorical_accuracy: 1.0000 - val_loss: 4.6992 - val_categorical_accuracy: 0.3391 - lr: 5.0000e-07\n",
            "Epoch 199/200\n",
            "39/39 [==============================] - 7s 178ms/step - loss: 0.7582 - categorical_accuracy: 0.9998 - val_loss: 4.6994 - val_categorical_accuracy: 0.3387 - lr: 5.0000e-07\n",
            "Epoch 200/200\n",
            "39/39 [==============================] - 7s 184ms/step - loss: 0.7582 - categorical_accuracy: 1.0000 - val_loss: 4.6992 - val_categorical_accuracy: 0.3384 - lr: 5.0000e-07\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "score = teacher_model.evaluate(data_split.test.examples,\n",
        "                                data_split.test.labels)\n",
        "\n",
        "if isinstance(score, (list, tuple, np.ndarray)):\n",
        "  teacher_accuracy = score[1] * 100.0\n",
        "print(\"The teacher's accuracy is:\", teacher_accuracy)"
      ],
      "metadata": {
        "id": "t57Im5BNojtX",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841801,
          "user_tz": 300,
          "elapsed": 0,
          "user": {
            "displayName": "XXXXX",
            "userId": "00000"
          }
        },
        "outputId": "a438f59e-ee49-4e35-a784-2951895f1bd7"
      },
      "execution_count": 9,
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "313/313 [==============================] - 4s 12ms/step - loss: 4.6992 - categorical_accuracy: 0.3384\n",
            "The teacher's accuracy is: 33.84000062942505\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Estimate Teacher's Accuracy (Isotonic Regression)\n",
        "\n",
        "Use either (normalized) entropy or (normalized) margin to estimate teacher's model accuracy.  We assume that the probability that a label from the teacher\n",
        "is correct is monotone with respect to the margin and enforce this monotonicity\n",
        "in our regression problem via Isotonic Regression."
      ],
      "metadata": {
        "id": "P-sNmGYSPWIZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import dataclasses\n",
        "from typing import Callable, Optional\n",
        "\n",
        "from sklearn.isotonic import IsotonicRegression\n",
        "from sklearn.neighbors import KNeighborsRegressor\n",
        "\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class Advice:\n",
        "  \"\"\"Class for advice used in distillation.\n",
        "\n",
        "  Attributes\n",
        "    ----------\n",
        "    train: tf.Tensor\n",
        "        A tensor containing the advice for the training data.\n",
        "    validation: tf.Tensor\n",
        "        A tensor containing the advice for the validation data.\n",
        "    pretraining: tf.Tensor\n",
        "        A tensor containing the advice for the pretraining data.\n",
        "    test: tf.Tensor\n",
        "        A tensor containing the advice for the test data.\n",
        "  \"\"\"\n",
        "  train: Optional[tf.Tensor] = None\n",
        "  validation: Optional[tf.Tensor] = None\n",
        "  pretraining: Optional[tf.Tensor] = None\n",
        "  test: Optional[tf.Tensor] = None\n",
        "\n",
        "\n",
        "def margin(x: tf.Tensor,\n",
        "           with_logits: Optional[bool] = False,\n",
        "           normalize: Optional[bool] = False) -> tf.Tensor:\n",
        "  \"\"\"Computes the margin of a probability/logit tensor.\"\"\"\n",
        "\n",
        "  if with_logits:\n",
        "    class_probabilities = tf.nn.softmax(x, axis=None, name=None)\n",
        "  else:\n",
        "    class_probabilities = x\n",
        "  a, _ = tf.math.top_k(class_probabilities, k=2)\n",
        "  marg = (tf.gather(a, [0], axis=1) - tf.gather(a, [1], axis=1))\n",
        "\n",
        "  if normalize:\n",
        "    marg = marg/tf.math.reduce_mean(marg)\n",
        "\n",
        "  return marg\n",
        "\n",
        "\n",
        "def entropy(x: tf.Tensor,\n",
        "            with_logits: Optional[bool] = False,\n",
        "            normalize: Optional[bool] = False) -> tf.Tensor:\n",
        "  \"\"\"Computes the entropy of a probability/logit tensor.\"\"\"\n",
        "\n",
        "  if with_logits:\n",
        "    class_probabilities = tf.nn.softmax(x, axis=None, name=None)\n",
        "  else:\n",
        "    class_probabilities = x\n",
        "\n",
        "  ent = tf.keras.losses.categorical_crossentropy(class_probabilities,\n",
        "                                                 class_probabilities)\n",
        "  if normalize:\n",
        "    ent = ent/tf.math.reduce_mean(ent)\n",
        "\n",
        "  return ent\n",
        "\n",
        "\n",
        "def disagreement(q: tf.Tensor, p: tf.Tensor) -> tf.Tensor:\n",
        "  \"\"\"Returns a tensor with ones at the positions where p and q are different.\"\"\"\n",
        "\n",
        "  top_1_disagreement = tf.math.equal(tf.argmax(q, 1), tf.argmax(p, 1))\n",
        "  return tf.cast(top_1_disagreement, tf.float32)\n",
        "\n",
        "\n",
        "def _add_default_advice(data_split: DataSplit,\n",
        "                        train_advice: tf.Tensor) -> Advice:\n",
        "  \"\"\"Adds the default advice for pretraining, validation, and test data.\"\"\"\n",
        "\n",
        "  pretraining = data_split.dataset_a\n",
        "  validation = data_split.validation\n",
        "  test = data_split.test\n",
        "\n",
        "  validation_advice = None\n",
        "  pretraining_advice = None\n",
        "\n",
        "  if validation is not None:\n",
        "    validation_advice = tf.ones(validation.size),\n",
        "    validation_advice = tf.reshape(validation_advice, [-1])\n",
        "  if pretraining is not None:\n",
        "    pretraining_advice = tf.ones(pretraining.size),\n",
        "    pretraining_advice = tf.reshape(pretraining_advice, [-1])\n",
        "\n",
        "  test_advice = tf.zeros(test.size)\n",
        "  test_advice = tf.reshape(test_advice, [-1])\n",
        "\n",
        "  advice_data = Advice(\n",
        "      train=train_advice,\n",
        "      validation=validation_advice,\n",
        "      pretraining=pretraining_advice,\n",
        "      test=test_advice)\n",
        "\n",
        "  return advice_data\n",
        "\n",
        "\n",
        "def vanilla_advice(data_split: DataSplit) -> Advice:\n",
        "  \"\"\"A function that creates advice for vanilla distillation (used for testing).\n",
        "\n",
        "  Args:\n",
        "    data_split: Instance of datasets.UnlabeledDistillationData containing the\n",
        "      training, validation, and test data.\n",
        "\n",
        "  Returns:\n",
        "    Instance of Advice containing advice vectors.\n",
        "  \"\"\"\n",
        "\n",
        "  train = data_split.dataset_b\n",
        "  main_advice = tf.ones(train.size)\n",
        "\n",
        "  return _add_default_advice(data_split, main_advice)\n",
        "  \n",
        "\n",
        "def teacher_accuracy_advice_isotonic(\n",
        "    teacher: tf.keras.Model,\n",
        "    confidence: Callable[[tf.Tensor], tf.Tensor],\n",
        "    data_split: DataSplit) -> Advice:\n",
        "  \"\"\"Creates teacher-accuracy advice for the training dataset (dataset_b).\n",
        "\n",
        "  It uses a confidence measure (e.g., margin, entropy) for the teacher\n",
        "  prediction over a validation dataset and isotonic regression to learn\n",
        "  the mapping from confidence to accuracy.\n",
        "\n",
        "  Args:\n",
        "    teacher: Instance of tf.keras.Model: the teacher model.\n",
        "    confidence: A function that maps soft labels to confidence.\n",
        "    data_split: Instance of DataSplit containing the\n",
        "      training, validation, and test data.\n",
        "\n",
        "  Returns:\n",
        "    Instance of Advice containing advice vectors.\n",
        "  \"\"\"\n",
        "\n",
        "  train = data_split.dataset_b\n",
        "  validation = data_split.validation\n",
        "\n",
        "  # These are the lower bound/upper bounds used in isotonic regression.\n",
        "  # All values predicted should be in the range [min_threshold, max_threshold].\n",
        "  # Since the outputs correspond to the probability that the teacher model is\n",
        "  # correct the range [0.5, 1.] is reasonable.\n",
        "  min_threshold = .5\n",
        "  max_threshold = 1.\n",
        "\n",
        "  # Compute teacher predictions on the validation examples.\n",
        "  teacher_predictions_validation = teacher.predict(validation.examples)\n",
        "\n",
        "  # Compute the teacher margins on the validation examples.\n",
        "  covariate = tf.reshape(confidence(teacher_predictions_validation), (-1, 1))\n",
        "\n",
        "  # Compute the (pointwise) accuracy of the teacher on the validation data.\n",
        "  response = disagreement(teacher_predictions_validation, validation.labels)\n",
        "\n",
        "  # Create a isotonic regressor that maps teacher_margins to teacher_accuracies.\n",
        "  covariate = tf.reshape(covariate, -1)\n",
        "  response = tf.reshape(response, -1)\n",
        "  iso = IsotonicRegression(\n",
        "      y_min=min_threshold, y_max=max_threshold, out_of_bounds='clip')\n",
        "  iso.fit(covariate, response)\n",
        "\n",
        "  # Compute the teacher soft labels on the training dataset.\n",
        "  teacher_predictions_training = teacher.predict(train.examples)\n",
        "\n",
        "  # Compute the teacher confidence on the training dataset.\n",
        "  teacher_confidence = tf.reshape(confidence(teacher_predictions_training), -1)\n",
        "\n",
        "  # Use the knn to predict the accuracy of the teacher on the training examples.\n",
        "  main_advice = iso.predict(teacher_confidence)\n",
        "\n",
        "  return _add_default_advice(data_split, main_advice)\n",
        "\n",
        "\n",
        "def teacher_accuracy_advice_from_entropy(\n",
        "    teacher: tf.keras.Model,\n",
        "    data_split: DataSplit,\n",
        "    normalize: Optional[bool] = False) -> Advice:\n",
        "  \"\"\"Creates teacher-accuracy advice for the training dataset (dataset_b).\n",
        "\n",
        "  Uses entropy as an uncertainty measure for the teacher and isotonic regression\n",
        "  on the validation dataset.\n",
        "\n",
        "  Args:\n",
        "    teacher: Instance of tf.keras.Model: the teacher model.\n",
        "    data_split: Instance of DataSplit containing the\n",
        "      training, validation, and test data.\n",
        "    normalize: True if the entropy should be normalized by the average entropy\n",
        "    over the dataset.\n",
        "\n",
        "  Returns:\n",
        "    Instance of Advice containing advice vectors.\n",
        "  \"\"\"\n",
        "\n",
        "  conf = lambda x: entropy(x, with_logits=False, normalize=normalize)\n",
        "\n",
        "  return teacher_accuracy_advice_isotonic(teacher, conf, data_split)\n",
        "\n",
        "\n",
        "def teacher_accuracy_advice_from_margin(\n",
        "    teacher: tf.keras.Model,\n",
        "    data_split: DataSplit,\n",
        "    normalize: Optional[bool] = False) -> Advice:\n",
        "  \"\"\"Creates teacher-accuracy advice for the training dataset (dataset_b).\n",
        "\n",
        "  Uses margin as an uncertainty measure for the teacher and isotonic regression\n",
        "  on the validation dataset.\n",
        "\n",
        "  Args:\n",
        "    teacher: Instance of tf.keras.Model: the teacher model.\n",
        "    data_split: Instance of DataSplit containing the\n",
        "      training, validation, and test data.\n",
        "    normalize: True if margin should be normalized by the average margin over\n",
        "    the dataset.\n",
        "\n",
        "  Returns:\n",
        "    Instance of Advice containing advice vectors.\n",
        "  \"\"\"\n",
        "\n",
        "  conf = lambda x: margin(x, with_logits=False, normalize=normalize)\n",
        "\n",
        "  return teacher_accuracy_advice_isotonic(teacher, conf, data_split)\n"
      ],
      "metadata": {
        "id": "uC03_ES3PTMr",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841789,
          "user_tz": 300,
          "elapsed": 4,
          "user": {
            "displayName": "XXXX",
            "userId": "000"
          }
        }
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Distillation Function\n",
        "\n",
        "This is a simple wrapper function that takes the student model and the predictions of the teacher and trains the student.  It can take both advice\n",
        "and weights to be used in distillation."
      ],
      "metadata": {
        "id": "PaD7DKb-T4HQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import TypedDict\n",
        "\n",
        "# A simple dict for the various training parameters of the distill() function,\n",
        "class TrainingParameters(TypedDict):\n",
        "  with_soft_labels: Optional[bool]\n",
        "  epochs: Optional[int]\n",
        "  batch_size: Optional[int]\n",
        "  data_augmentation: Optional[str]\n",
        "\n",
        "def distill(model: tf.keras.Model,\n",
        "             teacher_predictions: tf.Tensor,\n",
        "             data_split: DataSplit,\n",
        "             params: TrainingParameters,\n",
        "             advice_data: Optional[Advice] = None,\n",
        "             weight_data: Optional[Advice] = None) ->...:\n",
        "  \"\"\"A function that implements Distillation.\n",
        "\n",
        "  Args:\n",
        "    model: Instance of tf.keras.Model.\n",
        "    teacher_predictions: Instance of tf.Tensor containing teacher's predictions.\n",
        "    data_split: Instance of DataSplit.\n",
        "    params: Instance of TrainingParameters typed dict that contains the\n",
        "      following parameters:\n",
        "      epochs - The number of training epochs.\n",
        "      batch_size - The batch size used for training.\n",
        "      with_soft_labels - If true use the soft-labels provided by the teacher.\n",
        "      data_augmentation - One of None, 'offline', 'online'.\n",
        "    advice_data: Instance of Advice to be used to as advice feature.\n",
        "    weight_data: Instance of Advice to be used as weights for the training data.\n",
        "\n",
        "  Returns:\n",
        "    The achieved accuracy, the test-accuracy trajectory.\n",
        "  \"\"\"\n",
        "\n",
        "  pretraining = data_split.dataset_a\n",
        "  train = data_split.dataset_b\n",
        "  validation = data_split.validation\n",
        "  test = data_split.test\n",
        "\n",
        "  # Default values for the parameters of the params dict.\n",
        "  default_batch_size = 128\n",
        "  default_epochs = 90\n",
        "  default_with_soft_labels = True\n",
        "  default_data_augmentation = None\n",
        "\n",
        "  if params is not None:\n",
        "    batch_size = params.get('batch_size', default_batch_size)\n",
        "    epochs = params.get('epochs', default_epochs)\n",
        "    with_soft_labels = params.get('with_soft_labels', default_with_soft_labels)\n",
        "    data_augmentation = params.get('data_augmentation',\n",
        "                                   default_data_augmentation)\n",
        "\n",
        "  # We are doing soft-distillation by default.\n",
        "  if with_soft_labels:\n",
        "    targets = teacher_predictions\n",
        "  else:\n",
        "    arg_max_indices = tf.argmax(teacher_predictions, -1)\n",
        "    targets = tf.keras.utils.to_categorical(arg_max_indices,\n",
        "                                            len(teacher_predictions[0]))\n",
        "\n",
        "  train_data = train.examples\n",
        "  train_labels = tf.reshape(targets, tf.shape(train.labels))\n",
        "\n",
        "  test_data = test.examples\n",
        "  test_labels = test.labels\n",
        "\n",
        "  # Include pretraining data in the training dataset\n",
        "  if pretraining is not None and pretraining.trainable:\n",
        "    train_data = tf.concat((pretraining.examples, train_data), axis=0)\n",
        "    train_labels = tf.concat((pretraining.labels, train_labels), axis=0)\n",
        "\n",
        "  # Include validation data in the training dataset\n",
        "  if validation is not None and validation.trainable:\n",
        "    train_data = tf.concat((train_data, validation.examples), axis=0)\n",
        "    train_labels = tf.concat((train_labels, validation.labels), axis=0)\n",
        "\n",
        "  # Fill advice for validation and pretraining with default values.\n",
        "  if advice_data is not None:\n",
        "\n",
        "    full_advice = advice_data.train\n",
        "\n",
        "    if pretraining is not None and pretraining.trainable:\n",
        "      full_advice = tf.concat((advice_data.pretraining, full_advice), axis=0)\n",
        "\n",
        "    if validation is not None and validation.trainable:\n",
        "      full_advice = tf.concat((full_advice, advice_data.validation), axis=0)\n",
        "\n",
        "    if len(tf.shape(full_advice)) == 1:\n",
        "      print(\"reshape\")\n",
        "      train_advice = tf.reshape(full_advice, shape=(-1, 1))\n",
        "    else:\n",
        "      train_advice = full_advice\n",
        "\n",
        "    train_labels = tf.concat((train_labels, train_advice), axis=1)\n",
        "\n",
        "    # We need dummy test advice in order for test data to be consistent\n",
        "    # with the training data (we should always use zero for this).\n",
        "\n",
        "    if len(tf.shape(advice_data.test)) == 1:\n",
        "      test_advice = tf.reshape(advice_data.test, shape=(-1, 1))\n",
        "    else:\n",
        "      test_advice = advice_data.test\n",
        "    test_labels = tf.concat((test.labels, test_advice), axis=1)\n",
        "\n",
        "  # Use no weights by default.\n",
        "  weights = None\n",
        "  if weight_data is not None:\n",
        "\n",
        "    full_weights = weight_data.train\n",
        "\n",
        "    if pretraining is not None and pretraining.trainable:\n",
        "      full_weights = tf.concat((weight_data.pretraining, full_weights), axis=0)\n",
        "\n",
        "    if validation is not None and validation.trainable:\n",
        "      full_weights = tf.concat((full_weights, weight_data.validation), axis=0)\n",
        "\n",
        "    weights = tf.reshape(full_weights, [-1])\n",
        "\n",
        "  score = model.evaluate(test.examples, test_labels)\n",
        "  if isinstance(score, (list, tuple, np.ndarray)):\n",
        "    score = score[1]\n",
        "\n",
        "  history = train_model(\n",
        "      model,\n",
        "      train_data,\n",
        "      train_labels,\n",
        "      test_data,\n",
        "      test_labels,\n",
        "      data_augmentation=data_augmentation,\n",
        "      weights=weights,\n",
        "      epochs=epochs,\n",
        "      batch_size=batch_size)\n",
        "\n",
        "  trajectory = np.array(history.history['val_categorical_accuracy'])\n",
        "  trajectory = np.insert(trajectory, 0, score)\n",
        "  trajectory *= 100.0\n",
        "  result = np.max(trajectory)\n",
        "\n",
        "  return (result, trajectory)\n"
      ],
      "metadata": {
        "id": "4ru3NlfFTz-T",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841789,
          "user_tz": 300,
          "elapsed": 4,
          "user": {
            "displayName": "XXXXX",
            "userId": "000000"
          }
        }
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Randomize dataset B and create validation data."
      ],
      "metadata": {
        "id": "GcAbd8Woyngd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Create a small validation dataset.\n",
        "if _RANDOMIZE_DATASET: \n",
        "  data_split.dataset_b.shuffle(_SEED)\n",
        "\n",
        "# Round the size of validation to be multiple of batch_size\n",
        "rounded_validation_dataset_size = _BATCH_SIZE * round(\n",
        "    _VALIDATION_DATASET_SIZE / _BATCH_SIZE)\n",
        "\n",
        "train, validation = split(data_split.dataset_b, rounded_validation_dataset_size)\n",
        "\n",
        "train.trainable = True\n",
        "\n",
        "# We will train the student model on the validation dataset.\n",
        "validation.trainable = _TRAIN_ON_VALIDATION\n",
        "\n",
        "# Update the data_split datasets with the new training and validation data.\n",
        "data_split.dataset_b = train\n",
        "data_split.validation = validation\n",
        "\n",
        "# Predict the Labels on the unlabeled dataset using the teacher model\n",
        "teacher_predictions = teacher_model.predict(data_split.dataset_b.examples)"
      ],
      "metadata": {
        "id": "RtjaN3i7yk2Q",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841802,
          "user_tz": 300,
          "elapsed": 0,
          "user": {
            "displayName": "XXXX",
            "userId": "00000"
          }
        },
        "outputId": "3f626d95-9cb9-463d-92a1-651e82618e6d"
      },
      "execution_count": 12,
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1391/1391 [==============================] - 13s 8ms/step\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Vanilla Distillation"
      ],
      "metadata": {
        "id": "rt4LsGyCTFjB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Load a fresh student model with the standard cross-entropy loss.\n",
        "\n",
        "vanilla_student = load_model( _STUDENT_MODEL,\n",
        "                              num_classes=num_classes,\n",
        "                              optimizer_name=_STUDENT_OPTIMIZER,\n",
        "                              width_multiplier=1,\n",
        "                              depth_multiplier=_STUDENT_MOBILENET_DEPTH_MULTIPLIER,\n",
        "                              resnet_depth=_STUDENT_RESNET_DEPTH)\n",
        "\n"
      ],
      "metadata": {
        "id": "bqjYRwwVfwV8",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674703841870,
          "user_tz": 300,
          "elapsed": 85,
          "user": {
            "displayName": "XXXX",
            "userId": "0000"
          }
        }
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Define the training parameters for the distill() funcion. \n",
        "training_params = {\n",
        "    'epochs': _STUDENT_EPOCHS,\n",
        "    'with_soft_labels': _WITH_SOFT_LABELS,\n",
        "    'batch_size': _BATCH_SIZE,\n",
        "    'data_augmentation': _STUDENT_DISTILLATION_DATA_AUGMENTATION\n",
        "}\n",
        "\n",
        "\n",
        "\n",
        "# Train the student model using distillation.\n",
        "distill(model=vanilla_student,\n",
        "        teacher_predictions=teacher_predictions,\n",
        "        data_split=data_split,\n",
        "        advice_data=None,\n",
        "        params=training_params)"
      ],
      "metadata": {
        "id": "qjmn101if7k3",
        "outputId": "63e5547b-9ffe-40da-f82b-91c246897fb2",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674705955384,
          "user_tz": 300,
          "elapsed": 2113512,
          "user": {
            "displayName": "XXXXXX",
            "userId": "00000"
          }
        }
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "313/313 [==============================] - 4s 7ms/step - loss: 785.9287 - categorical_accuracy: 0.0099\n",
            "Epoch 1/200\n",
            "390/390 [==============================] - 53s 54ms/step - loss: 4.3412 - categorical_accuracy: 0.1888 - val_loss: 4.2619 - val_categorical_accuracy: 0.1589 - lr: 0.0010\n",
            "Epoch 2/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 3.3923 - categorical_accuracy: 0.3210 - val_loss: 3.6788 - val_categorical_accuracy: 0.2487 - lr: 0.0010\n",
            "Epoch 3/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 3.0082 - categorical_accuracy: 0.3858 - val_loss: 3.7917 - val_categorical_accuracy: 0.2395 - lr: 0.0010\n",
            "Epoch 4/200\n",
            "390/390 [==============================] - 24s 60ms/step - loss: 2.7920 - categorical_accuracy: 0.4244 - val_loss: 3.5636 - val_categorical_accuracy: 0.2706 - lr: 0.0010\n",
            "Epoch 5/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6495 - categorical_accuracy: 0.4514 - val_loss: 3.5210 - val_categorical_accuracy: 0.2918 - lr: 0.0010\n",
            "Epoch 6/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 2.5267 - categorical_accuracy: 0.4771 - val_loss: 3.3725 - val_categorical_accuracy: 0.2977 - lr: 0.0010\n",
            "Epoch 7/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.4352 - categorical_accuracy: 0.4990 - val_loss: 3.6924 - val_categorical_accuracy: 0.2813 - lr: 0.0010\n",
            "Epoch 8/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 2.4051 - categorical_accuracy: 0.4974 - val_loss: 3.3533 - val_categorical_accuracy: 0.3164 - lr: 0.0010\n",
            "Epoch 9/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 2.3300 - categorical_accuracy: 0.5168 - val_loss: 3.3369 - val_categorical_accuracy: 0.3150 - lr: 0.0010\n",
            "Epoch 10/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.2744 - categorical_accuracy: 0.5308 - val_loss: 3.5184 - val_categorical_accuracy: 0.3012 - lr: 0.0010\n",
            "Epoch 11/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 2.2258 - categorical_accuracy: 0.5444 - val_loss: 3.4994 - val_categorical_accuracy: 0.3067 - lr: 0.0010\n",
            "Epoch 12/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.1892 - categorical_accuracy: 0.5527 - val_loss: 3.3165 - val_categorical_accuracy: 0.3292 - lr: 0.0010\n",
            "Epoch 13/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.1477 - categorical_accuracy: 0.5641 - val_loss: 3.3747 - val_categorical_accuracy: 0.3099 - lr: 0.0010\n",
            "Epoch 14/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 2.1767 - categorical_accuracy: 0.5510 - val_loss: 3.4357 - val_categorical_accuracy: 0.3165 - lr: 0.0010\n",
            "Epoch 15/200\n",
            "390/390 [==============================] - 22s 57ms/step - loss: 2.1049 - categorical_accuracy: 0.5707 - val_loss: 3.4654 - val_categorical_accuracy: 0.3122 - lr: 0.0010\n",
            "Epoch 16/200\n",
            "390/390 [==============================] - 22s 57ms/step - loss: 2.0701 - categorical_accuracy: 0.5838 - val_loss: 3.4860 - val_categorical_accuracy: 0.3101 - lr: 0.0010\n",
            "Epoch 17/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.0506 - categorical_accuracy: 0.5915 - val_loss: 3.4252 - val_categorical_accuracy: 0.3232 - lr: 0.0010\n",
            "Epoch 18/200\n",
            "390/390 [==============================] - 23s 60ms/step - loss: 2.0231 - categorical_accuracy: 0.5962 - val_loss: 3.5500 - val_categorical_accuracy: 0.3187 - lr: 0.0010\n",
            "Epoch 19/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.0728 - categorical_accuracy: 0.5751 - val_loss: 3.5921 - val_categorical_accuracy: 0.3054 - lr: 0.0010\n",
            "Epoch 20/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 2.0045 - categorical_accuracy: 0.6008 - val_loss: 3.6075 - val_categorical_accuracy: 0.3134 - lr: 0.0010\n",
            "Epoch 21/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.9712 - categorical_accuracy: 0.6149 - val_loss: 3.4664 - val_categorical_accuracy: 0.3197 - lr: 0.0010\n",
            "Epoch 22/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.9576 - categorical_accuracy: 0.6177 - val_loss: 3.5923 - val_categorical_accuracy: 0.3119 - lr: 0.0010\n",
            "Epoch 23/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.0110 - categorical_accuracy: 0.5928 - val_loss: 3.4127 - val_categorical_accuracy: 0.3265 - lr: 0.0010\n",
            "Epoch 24/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.9460 - categorical_accuracy: 0.6178 - val_loss: 3.5138 - val_categorical_accuracy: 0.3175 - lr: 0.0010\n",
            "Epoch 25/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.9714 - categorical_accuracy: 0.6050 - val_loss: 3.5388 - val_categorical_accuracy: 0.3205 - lr: 0.0010\n",
            "Epoch 26/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.9353 - categorical_accuracy: 0.6171 - val_loss: 3.5288 - val_categorical_accuracy: 0.3172 - lr: 0.0010\n",
            "Epoch 27/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.9255 - categorical_accuracy: 0.6214 - val_loss: 3.4279 - val_categorical_accuracy: 0.3298 - lr: 0.0010\n",
            "Epoch 28/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.9200 - categorical_accuracy: 0.6214 - val_loss: 3.4713 - val_categorical_accuracy: 0.3216 - lr: 0.0010\n",
            "Epoch 29/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.8963 - categorical_accuracy: 0.6303 - val_loss: 3.5599 - val_categorical_accuracy: 0.3281 - lr: 0.0010\n",
            "Epoch 30/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 1.8728 - categorical_accuracy: 0.6404 - val_loss: 3.4393 - val_categorical_accuracy: 0.3365 - lr: 0.0010\n",
            "Epoch 31/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.9043 - categorical_accuracy: 0.6264 - val_loss: 3.7360 - val_categorical_accuracy: 0.3022 - lr: 0.0010\n",
            "Epoch 32/200\n",
            "390/390 [==============================] - 22s 57ms/step - loss: 1.8652 - categorical_accuracy: 0.6428 - val_loss: 3.4302 - val_categorical_accuracy: 0.3401 - lr: 0.0010\n",
            "Epoch 33/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 1.8819 - categorical_accuracy: 0.6322 - val_loss: 3.4079 - val_categorical_accuracy: 0.3339 - lr: 0.0010\n",
            "Epoch 34/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.8540 - categorical_accuracy: 0.6444 - val_loss: 3.6362 - val_categorical_accuracy: 0.3170 - lr: 0.0010\n",
            "Epoch 35/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.8588 - categorical_accuracy: 0.6434 - val_loss: 3.5857 - val_categorical_accuracy: 0.3282 - lr: 0.0010\n",
            "Epoch 36/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.8459 - categorical_accuracy: 0.6453 - val_loss: 3.5872 - val_categorical_accuracy: 0.3245 - lr: 0.0010\n",
            "Epoch 37/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.8393 - categorical_accuracy: 0.6519 - val_loss: 3.6664 - val_categorical_accuracy: 0.3160 - lr: 0.0010\n",
            "Epoch 38/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.8294 - categorical_accuracy: 0.6535 - val_loss: 3.6346 - val_categorical_accuracy: 0.3157 - lr: 0.0010\n",
            "Epoch 39/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 1.8091 - categorical_accuracy: 0.6605 - val_loss: 3.4683 - val_categorical_accuracy: 0.3290 - lr: 0.0010\n",
            "Epoch 40/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.8011 - categorical_accuracy: 0.6666 - val_loss: 3.5935 - val_categorical_accuracy: 0.3238 - lr: 0.0010\n",
            "Epoch 41/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.8274 - categorical_accuracy: 0.6520 - val_loss: 3.4697 - val_categorical_accuracy: 0.3309 - lr: 0.0010\n",
            "Epoch 42/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.7967 - categorical_accuracy: 0.6666 - val_loss: 3.5106 - val_categorical_accuracy: 0.3233 - lr: 0.0010\n",
            "Epoch 43/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.8079 - categorical_accuracy: 0.6601 - val_loss: 3.5887 - val_categorical_accuracy: 0.3302 - lr: 0.0010\n",
            "Epoch 44/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.7846 - categorical_accuracy: 0.6690 - val_loss: 3.4187 - val_categorical_accuracy: 0.3306 - lr: 0.0010\n",
            "Epoch 45/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.7845 - categorical_accuracy: 0.6655 - val_loss: 3.5654 - val_categorical_accuracy: 0.3231 - lr: 0.0010\n",
            "Epoch 46/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.7757 - categorical_accuracy: 0.6742 - val_loss: 3.8087 - val_categorical_accuracy: 0.3180 - lr: 0.0010\n",
            "Epoch 47/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 1.7606 - categorical_accuracy: 0.6769 - val_loss: 3.5795 - val_categorical_accuracy: 0.3268 - lr: 0.0010\n",
            "Epoch 48/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.7941 - categorical_accuracy: 0.6662 - val_loss: 3.6418 - val_categorical_accuracy: 0.3222 - lr: 0.0010\n",
            "Epoch 49/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.7600 - categorical_accuracy: 0.6786 - val_loss: 3.4948 - val_categorical_accuracy: 0.3307 - lr: 0.0010\n",
            "Epoch 50/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.7636 - categorical_accuracy: 0.6749 - val_loss: 3.7354 - val_categorical_accuracy: 0.3144 - lr: 0.0010\n",
            "Epoch 51/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.7384 - categorical_accuracy: 0.6871 - val_loss: 3.5184 - val_categorical_accuracy: 0.3327 - lr: 0.0010\n",
            "Epoch 52/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.7657 - categorical_accuracy: 0.6743 - val_loss: 3.6319 - val_categorical_accuracy: 0.3227 - lr: 0.0010\n",
            "Epoch 53/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.7445 - categorical_accuracy: 0.6852 - val_loss: 3.5615 - val_categorical_accuracy: 0.3270 - lr: 0.0010\n",
            "Epoch 54/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.7488 - categorical_accuracy: 0.6811 - val_loss: 3.6335 - val_categorical_accuracy: 0.3236 - lr: 0.0010\n",
            "Epoch 55/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.7310 - categorical_accuracy: 0.6898 - val_loss: 3.6144 - val_categorical_accuracy: 0.3289 - lr: 0.0010\n",
            "Epoch 56/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.7451 - categorical_accuracy: 0.6837 - val_loss: 3.6558 - val_categorical_accuracy: 0.3286 - lr: 0.0010\n",
            "Epoch 57/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.7242 - categorical_accuracy: 0.6930 - val_loss: 3.6332 - val_categorical_accuracy: 0.3312 - lr: 0.0010\n",
            "Epoch 58/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.7113 - categorical_accuracy: 0.6985 - val_loss: 3.5761 - val_categorical_accuracy: 0.3301 - lr: 0.0010\n",
            "Epoch 59/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.6996 - categorical_accuracy: 0.7006 - val_loss: 3.6432 - val_categorical_accuracy: 0.3221 - lr: 0.0010\n",
            "Epoch 60/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 1.6954 - categorical_accuracy: 0.7034 - val_loss: 3.6952 - val_categorical_accuracy: 0.3223 - lr: 0.0010\n",
            "Epoch 61/200\n",
            "390/390 [==============================] - 23s 58ms/step - loss: 1.6945 - categorical_accuracy: 0.7044 - val_loss: 3.5314 - val_categorical_accuracy: 0.3326 - lr: 0.0010\n",
            "Epoch 62/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.7540 - categorical_accuracy: 0.6763 - val_loss: 3.5225 - val_categorical_accuracy: 0.3273 - lr: 0.0010\n",
            "Epoch 63/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.7086 - categorical_accuracy: 0.6970 - val_loss: 3.7030 - val_categorical_accuracy: 0.3105 - lr: 0.0010\n",
            "Epoch 64/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.7044 - categorical_accuracy: 0.6996 - val_loss: 3.5547 - val_categorical_accuracy: 0.3291 - lr: 0.0010\n",
            "Epoch 65/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.6809 - categorical_accuracy: 0.7078 - val_loss: 3.5754 - val_categorical_accuracy: 0.3275 - lr: 0.0010\n",
            "Epoch 66/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.7244 - categorical_accuracy: 0.6889 - val_loss: 3.6452 - val_categorical_accuracy: 0.3252 - lr: 0.0010\n",
            "Epoch 67/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.6929 - categorical_accuracy: 0.7039 - val_loss: 3.5322 - val_categorical_accuracy: 0.3355 - lr: 0.0010\n",
            "Epoch 68/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.6890 - categorical_accuracy: 0.7039 - val_loss: 3.5760 - val_categorical_accuracy: 0.3392 - lr: 0.0010\n",
            "Epoch 69/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.7071 - categorical_accuracy: 0.6945 - val_loss: 3.5952 - val_categorical_accuracy: 0.3225 - lr: 0.0010\n",
            "Epoch 70/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.6858 - categorical_accuracy: 0.7054 - val_loss: 3.6345 - val_categorical_accuracy: 0.3218 - lr: 0.0010\n",
            "Epoch 71/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.6940 - categorical_accuracy: 0.6995 - val_loss: 3.5864 - val_categorical_accuracy: 0.3305 - lr: 0.0010\n",
            "Epoch 72/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.6830 - categorical_accuracy: 0.7068 - val_loss: 3.6353 - val_categorical_accuracy: 0.3327 - lr: 0.0010\n",
            "Epoch 73/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.6637 - categorical_accuracy: 0.7154 - val_loss: 3.7403 - val_categorical_accuracy: 0.3257 - lr: 0.0010\n",
            "Epoch 74/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.6543 - categorical_accuracy: 0.7169 - val_loss: 3.6345 - val_categorical_accuracy: 0.3276 - lr: 0.0010\n",
            "Epoch 75/200\n",
            "390/390 [==============================] - 23s 58ms/step - loss: 1.7018 - categorical_accuracy: 0.6960 - val_loss: 3.5827 - val_categorical_accuracy: 0.3281 - lr: 0.0010\n",
            "Epoch 76/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.6643 - categorical_accuracy: 0.7142 - val_loss: 3.4788 - val_categorical_accuracy: 0.3312 - lr: 0.0010\n",
            "Epoch 77/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.6763 - categorical_accuracy: 0.7073 - val_loss: 3.6736 - val_categorical_accuracy: 0.3171 - lr: 0.0010\n",
            "Epoch 78/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.6639 - categorical_accuracy: 0.7131 - val_loss: 3.6521 - val_categorical_accuracy: 0.3193 - lr: 0.0010\n",
            "Epoch 79/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.6771 - categorical_accuracy: 0.7072 - val_loss: 3.5697 - val_categorical_accuracy: 0.3308 - lr: 0.0010\n",
            "Epoch 80/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 1.6537 - categorical_accuracy: 0.7192 - val_loss: 3.6632 - val_categorical_accuracy: 0.3273 - lr: 0.0010\n",
            "Epoch 81/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.5265 - categorical_accuracy: 0.7760 - val_loss: 3.4128 - val_categorical_accuracy: 0.3520 - lr: 1.0000e-04\n",
            "Epoch 82/200\n",
            "390/390 [==============================] - 22s 55ms/step - loss: 1.4478 - categorical_accuracy: 0.8141 - val_loss: 3.4624 - val_categorical_accuracy: 0.3512 - lr: 1.0000e-04\n",
            "Epoch 83/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.4508 - categorical_accuracy: 0.8087 - val_loss: 3.4841 - val_categorical_accuracy: 0.3490 - lr: 1.0000e-04\n",
            "Epoch 84/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 1.4162 - categorical_accuracy: 0.8251 - val_loss: 3.4688 - val_categorical_accuracy: 0.3507 - lr: 1.0000e-04\n",
            "Epoch 85/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.4043 - categorical_accuracy: 0.8304 - val_loss: 3.4673 - val_categorical_accuracy: 0.3511 - lr: 1.0000e-04\n",
            "Epoch 86/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.4164 - categorical_accuracy: 0.8190 - val_loss: 3.4924 - val_categorical_accuracy: 0.3478 - lr: 1.0000e-04\n",
            "Epoch 87/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.3857 - categorical_accuracy: 0.8361 - val_loss: 3.5015 - val_categorical_accuracy: 0.3500 - lr: 1.0000e-04\n",
            "Epoch 88/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3750 - categorical_accuracy: 0.8380 - val_loss: 3.4862 - val_categorical_accuracy: 0.3506 - lr: 1.0000e-04\n",
            "Epoch 89/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3688 - categorical_accuracy: 0.8389 - val_loss: 3.4892 - val_categorical_accuracy: 0.3501 - lr: 1.0000e-04\n",
            "Epoch 90/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3866 - categorical_accuracy: 0.8288 - val_loss: 3.5245 - val_categorical_accuracy: 0.3466 - lr: 1.0000e-04\n",
            "Epoch 91/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3729 - categorical_accuracy: 0.8337 - val_loss: 3.5056 - val_categorical_accuracy: 0.3515 - lr: 1.0000e-04\n",
            "Epoch 92/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3546 - categorical_accuracy: 0.8444 - val_loss: 3.4844 - val_categorical_accuracy: 0.3536 - lr: 1.0000e-04\n",
            "Epoch 93/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3600 - categorical_accuracy: 0.8371 - val_loss: 3.5183 - val_categorical_accuracy: 0.3447 - lr: 1.0000e-04\n",
            "Epoch 94/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3411 - categorical_accuracy: 0.8451 - val_loss: 3.4983 - val_categorical_accuracy: 0.3506 - lr: 1.0000e-04\n",
            "Epoch 95/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3345 - categorical_accuracy: 0.8486 - val_loss: 3.5205 - val_categorical_accuracy: 0.3497 - lr: 1.0000e-04\n",
            "Epoch 96/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3467 - categorical_accuracy: 0.8408 - val_loss: 3.4881 - val_categorical_accuracy: 0.3517 - lr: 1.0000e-04\n",
            "Epoch 97/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3239 - categorical_accuracy: 0.8527 - val_loss: 3.4780 - val_categorical_accuracy: 0.3511 - lr: 1.0000e-04\n",
            "Epoch 98/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3345 - categorical_accuracy: 0.8409 - val_loss: 3.4712 - val_categorical_accuracy: 0.3527 - lr: 1.0000e-04\n",
            "Epoch 99/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3256 - categorical_accuracy: 0.8454 - val_loss: 3.4803 - val_categorical_accuracy: 0.3504 - lr: 1.0000e-04\n",
            "Epoch 100/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3126 - categorical_accuracy: 0.8518 - val_loss: 3.4879 - val_categorical_accuracy: 0.3472 - lr: 1.0000e-04\n",
            "Epoch 101/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.3083 - categorical_accuracy: 0.8524 - val_loss: 3.4642 - val_categorical_accuracy: 0.3534 - lr: 1.0000e-04\n",
            "Epoch 102/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3161 - categorical_accuracy: 0.8470 - val_loss: 3.4819 - val_categorical_accuracy: 0.3495 - lr: 1.0000e-04\n",
            "Epoch 103/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2970 - categorical_accuracy: 0.8564 - val_loss: 3.5013 - val_categorical_accuracy: 0.3478 - lr: 1.0000e-04\n",
            "Epoch 104/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2919 - categorical_accuracy: 0.8571 - val_loss: 3.4936 - val_categorical_accuracy: 0.3498 - lr: 1.0000e-04\n",
            "Epoch 105/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2882 - categorical_accuracy: 0.8588 - val_loss: 3.4839 - val_categorical_accuracy: 0.3482 - lr: 1.0000e-04\n",
            "Epoch 106/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.3037 - categorical_accuracy: 0.8485 - val_loss: 3.5274 - val_categorical_accuracy: 0.3450 - lr: 1.0000e-04\n",
            "Epoch 107/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2821 - categorical_accuracy: 0.8589 - val_loss: 3.4799 - val_categorical_accuracy: 0.3508 - lr: 1.0000e-04\n",
            "Epoch 108/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2774 - categorical_accuracy: 0.8600 - val_loss: 3.5080 - val_categorical_accuracy: 0.3436 - lr: 1.0000e-04\n",
            "Epoch 109/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2930 - categorical_accuracy: 0.8499 - val_loss: 3.4823 - val_categorical_accuracy: 0.3482 - lr: 1.0000e-04\n",
            "Epoch 110/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2725 - categorical_accuracy: 0.8606 - val_loss: 3.4638 - val_categorical_accuracy: 0.3467 - lr: 1.0000e-04\n",
            "Epoch 111/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2683 - categorical_accuracy: 0.8600 - val_loss: 3.4701 - val_categorical_accuracy: 0.3516 - lr: 1.0000e-04\n",
            "Epoch 112/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2648 - categorical_accuracy: 0.8605 - val_loss: 3.4986 - val_categorical_accuracy: 0.3458 - lr: 1.0000e-04\n",
            "Epoch 113/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2839 - categorical_accuracy: 0.8490 - val_loss: 3.4731 - val_categorical_accuracy: 0.3506 - lr: 1.0000e-04\n",
            "Epoch 114/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2750 - categorical_accuracy: 0.8527 - val_loss: 3.4976 - val_categorical_accuracy: 0.3432 - lr: 1.0000e-04\n",
            "Epoch 115/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2602 - categorical_accuracy: 0.8614 - val_loss: 3.4866 - val_categorical_accuracy: 0.3468 - lr: 1.0000e-04\n",
            "Epoch 116/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2551 - categorical_accuracy: 0.8629 - val_loss: 3.5080 - val_categorical_accuracy: 0.3459 - lr: 1.0000e-04\n",
            "Epoch 117/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2516 - categorical_accuracy: 0.8643 - val_loss: 3.4597 - val_categorical_accuracy: 0.3482 - lr: 1.0000e-04\n",
            "Epoch 118/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2685 - categorical_accuracy: 0.8529 - val_loss: 3.4904 - val_categorical_accuracy: 0.3481 - lr: 1.0000e-04\n",
            "Epoch 119/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2451 - categorical_accuracy: 0.8638 - val_loss: 3.4568 - val_categorical_accuracy: 0.3486 - lr: 1.0000e-04\n",
            "Epoch 120/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 1.2618 - categorical_accuracy: 0.8555 - val_loss: 3.4920 - val_categorical_accuracy: 0.3517 - lr: 1.0000e-04\n",
            "Epoch 121/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2443 - categorical_accuracy: 0.8653 - val_loss: 3.4761 - val_categorical_accuracy: 0.3489 - lr: 1.0000e-05\n",
            "Epoch 122/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2378 - categorical_accuracy: 0.8677 - val_loss: 3.4743 - val_categorical_accuracy: 0.3503 - lr: 1.0000e-05\n",
            "Epoch 123/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2296 - categorical_accuracy: 0.8742 - val_loss: 3.4728 - val_categorical_accuracy: 0.3515 - lr: 1.0000e-05\n",
            "Epoch 124/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2356 - categorical_accuracy: 0.8700 - val_loss: 3.4777 - val_categorical_accuracy: 0.3490 - lr: 1.0000e-05\n",
            "Epoch 125/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2244 - categorical_accuracy: 0.8748 - val_loss: 3.4679 - val_categorical_accuracy: 0.3523 - lr: 1.0000e-05\n",
            "Epoch 126/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2233 - categorical_accuracy: 0.8764 - val_loss: 3.4744 - val_categorical_accuracy: 0.3519 - lr: 1.0000e-05\n",
            "Epoch 127/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2329 - categorical_accuracy: 0.8696 - val_loss: 3.4734 - val_categorical_accuracy: 0.3511 - lr: 1.0000e-05\n",
            "Epoch 128/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2308 - categorical_accuracy: 0.8705 - val_loss: 3.4671 - val_categorical_accuracy: 0.3518 - lr: 1.0000e-05\n",
            "Epoch 129/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2300 - categorical_accuracy: 0.8700 - val_loss: 3.4677 - val_categorical_accuracy: 0.3495 - lr: 1.0000e-05\n",
            "Epoch 130/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2214 - categorical_accuracy: 0.8767 - val_loss: 3.4683 - val_categorical_accuracy: 0.3529 - lr: 1.0000e-05\n",
            "Epoch 131/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2200 - categorical_accuracy: 0.8759 - val_loss: 3.4701 - val_categorical_accuracy: 0.3512 - lr: 1.0000e-05\n",
            "Epoch 132/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2298 - categorical_accuracy: 0.8691 - val_loss: 3.4702 - val_categorical_accuracy: 0.3511 - lr: 1.0000e-05\n",
            "Epoch 133/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2274 - categorical_accuracy: 0.8730 - val_loss: 3.4732 - val_categorical_accuracy: 0.3507 - lr: 1.0000e-05\n",
            "Epoch 134/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2200 - categorical_accuracy: 0.8765 - val_loss: 3.4684 - val_categorical_accuracy: 0.3520 - lr: 1.0000e-05\n",
            "Epoch 135/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2167 - categorical_accuracy: 0.8791 - val_loss: 3.4674 - val_categorical_accuracy: 0.3509 - lr: 1.0000e-05\n",
            "Epoch 136/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2169 - categorical_accuracy: 0.8792 - val_loss: 3.4690 - val_categorical_accuracy: 0.3504 - lr: 1.0000e-05\n",
            "Epoch 137/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2276 - categorical_accuracy: 0.8714 - val_loss: 3.4731 - val_categorical_accuracy: 0.3507 - lr: 1.0000e-05\n",
            "Epoch 138/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2162 - categorical_accuracy: 0.8779 - val_loss: 3.4686 - val_categorical_accuracy: 0.3503 - lr: 1.0000e-05\n",
            "Epoch 139/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2160 - categorical_accuracy: 0.8781 - val_loss: 3.4698 - val_categorical_accuracy: 0.3488 - lr: 1.0000e-05\n",
            "Epoch 140/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2258 - categorical_accuracy: 0.8708 - val_loss: 3.4701 - val_categorical_accuracy: 0.3494 - lr: 1.0000e-05\n",
            "Epoch 141/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2158 - categorical_accuracy: 0.8811 - val_loss: 3.4676 - val_categorical_accuracy: 0.3505 - lr: 1.0000e-05\n",
            "Epoch 142/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2144 - categorical_accuracy: 0.8786 - val_loss: 3.4720 - val_categorical_accuracy: 0.3512 - lr: 1.0000e-05\n",
            "Epoch 143/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2251 - categorical_accuracy: 0.8722 - val_loss: 3.4764 - val_categorical_accuracy: 0.3499 - lr: 1.0000e-05\n",
            "Epoch 144/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2131 - categorical_accuracy: 0.8784 - val_loss: 3.4733 - val_categorical_accuracy: 0.3498 - lr: 1.0000e-05\n",
            "Epoch 145/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2115 - categorical_accuracy: 0.8799 - val_loss: 3.4674 - val_categorical_accuracy: 0.3497 - lr: 1.0000e-05\n",
            "Epoch 146/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2131 - categorical_accuracy: 0.8802 - val_loss: 3.4671 - val_categorical_accuracy: 0.3514 - lr: 1.0000e-05\n",
            "Epoch 147/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2243 - categorical_accuracy: 0.8716 - val_loss: 3.4784 - val_categorical_accuracy: 0.3499 - lr: 1.0000e-05\n",
            "Epoch 148/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2121 - categorical_accuracy: 0.8786 - val_loss: 3.4701 - val_categorical_accuracy: 0.3522 - lr: 1.0000e-05\n",
            "Epoch 149/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 1.2100 - categorical_accuracy: 0.8810 - val_loss: 3.4678 - val_categorical_accuracy: 0.3499 - lr: 1.0000e-05\n",
            "Epoch 150/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2230 - categorical_accuracy: 0.8762 - val_loss: 3.4710 - val_categorical_accuracy: 0.3492 - lr: 1.0000e-05\n",
            "Epoch 151/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2116 - categorical_accuracy: 0.8796 - val_loss: 3.4629 - val_categorical_accuracy: 0.3501 - lr: 1.0000e-05\n",
            "Epoch 152/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2104 - categorical_accuracy: 0.8812 - val_loss: 3.4617 - val_categorical_accuracy: 0.3500 - lr: 1.0000e-05\n",
            "Epoch 153/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2211 - categorical_accuracy: 0.8738 - val_loss: 3.4660 - val_categorical_accuracy: 0.3491 - lr: 1.0000e-05\n",
            "Epoch 154/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2093 - categorical_accuracy: 0.8793 - val_loss: 3.4576 - val_categorical_accuracy: 0.3506 - lr: 1.0000e-05\n",
            "Epoch 155/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2083 - categorical_accuracy: 0.8794 - val_loss: 3.4685 - val_categorical_accuracy: 0.3497 - lr: 1.0000e-05\n",
            "Epoch 156/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2216 - categorical_accuracy: 0.8721 - val_loss: 3.4656 - val_categorical_accuracy: 0.3490 - lr: 1.0000e-05\n",
            "Epoch 157/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2185 - categorical_accuracy: 0.8739 - val_loss: 3.4721 - val_categorical_accuracy: 0.3493 - lr: 1.0000e-05\n",
            "Epoch 158/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2181 - categorical_accuracy: 0.8751 - val_loss: 3.4739 - val_categorical_accuracy: 0.3498 - lr: 1.0000e-05\n",
            "Epoch 159/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2168 - categorical_accuracy: 0.8748 - val_loss: 3.4747 - val_categorical_accuracy: 0.3487 - lr: 1.0000e-05\n",
            "Epoch 160/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2179 - categorical_accuracy: 0.8727 - val_loss: 3.4737 - val_categorical_accuracy: 0.3471 - lr: 1.0000e-05\n",
            "Epoch 161/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2081 - categorical_accuracy: 0.8782 - val_loss: 3.4806 - val_categorical_accuracy: 0.3471 - lr: 1.0000e-06\n",
            "Epoch 162/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2164 - categorical_accuracy: 0.8736 - val_loss: 3.4744 - val_categorical_accuracy: 0.3474 - lr: 1.0000e-06\n",
            "Epoch 163/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2151 - categorical_accuracy: 0.8751 - val_loss: 3.4756 - val_categorical_accuracy: 0.3466 - lr: 1.0000e-06\n",
            "Epoch 164/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2057 - categorical_accuracy: 0.8816 - val_loss: 3.4762 - val_categorical_accuracy: 0.3477 - lr: 1.0000e-06\n",
            "Epoch 165/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2069 - categorical_accuracy: 0.8823 - val_loss: 3.4725 - val_categorical_accuracy: 0.3480 - lr: 1.0000e-06\n",
            "Epoch 166/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2068 - categorical_accuracy: 0.8798 - val_loss: 3.4742 - val_categorical_accuracy: 0.3482 - lr: 1.0000e-06\n",
            "Epoch 167/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2060 - categorical_accuracy: 0.8779 - val_loss: 3.4711 - val_categorical_accuracy: 0.3484 - lr: 1.0000e-06\n",
            "Epoch 168/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2144 - categorical_accuracy: 0.8751 - val_loss: 3.4748 - val_categorical_accuracy: 0.3483 - lr: 1.0000e-06\n",
            "Epoch 169/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2055 - categorical_accuracy: 0.8803 - val_loss: 3.4741 - val_categorical_accuracy: 0.3483 - lr: 1.0000e-06\n",
            "Epoch 170/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2136 - categorical_accuracy: 0.8753 - val_loss: 3.4728 - val_categorical_accuracy: 0.3484 - lr: 1.0000e-06\n",
            "Epoch 171/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2074 - categorical_accuracy: 0.8771 - val_loss: 3.4742 - val_categorical_accuracy: 0.3480 - lr: 1.0000e-06\n",
            "Epoch 172/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2052 - categorical_accuracy: 0.8802 - val_loss: 3.4762 - val_categorical_accuracy: 0.3476 - lr: 1.0000e-06\n",
            "Epoch 173/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2144 - categorical_accuracy: 0.8779 - val_loss: 3.4736 - val_categorical_accuracy: 0.3477 - lr: 1.0000e-06\n",
            "Epoch 174/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2154 - categorical_accuracy: 0.8756 - val_loss: 3.4736 - val_categorical_accuracy: 0.3481 - lr: 1.0000e-06\n",
            "Epoch 175/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2151 - categorical_accuracy: 0.8739 - val_loss: 3.4762 - val_categorical_accuracy: 0.3475 - lr: 1.0000e-06\n",
            "Epoch 176/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2060 - categorical_accuracy: 0.8799 - val_loss: 3.4701 - val_categorical_accuracy: 0.3477 - lr: 1.0000e-06\n",
            "Epoch 177/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2139 - categorical_accuracy: 0.8778 - val_loss: 3.4768 - val_categorical_accuracy: 0.3480 - lr: 1.0000e-06\n",
            "Epoch 178/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2061 - categorical_accuracy: 0.8781 - val_loss: 3.4736 - val_categorical_accuracy: 0.3482 - lr: 1.0000e-06\n",
            "Epoch 179/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2049 - categorical_accuracy: 0.8801 - val_loss: 3.4737 - val_categorical_accuracy: 0.3482 - lr: 1.0000e-06\n",
            "Epoch 180/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2147 - categorical_accuracy: 0.8750 - val_loss: 3.4728 - val_categorical_accuracy: 0.3494 - lr: 1.0000e-06\n",
            "Epoch 181/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2142 - categorical_accuracy: 0.8776 - val_loss: 3.4745 - val_categorical_accuracy: 0.3486 - lr: 5.0000e-07\n",
            "Epoch 182/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2148 - categorical_accuracy: 0.8751 - val_loss: 3.4733 - val_categorical_accuracy: 0.3478 - lr: 5.0000e-07\n",
            "Epoch 183/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2137 - categorical_accuracy: 0.8762 - val_loss: 3.4750 - val_categorical_accuracy: 0.3484 - lr: 5.0000e-07\n",
            "Epoch 184/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2048 - categorical_accuracy: 0.8801 - val_loss: 3.4704 - val_categorical_accuracy: 0.3486 - lr: 5.0000e-07\n",
            "Epoch 185/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2057 - categorical_accuracy: 0.8781 - val_loss: 3.4723 - val_categorical_accuracy: 0.3489 - lr: 5.0000e-07\n",
            "Epoch 186/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2148 - categorical_accuracy: 0.8753 - val_loss: 3.4682 - val_categorical_accuracy: 0.3490 - lr: 5.0000e-07\n",
            "Epoch 187/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2149 - categorical_accuracy: 0.8732 - val_loss: 3.4714 - val_categorical_accuracy: 0.3483 - lr: 5.0000e-07\n",
            "Epoch 188/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2146 - categorical_accuracy: 0.8757 - val_loss: 3.4757 - val_categorical_accuracy: 0.3480 - lr: 5.0000e-07\n",
            "Epoch 189/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2059 - categorical_accuracy: 0.8787 - val_loss: 3.4690 - val_categorical_accuracy: 0.3488 - lr: 5.0000e-07\n",
            "Epoch 190/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2145 - categorical_accuracy: 0.8740 - val_loss: 3.4736 - val_categorical_accuracy: 0.3486 - lr: 5.0000e-07\n",
            "Epoch 191/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2056 - categorical_accuracy: 0.8816 - val_loss: 3.4696 - val_categorical_accuracy: 0.3495 - lr: 5.0000e-07\n",
            "Epoch 192/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2062 - categorical_accuracy: 0.8785 - val_loss: 3.4702 - val_categorical_accuracy: 0.3485 - lr: 5.0000e-07\n",
            "Epoch 193/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2048 - categorical_accuracy: 0.8813 - val_loss: 3.4738 - val_categorical_accuracy: 0.3486 - lr: 5.0000e-07\n",
            "Epoch 194/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2043 - categorical_accuracy: 0.8814 - val_loss: 3.4747 - val_categorical_accuracy: 0.3484 - lr: 5.0000e-07\n",
            "Epoch 195/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2053 - categorical_accuracy: 0.8821 - val_loss: 3.4704 - val_categorical_accuracy: 0.3478 - lr: 5.0000e-07\n",
            "Epoch 196/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 1.2054 - categorical_accuracy: 0.8781 - val_loss: 3.4718 - val_categorical_accuracy: 0.3486 - lr: 5.0000e-07\n",
            "Epoch 197/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2049 - categorical_accuracy: 0.8803 - val_loss: 3.4710 - val_categorical_accuracy: 0.3488 - lr: 5.0000e-07\n",
            "Epoch 198/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2146 - categorical_accuracy: 0.8770 - val_loss: 3.4749 - val_categorical_accuracy: 0.3478 - lr: 5.0000e-07\n",
            "Epoch 199/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 1.2048 - categorical_accuracy: 0.8799 - val_loss: 3.4713 - val_categorical_accuracy: 0.3489 - lr: 5.0000e-07\n",
            "Epoch 200/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 1.2141 - categorical_accuracy: 0.8762 - val_loss: 3.4703 - val_categorical_accuracy: 0.3480 - lr: 5.0000e-07\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(35.35999953746796,\n",
              " array([ 0.98999999, 15.88999927, 24.86999929, 23.95000011, 27.05999911,\n",
              "        29.17999923, 29.76999879, 28.13000083, 31.63999915, 31.49999976,\n",
              "        30.12000024, 30.66999912, 32.91999996, 30.98999858, 31.65000081,\n",
              "        31.22000098, 31.00999892, 32.31999874, 31.86999857, 30.54000139,\n",
              "        31.34000003, 31.97000027, 31.18999898, 32.64999986, 31.74999952,\n",
              "        32.04999864, 31.72000051, 32.98000097, 32.15999901, 32.80999959,\n",
              "        33.6499989 , 30.21999896, 34.00999904, 33.39000046, 31.70000017,\n",
              "        32.82000124, 32.44999945, 31.60000145, 31.56999946, 32.89999962,\n",
              "        32.37999976, 33.09000134, 32.3300004 , 33.01999867, 33.05999935,\n",
              "        32.31000006, 31.79999888, 32.67999887, 32.22000003, 33.07000101,\n",
              "        31.43999875, 33.27000141, 32.26999938, 32.69999921, 32.35999942,\n",
              "        32.89000094, 32.85999894, 33.12000036, 33.00999999, 32.21000135,\n",
              "        32.22999871, 33.25999975, 32.73000121, 31.04999959, 32.91000128,\n",
              "        32.74999857, 32.51999915, 33.55000019, 33.919999  , 32.24999905,\n",
              "        32.17999935, 33.05000067, 33.27000141, 32.57000148, 32.76000023,\n",
              "        32.80999959, 33.12000036, 31.70999885, 31.92999959, 33.07999969,\n",
              "        32.73000121, 35.19999981, 35.12000144, 34.90000069, 35.0699991 ,\n",
              "        35.10999978, 34.77999866, 34.9999994 , 35.06000042, 35.01000106,\n",
              "        34.65999961, 35.15000045, 35.35999954, 34.47000086, 35.06000042,\n",
              "        34.97000039, 35.17000079, 35.10999978, 35.2699995 , 35.04000008,\n",
              "        34.72000062, 35.3399992 , 34.95000005, 34.77999866, 34.97999907,\n",
              "        34.81999934, 34.49999988, 35.08000076, 34.36000049, 34.81999934,\n",
              "        34.67000127, 35.15999913, 34.58000124, 35.06000042, 34.31999981,\n",
              "        34.67999995, 34.58999991, 34.81999934, 34.81000066, 34.86000001,\n",
              "        35.17000079, 34.88999903, 35.0300014 , 35.15000045, 34.90000069,\n",
              "        35.22999883, 35.19000113, 35.10999978, 35.17999947, 34.95000005,\n",
              "        35.28999984, 35.12000144, 35.10999978, 35.0699991 , 35.19999981,\n",
              "        35.08999944, 35.04000008, 35.0699991 , 35.0300014 , 34.88000035,\n",
              "        34.94000137, 35.04999876, 35.12000144, 34.99000072, 34.97999907,\n",
              "        34.97000039, 35.13999879, 34.99000072, 35.22000015, 34.99000072,\n",
              "        34.92000103, 35.01000106, 34.9999994 , 34.90999937, 35.06000042,\n",
              "        34.97000039, 34.90000069, 34.92999971, 34.97999907, 34.86999869,\n",
              "        34.70999897, 34.70999897, 34.74000096, 34.65999961, 34.76999998,\n",
              "        34.799999  , 34.81999934, 34.83999968, 34.830001  , 34.830001  ,\n",
              "        34.83999968, 34.799999  , 34.7600013 , 34.76999998, 34.81000066,\n",
              "        34.74999964, 34.76999998, 34.799999  , 34.81999934, 34.81999934,\n",
              "        34.94000137, 34.86000001, 34.77999866, 34.83999968, 34.86000001,\n",
              "        34.88999903, 34.90000069, 34.830001  , 34.799999  , 34.88000035,\n",
              "        34.86000001, 34.95000005, 34.85000134, 34.86000001, 34.83999968,\n",
              "        34.77999866, 34.86000001, 34.88000035, 34.77999866, 34.88999903,\n",
              "        34.799999  ]))"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# SLaM"
      ],
      "metadata": {
        "id": "4os1XlFbTAoq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def top_k_disagreement(q: tf.Tensor, p: tf.Tensor, k: int) -> tf.Tensor:\n",
        "  num_classes = tf.shape(p)[-1]\n",
        "  _, top_k_pos = tf.math.top_k(p, k=k)\n",
        "  one_hot_top_k = tf.reduce_sum(tf.one_hot(top_k_pos,  num_classes), axis=1)\n",
        "  true_top_1 = tf.one_hot(tf.argmax(q, 1), num_classes)\n",
        "\n",
        "  return tf.reduce_sum(one_hot_top_k * true_top_1, axis=1)\n",
        "\n",
        "\n",
        "def top_k_margin(x: tf.Tensor, \n",
        "                k: int,\n",
        "                with_logits: Optional[bool] = False,\n",
        "                normalize: Optional[bool] = False) -> tf.Tensor:\n",
        "  \"\"\"Computes the margin of a probability/logit tensor.\"\"\"\n",
        "\n",
        "  if with_logits:\n",
        "    class_probabilities = tf.nn.softmax(x, axis=None, name=None)\n",
        "  else:\n",
        "    class_probabilities = x\n",
        "  a, _ = tf.math.top_k(class_probabilities, k=k+1, sorted=True)\n",
        "\n",
        "  # Compute total probability of top-k elements\n",
        "  top_k_prob = tf.reduce_sum(a[:, :k], axis=1)\n",
        "  next_best_prob =  tf.reduce_sum(a[:, k:], axis=1)\n",
        "  marg = top_k_prob - next_best_prob\n",
        "\n",
        "  if normalize:\n",
        "    marg = marg/tf.math.reduce_mean(marg)\n",
        "\n",
        "  return marg\n",
        "\n",
        "\n",
        "def teacher_top_k_accuracy_predictions(\n",
        "    teacher_predictions_training: tf.keras.Model,\n",
        "    teacher_predictions_validation: tf.keras.Model,\n",
        "    k: int,\n",
        "    confidence: Callable[[tf.Tensor], tf.Tensor],\n",
        "    data_split: DataSplit) -> tf.Tensor:\n",
        "  \"\"\"Creates teacher-accuracy advice for the training dataset (dataset_b).\n",
        "\n",
        "  It uses a confidence measure (e.g., margin, entropy) for the teacher\n",
        "  prediction over a validation dataset and isotonic regression to learn\n",
        "  the mapping from confidence to accuracy.\n",
        "\n",
        "  Args:\n",
        "    teacher: Instance of tf.keras.Model: the teacher model.\n",
        "    confidence: A function that maps soft labels to confidence.\n",
        "    k: Compute the top-k accuracy of teacher.\n",
        "    data_split: Instance of DataSplit containing the\n",
        "      training, validation, and test data.\n",
        "\n",
        "  Returns:\n",
        "    Instance of Advice containing advice vectors.\n",
        "  \"\"\"\n",
        "\n",
        "  train = data_split.dataset_b\n",
        "  validation = data_split.validation\n",
        "\n",
        "  # These are the lower bound/upper bounds used in isotonic regression.\n",
        "  # All values predicted should be in the range [min_threshold, max_threshold].\n",
        "  # Since the outputs correspond to the probability that the teacher model is\n",
        "  # correct the range [0.5, 1.] is reasonable.\n",
        "  min_threshold = 0.5\n",
        "  max_threshold = 1.\n",
        "\n",
        "  # Compute the teacher margins on the validation examples.\n",
        "  covariate = tf.reshape(confidence(teacher_predictions_validation), (-1, 1))\n",
        "\n",
        "  # Compute the (pointwise) top-k accuracy of the teacher on the validation data.\n",
        "  response = top_k_disagreement(validation.labels, teacher_predictions_validation, k)\n",
        "\n",
        "  # Create a isotonic regressor that maps teacher_margins to teacher_accuracies.\n",
        "  covariate = tf.reshape(covariate, -1)\n",
        "  response = tf.reshape(response, -1)\n",
        "  iso = IsotonicRegression(\n",
        "      y_min=min_threshold, y_max=max_threshold, out_of_bounds='clip')\n",
        "  iso.fit(covariate, response)\n",
        "\n",
        "\n",
        "  # Compute the teacher confidence on the training dataset.\n",
        "  teacher_confidence = tf.reshape(confidence(teacher_predictions_training), -1)\n",
        "\n",
        "  # Use the knn to predict the accuracy of the teacher on the training examples.\n",
        "  main_advice = iso.predict(teacher_confidence)\n",
        "\n",
        "  return main_advice\n",
        "\n",
        "def teacher_top_k_accuracy_advice(\n",
        "    teacher: tf.keras.Model,\n",
        "    num_classes,\n",
        "    data_split: DataSplit) -> Advice:\n",
        "\n",
        "  pretraining = data_split.dataset_a\n",
        "  validation = data_split.validation\n",
        "  test = data_split.test \n",
        "\n",
        "  s = []\n",
        "\n",
        "  teacher_predictions_training = teacher.predict(train.examples)\n",
        "  teacher_predictions_validation = teacher.predict(validation.examples)\n",
        "\n",
        "  for k in range(1, num_classes):\n",
        "\n",
        "    def _top_k_confidence(x): \n",
        "      return top_k_margin(x, k, with_logits=False, normalize=False)\n",
        "\n",
        "    s.append(teacher_top_k_accuracy_predictions(\n",
        "        teacher_predictions_training, \n",
        "        teacher_predictions_validation, \n",
        "        k, _top_k_confidence, data_split))\n",
        "    \n",
        "  s.append(tf.ones(data_split.dataset_b.size))\n",
        "  \n",
        "  train_advice = tf.stack(s, axis=1)\n",
        "\n",
        "  validation_advice = None\n",
        "  pretraining_advice = None\n",
        "\n",
        "  if validation is not None:\n",
        "    validation_advice = tf.stack([tf.one_hot(0, num_classes) for i in range(validation.size)], axis=0)\n",
        "\n",
        "  if pretraining is not None:\n",
        "    pretraining_advice =  tf.stack([tf.one_hot(0, num_classes) for i in range(pretraining.size)], axis=0)\n",
        "\n",
        "  test_advice =  tf.stack([tf.zeros(num_classes) for i in range(test.size)], axis=0)\n",
        "\n",
        "  advice_data = Advice(\n",
        "      train=train_advice,\n",
        "      validation=validation_advice,\n",
        "      pretraining=pretraining_advice,\n",
        "      test=test_advice)\n",
        "\n",
        "  return advice_data\n"
      ],
      "metadata": {
        "id": "fV8jRh_cHaV1",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674705990992,
          "user_tz": 300,
          "elapsed": 166,
          "user": {
            "displayName": "XXXX",
            "userId": "00000"
          }
        }
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "full_top_k_advice = teacher_top_k_accuracy_advice(teacher_model, 100, data_split)"
      ],
      "metadata": {
        "id": "vQ-qCtDcTkW_",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674711899399,
          "user_tz": 300,
          "elapsed": 0,
          "user": {
            "displayName": "XXXX",
            "userId": "0000"
          }
        },
        "outputId": "0d165e3e-2c3d-4c68-9caf-9fd1f293ca5e"
      },
      "execution_count": 16,
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1391/1391 [==============================] - 11s 8ms/step\n",
            "16/16 [==============================] - 0s 8ms/step\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import tensorflow_probability as tfp\n",
        "\n",
        "def soft_top_k_mask(x, k_mask, temp):\n",
        "  perm_matrices =  tfp.math.soft_sorting_matrix(x, temp) \n",
        "  return tf.einsum('ijk, ij -> ik', perm_matrices, k_mask)\n",
        "\n",
        "\n",
        "def advice_mix_soft_top_full(num_classes, temperature, threshold):\n",
        "  \"\"\"A decorator that returns a top_k student label mixing loss decorator.\n",
        "\n",
        "  Args:\n",
        "      k: the top_k elements of the student to be used.\n",
        "\n",
        "  Returns:\n",
        "    A top_k student label mixing loss decorator.\n",
        "    loss.\n",
        "  \"\"\"\n",
        "\n",
        "  def _advice_mix_soft_top_k(loss):\n",
        "\n",
        "    def _loss(y_true, y_pred):\n",
        "  \n",
        "      # Split y_true into the advice (last column of y_true) and labels.\n",
        "      advice = tf.stop_gradient(y_true[:, -num_classes:])\n",
        "      label = tf.stop_gradient(y_true[:, :-num_classes])\n",
        "      k_mask = tf.cast(advice<threshold, tf.float32)\n",
        "      boost = soft_top_k_mask(label, k_mask, temperature) * (1-y_pred)\n",
        "      \n",
        "      top_1_acc = advice[:, 0]\n",
        "      top_1_err = 1.-top_1_acc\n",
        "\n",
        "      y_1 = tf.einsum('ij, i -> ij', y_pred, top_1_acc) \n",
        "      y_2  = tf.einsum('ij, i -> ij', boost, top_1_err)\n",
        "      y_mixed = y_1 + y_2 \n",
        "  \n",
        "      # Return the loss of the teacher label and the mixed student label.\n",
        "      return loss(label, y_mixed)\n",
        "  \n",
        "    return _loss\n",
        "\n",
        "  return _advice_mix_soft_top_k\n",
        "\n",
        "\n",
        "def mixed_cce_soft_top_full(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:\n",
        "  \"\"\"Computes the cross entropy loss using the advice to perform student mixing.\n",
        "\n",
        "  Args:\n",
        "    y_true: tf.Tensor with num_classes + 1 columns where the first num_classes\n",
        "      columns contain the true label and the last column corresponds to the\n",
        "      advice.\n",
        "    y_pred: tf.Tensor with num_classes columns that contains the predicted\n",
        "      probabilities for each class.\n",
        "\n",
        "  Returns:\n",
        "    The student mixed cross entropy between y_true[:,:-1] and y_pred defined as:\n",
        "    ce(teacher_pred, advice student_pred + (1-advice) (1-student_pred))\n",
        "  \"\"\"\n",
        "\n",
        "  # We have to use tf.keras.losses.Reduction.NONE when we are using\n",
        "  # tf.distribute.Strategy.\n",
        "  # (see https://www.tensorflow.org/tutorials/distribute/custom_training).\n",
        "  mixing_operator = advice_mix_soft_top_full(100, 0.1, 0.9)\n",
        "  loss = mixing_operator(tf.keras.losses.CategoricalCrossentropy(\n",
        "    reduction=tf.keras.losses.Reduction.NONE))\n",
        "\n",
        "  return tf.math.reduce_mean(loss(y_true, y_pred))\n"
      ],
      "metadata": {
        "id": "jSzSDWLvHb8o",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674711899392,
          "user_tz": 300,
          "elapsed": 1,
          "user": {
            "displayName": "XXX",
            "userId": "00000"
          }
        }
      },
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "student_soft_full = load_model(\n",
        "      _STUDENT_MODEL,\n",
        "      num_classes=num_classes,\n",
        "      optimizer_name=_STUDENT_OPTIMIZER,\n",
        "      loss_function=mixed_cce_soft_top_full,\n",
        "      width_multiplier=1,\n",
        "      depth_multiplier=_STUDENT_MOBILENET_DEPTH_MULTIPLIER,\n",
        "      resnet_depth=_STUDENT_RESNET_DEPTH)"
      ],
      "metadata": {
        "id": "_8ZjtS89StWS",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674711899392,
          "user_tz": 300,
          "elapsed": 0,
          "user": {
            "displayName": "XXXXX",
            "userId": "000000"
          }
        }
      },
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Train the student model using distillation.\n",
        "distill(model=student_soft_full,\n",
        "        teacher_predictions=teacher_predictions,\n",
        "        data_split=data_split,\n",
        "        advice_data=full_top_k_advice,\n",
        "        params=training_params)\n"
      ],
      "metadata": {
        "id": "RbT_U0M4VKv-",
        "outputId": "37421506-bb40-4199-b18f-7040d7db8805",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1674711899400,
          "user_tz": 300,
          "elapsed": 1,
          "user": {
            "displayName": "XXXX",
            "userId": "00000"
          }
        }
      },
      "execution_count": 19,
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "313/313 [==============================] - 4s 6ms/step - loss: 5.8440 - categorical_accuracy: 0.0100\n",
            "Epoch 1/200\n",
            "390/390 [==============================] - 49s 52ms/step - loss: 3.9651 - categorical_accuracy: 0.0086 - val_loss: 5.0158 - val_categorical_accuracy: 0.0906 - lr: 0.0010\n",
            "Epoch 2/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.5196 - categorical_accuracy: 0.0139 - val_loss: 4.8894 - val_categorical_accuracy: 0.1008 - lr: 0.0010\n",
            "Epoch 3/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.3884 - categorical_accuracy: 0.0178 - val_loss: 4.9193 - val_categorical_accuracy: 0.1210 - lr: 0.0010\n",
            "Epoch 4/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.3175 - categorical_accuracy: 0.0215 - val_loss: 4.9100 - val_categorical_accuracy: 0.1577 - lr: 0.0010\n",
            "Epoch 5/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.2684 - categorical_accuracy: 0.0239 - val_loss: 4.8999 - val_categorical_accuracy: 0.1694 - lr: 0.0010\n",
            "Epoch 6/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 3.2374 - categorical_accuracy: 0.0250 - val_loss: 4.9577 - val_categorical_accuracy: 0.1535 - lr: 0.0010\n",
            "Epoch 7/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.2106 - categorical_accuracy: 0.0277 - val_loss: 4.9750 - val_categorical_accuracy: 0.1864 - lr: 0.0010\n",
            "Epoch 8/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.1839 - categorical_accuracy: 0.0307 - val_loss: 4.9756 - val_categorical_accuracy: 0.1933 - lr: 0.0010\n",
            "Epoch 9/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.1604 - categorical_accuracy: 0.0314 - val_loss: 4.9502 - val_categorical_accuracy: 0.1700 - lr: 0.0010\n",
            "Epoch 10/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.1435 - categorical_accuracy: 0.0339 - val_loss: 5.0261 - val_categorical_accuracy: 0.2365 - lr: 0.0010\n",
            "Epoch 11/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.1262 - categorical_accuracy: 0.0364 - val_loss: 5.0358 - val_categorical_accuracy: 0.2378 - lr: 0.0010\n",
            "Epoch 12/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.1215 - categorical_accuracy: 0.0365 - val_loss: 5.0205 - val_categorical_accuracy: 0.2179 - lr: 0.0010\n",
            "Epoch 13/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.1036 - categorical_accuracy: 0.0381 - val_loss: 5.1278 - val_categorical_accuracy: 0.2456 - lr: 0.0010\n",
            "Epoch 14/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.0938 - categorical_accuracy: 0.0391 - val_loss: 5.0897 - val_categorical_accuracy: 0.2349 - lr: 0.0010\n",
            "Epoch 15/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.0760 - categorical_accuracy: 0.0412 - val_loss: 5.0470 - val_categorical_accuracy: 0.2271 - lr: 0.0010\n",
            "Epoch 16/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.0640 - categorical_accuracy: 0.0427 - val_loss: 5.0718 - val_categorical_accuracy: 0.2438 - lr: 0.0010\n",
            "Epoch 17/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.0653 - categorical_accuracy: 0.0428 - val_loss: 5.1074 - val_categorical_accuracy: 0.2403 - lr: 0.0010\n",
            "Epoch 18/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.0494 - categorical_accuracy: 0.0446 - val_loss: 5.0806 - val_categorical_accuracy: 0.2351 - lr: 0.0010\n",
            "Epoch 19/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.0374 - categorical_accuracy: 0.0473 - val_loss: 5.2238 - val_categorical_accuracy: 0.2835 - lr: 0.0010\n",
            "Epoch 20/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.0352 - categorical_accuracy: 0.0465 - val_loss: 5.1941 - val_categorical_accuracy: 0.2551 - lr: 0.0010\n",
            "Epoch 21/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 3.0269 - categorical_accuracy: 0.0482 - val_loss: 5.2121 - val_categorical_accuracy: 0.2937 - lr: 0.0010\n",
            "Epoch 22/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.0202 - categorical_accuracy: 0.0493 - val_loss: 5.1785 - val_categorical_accuracy: 0.2697 - lr: 0.0010\n",
            "Epoch 23/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.0121 - categorical_accuracy: 0.0499 - val_loss: 5.2585 - val_categorical_accuracy: 0.2920 - lr: 0.0010\n",
            "Epoch 24/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 3.0047 - categorical_accuracy: 0.0508 - val_loss: 5.2518 - val_categorical_accuracy: 0.2819 - lr: 0.0010\n",
            "Epoch 25/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9968 - categorical_accuracy: 0.0532 - val_loss: 5.1901 - val_categorical_accuracy: 0.2644 - lr: 0.0010\n",
            "Epoch 26/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9986 - categorical_accuracy: 0.0527 - val_loss: 5.3433 - val_categorical_accuracy: 0.2928 - lr: 0.0010\n",
            "Epoch 27/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.9881 - categorical_accuracy: 0.0545 - val_loss: 5.2903 - val_categorical_accuracy: 0.2970 - lr: 0.0010\n",
            "Epoch 28/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9842 - categorical_accuracy: 0.0543 - val_loss: 5.4288 - val_categorical_accuracy: 0.3257 - lr: 0.0010\n",
            "Epoch 29/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9771 - categorical_accuracy: 0.0560 - val_loss: 5.3922 - val_categorical_accuracy: 0.3030 - lr: 0.0010\n",
            "Epoch 30/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.9667 - categorical_accuracy: 0.0581 - val_loss: 5.4330 - val_categorical_accuracy: 0.2828 - lr: 0.0010\n",
            "Epoch 31/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9709 - categorical_accuracy: 0.0564 - val_loss: 5.3424 - val_categorical_accuracy: 0.3004 - lr: 0.0010\n",
            "Epoch 32/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9602 - categorical_accuracy: 0.0585 - val_loss: 5.4082 - val_categorical_accuracy: 0.3312 - lr: 0.0010\n",
            "Epoch 33/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9600 - categorical_accuracy: 0.0585 - val_loss: 5.4409 - val_categorical_accuracy: 0.3064 - lr: 0.0010\n",
            "Epoch 34/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9524 - categorical_accuracy: 0.0603 - val_loss: 5.4819 - val_categorical_accuracy: 0.3162 - lr: 0.0010\n",
            "Epoch 35/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9495 - categorical_accuracy: 0.0609 - val_loss: 5.5822 - val_categorical_accuracy: 0.3522 - lr: 0.0010\n",
            "Epoch 36/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.9416 - categorical_accuracy: 0.0626 - val_loss: 5.4265 - val_categorical_accuracy: 0.3235 - lr: 0.0010\n",
            "Epoch 37/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.9387 - categorical_accuracy: 0.0633 - val_loss: 5.5403 - val_categorical_accuracy: 0.3440 - lr: 0.0010\n",
            "Epoch 38/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9301 - categorical_accuracy: 0.0655 - val_loss: 5.5399 - val_categorical_accuracy: 0.3206 - lr: 0.0010\n",
            "Epoch 39/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9472 - categorical_accuracy: 0.0617 - val_loss: 5.5496 - val_categorical_accuracy: 0.3383 - lr: 0.0010\n",
            "Epoch 40/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9317 - categorical_accuracy: 0.0638 - val_loss: 5.5484 - val_categorical_accuracy: 0.3507 - lr: 0.0010\n",
            "Epoch 41/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9306 - categorical_accuracy: 0.0650 - val_loss: 5.5243 - val_categorical_accuracy: 0.3298 - lr: 0.0010\n",
            "Epoch 42/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9215 - categorical_accuracy: 0.0665 - val_loss: 5.5967 - val_categorical_accuracy: 0.3380 - lr: 0.0010\n",
            "Epoch 43/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9167 - categorical_accuracy: 0.0683 - val_loss: 5.5870 - val_categorical_accuracy: 0.3379 - lr: 0.0010\n",
            "Epoch 44/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9111 - categorical_accuracy: 0.0698 - val_loss: 5.4566 - val_categorical_accuracy: 0.3207 - lr: 0.0010\n",
            "Epoch 45/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9109 - categorical_accuracy: 0.0690 - val_loss: 5.5674 - val_categorical_accuracy: 0.3043 - lr: 0.0010\n",
            "Epoch 46/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9040 - categorical_accuracy: 0.0704 - val_loss: 5.6918 - val_categorical_accuracy: 0.3600 - lr: 0.0010\n",
            "Epoch 47/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8973 - categorical_accuracy: 0.0723 - val_loss: 5.6231 - val_categorical_accuracy: 0.3332 - lr: 0.0010\n",
            "Epoch 48/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8960 - categorical_accuracy: 0.0730 - val_loss: 5.5466 - val_categorical_accuracy: 0.3276 - lr: 0.0010\n",
            "Epoch 49/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.9274 - categorical_accuracy: 0.0657 - val_loss: 5.6016 - val_categorical_accuracy: 0.3320 - lr: 0.0010\n",
            "Epoch 50/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.9148 - categorical_accuracy: 0.0684 - val_loss: 5.8010 - val_categorical_accuracy: 0.3477 - lr: 0.0010\n",
            "Epoch 51/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8981 - categorical_accuracy: 0.0724 - val_loss: 5.6541 - val_categorical_accuracy: 0.3165 - lr: 0.0010\n",
            "Epoch 52/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.9071 - categorical_accuracy: 0.0689 - val_loss: 5.6554 - val_categorical_accuracy: 0.3413 - lr: 0.0010\n",
            "Epoch 53/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8945 - categorical_accuracy: 0.0723 - val_loss: 5.4930 - val_categorical_accuracy: 0.3107 - lr: 0.0010\n",
            "Epoch 54/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8864 - categorical_accuracy: 0.0753 - val_loss: 5.7654 - val_categorical_accuracy: 0.3577 - lr: 0.0010\n",
            "Epoch 55/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8785 - categorical_accuracy: 0.0770 - val_loss: 5.7440 - val_categorical_accuracy: 0.3665 - lr: 0.0010\n",
            "Epoch 56/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.9087 - categorical_accuracy: 0.0691 - val_loss: 5.7490 - val_categorical_accuracy: 0.3460 - lr: 0.0010\n",
            "Epoch 57/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8819 - categorical_accuracy: 0.0759 - val_loss: 5.8648 - val_categorical_accuracy: 0.3524 - lr: 0.0010\n",
            "Epoch 58/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8991 - categorical_accuracy: 0.0720 - val_loss: 5.5826 - val_categorical_accuracy: 0.2965 - lr: 0.0010\n",
            "Epoch 59/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8853 - categorical_accuracy: 0.0745 - val_loss: 5.7553 - val_categorical_accuracy: 0.3368 - lr: 0.0010\n",
            "Epoch 60/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8826 - categorical_accuracy: 0.0755 - val_loss: 5.7418 - val_categorical_accuracy: 0.3653 - lr: 0.0010\n",
            "Epoch 61/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8899 - categorical_accuracy: 0.0743 - val_loss: 5.7809 - val_categorical_accuracy: 0.3457 - lr: 0.0010\n",
            "Epoch 62/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8825 - categorical_accuracy: 0.0748 - val_loss: 5.8248 - val_categorical_accuracy: 0.3275 - lr: 0.0010\n",
            "Epoch 63/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8762 - categorical_accuracy: 0.0772 - val_loss: 5.8426 - val_categorical_accuracy: 0.3567 - lr: 0.0010\n",
            "Epoch 64/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8762 - categorical_accuracy: 0.0778 - val_loss: 5.7494 - val_categorical_accuracy: 0.3165 - lr: 0.0010\n",
            "Epoch 65/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8750 - categorical_accuracy: 0.0775 - val_loss: 5.8507 - val_categorical_accuracy: 0.3352 - lr: 0.0010\n",
            "Epoch 66/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.8771 - categorical_accuracy: 0.0779 - val_loss: 5.6869 - val_categorical_accuracy: 0.3420 - lr: 0.0010\n",
            "Epoch 67/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.8663 - categorical_accuracy: 0.0803 - val_loss: 5.8277 - val_categorical_accuracy: 0.3523 - lr: 0.0010\n",
            "Epoch 68/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8652 - categorical_accuracy: 0.0802 - val_loss: 5.8286 - val_categorical_accuracy: 0.3525 - lr: 0.0010\n",
            "Epoch 69/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8647 - categorical_accuracy: 0.0803 - val_loss: 5.8160 - val_categorical_accuracy: 0.3064 - lr: 0.0010\n",
            "Epoch 70/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8591 - categorical_accuracy: 0.0825 - val_loss: 5.8817 - val_categorical_accuracy: 0.3526 - lr: 0.0010\n",
            "Epoch 71/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8597 - categorical_accuracy: 0.0813 - val_loss: 5.8150 - val_categorical_accuracy: 0.3373 - lr: 0.0010\n",
            "Epoch 72/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8826 - categorical_accuracy: 0.0762 - val_loss: 5.8488 - val_categorical_accuracy: 0.3757 - lr: 0.0010\n",
            "Epoch 73/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8679 - categorical_accuracy: 0.0796 - val_loss: 5.8595 - val_categorical_accuracy: 0.3357 - lr: 0.0010\n",
            "Epoch 74/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8637 - categorical_accuracy: 0.0807 - val_loss: 5.9412 - val_categorical_accuracy: 0.3404 - lr: 0.0010\n",
            "Epoch 75/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8547 - categorical_accuracy: 0.0828 - val_loss: 5.8073 - val_categorical_accuracy: 0.3283 - lr: 0.0010\n",
            "Epoch 76/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8524 - categorical_accuracy: 0.0835 - val_loss: 5.8933 - val_categorical_accuracy: 0.3541 - lr: 0.0010\n",
            "Epoch 77/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8453 - categorical_accuracy: 0.0857 - val_loss: 5.9466 - val_categorical_accuracy: 0.3468 - lr: 0.0010\n",
            "Epoch 78/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8809 - categorical_accuracy: 0.0766 - val_loss: 5.8521 - val_categorical_accuracy: 0.3336 - lr: 0.0010\n",
            "Epoch 79/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8503 - categorical_accuracy: 0.0847 - val_loss: 5.8647 - val_categorical_accuracy: 0.3420 - lr: 0.0010\n",
            "Epoch 80/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.8628 - categorical_accuracy: 0.0808 - val_loss: 5.8971 - val_categorical_accuracy: 0.3253 - lr: 0.0010\n",
            "Epoch 81/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.8093 - categorical_accuracy: 0.0953 - val_loss: 6.1487 - val_categorical_accuracy: 0.4203 - lr: 1.0000e-04\n",
            "Epoch 82/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.7894 - categorical_accuracy: 0.1007 - val_loss: 6.1849 - val_categorical_accuracy: 0.4243 - lr: 1.0000e-04\n",
            "Epoch 83/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7852 - categorical_accuracy: 0.1011 - val_loss: 6.2060 - val_categorical_accuracy: 0.4224 - lr: 1.0000e-04\n",
            "Epoch 84/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7768 - categorical_accuracy: 0.1043 - val_loss: 6.2393 - val_categorical_accuracy: 0.4269 - lr: 1.0000e-04\n",
            "Epoch 85/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7718 - categorical_accuracy: 0.1049 - val_loss: 6.2212 - val_categorical_accuracy: 0.4199 - lr: 1.0000e-04\n",
            "Epoch 86/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7731 - categorical_accuracy: 0.1033 - val_loss: 6.2435 - val_categorical_accuracy: 0.4208 - lr: 1.0000e-04\n",
            "Epoch 87/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7652 - categorical_accuracy: 0.1059 - val_loss: 6.2533 - val_categorical_accuracy: 0.4213 - lr: 1.0000e-04\n",
            "Epoch 88/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7662 - categorical_accuracy: 0.1049 - val_loss: 6.2652 - val_categorical_accuracy: 0.4238 - lr: 1.0000e-04\n",
            "Epoch 89/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7620 - categorical_accuracy: 0.1057 - val_loss: 6.2640 - val_categorical_accuracy: 0.4179 - lr: 1.0000e-04\n",
            "Epoch 90/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7582 - categorical_accuracy: 0.1063 - val_loss: 6.2446 - val_categorical_accuracy: 0.4180 - lr: 1.0000e-04\n",
            "Epoch 91/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7574 - categorical_accuracy: 0.1059 - val_loss: 6.2713 - val_categorical_accuracy: 0.4163 - lr: 1.0000e-04\n",
            "Epoch 92/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7526 - categorical_accuracy: 0.1070 - val_loss: 6.2652 - val_categorical_accuracy: 0.4200 - lr: 1.0000e-04\n",
            "Epoch 93/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7503 - categorical_accuracy: 0.1070 - val_loss: 6.2732 - val_categorical_accuracy: 0.4149 - lr: 1.0000e-04\n",
            "Epoch 94/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7493 - categorical_accuracy: 0.1071 - val_loss: 6.2849 - val_categorical_accuracy: 0.4157 - lr: 1.0000e-04\n",
            "Epoch 95/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7474 - categorical_accuracy: 0.1071 - val_loss: 6.2713 - val_categorical_accuracy: 0.4125 - lr: 1.0000e-04\n",
            "Epoch 96/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7444 - categorical_accuracy: 0.1076 - val_loss: 6.2864 - val_categorical_accuracy: 0.4161 - lr: 1.0000e-04\n",
            "Epoch 97/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7427 - categorical_accuracy: 0.1077 - val_loss: 6.2889 - val_categorical_accuracy: 0.4099 - lr: 1.0000e-04\n",
            "Epoch 98/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7404 - categorical_accuracy: 0.1078 - val_loss: 6.2995 - val_categorical_accuracy: 0.4165 - lr: 1.0000e-04\n",
            "Epoch 99/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7389 - categorical_accuracy: 0.1080 - val_loss: 6.3393 - val_categorical_accuracy: 0.4126 - lr: 1.0000e-04\n",
            "Epoch 100/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7372 - categorical_accuracy: 0.1081 - val_loss: 6.3062 - val_categorical_accuracy: 0.4133 - lr: 1.0000e-04\n",
            "Epoch 101/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7348 - categorical_accuracy: 0.1081 - val_loss: 6.3049 - val_categorical_accuracy: 0.4133 - lr: 1.0000e-04\n",
            "Epoch 102/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7324 - categorical_accuracy: 0.1086 - val_loss: 6.3076 - val_categorical_accuracy: 0.4085 - lr: 1.0000e-04\n",
            "Epoch 103/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7311 - categorical_accuracy: 0.1085 - val_loss: 6.3065 - val_categorical_accuracy: 0.4056 - lr: 1.0000e-04\n",
            "Epoch 104/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7297 - categorical_accuracy: 0.1083 - val_loss: 6.3185 - val_categorical_accuracy: 0.4066 - lr: 1.0000e-04\n",
            "Epoch 105/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7303 - categorical_accuracy: 0.1082 - val_loss: 6.3299 - val_categorical_accuracy: 0.4095 - lr: 1.0000e-04\n",
            "Epoch 106/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7268 - categorical_accuracy: 0.1089 - val_loss: 6.3135 - val_categorical_accuracy: 0.4080 - lr: 1.0000e-04\n",
            "Epoch 107/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7282 - categorical_accuracy: 0.1079 - val_loss: 6.3556 - val_categorical_accuracy: 0.4057 - lr: 1.0000e-04\n",
            "Epoch 108/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7252 - categorical_accuracy: 0.1086 - val_loss: 6.3179 - val_categorical_accuracy: 0.4052 - lr: 1.0000e-04\n",
            "Epoch 109/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.7223 - categorical_accuracy: 0.1091 - val_loss: 6.3468 - val_categorical_accuracy: 0.4050 - lr: 1.0000e-04\n",
            "Epoch 110/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7238 - categorical_accuracy: 0.1084 - val_loss: 6.3304 - val_categorical_accuracy: 0.4013 - lr: 1.0000e-04\n",
            "Epoch 111/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7206 - categorical_accuracy: 0.1087 - val_loss: 6.3379 - val_categorical_accuracy: 0.4097 - lr: 1.0000e-04\n",
            "Epoch 112/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7185 - categorical_accuracy: 0.1093 - val_loss: 6.3567 - val_categorical_accuracy: 0.4086 - lr: 1.0000e-04\n",
            "Epoch 113/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.7168 - categorical_accuracy: 0.1090 - val_loss: 6.3344 - val_categorical_accuracy: 0.4066 - lr: 1.0000e-04\n",
            "Epoch 114/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7196 - categorical_accuracy: 0.1084 - val_loss: 6.3841 - val_categorical_accuracy: 0.3984 - lr: 1.0000e-04\n",
            "Epoch 115/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7155 - categorical_accuracy: 0.1091 - val_loss: 6.3646 - val_categorical_accuracy: 0.4035 - lr: 1.0000e-04\n",
            "Epoch 116/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7139 - categorical_accuracy: 0.1091 - val_loss: 6.3610 - val_categorical_accuracy: 0.3998 - lr: 1.0000e-04\n",
            "Epoch 117/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7165 - categorical_accuracy: 0.1086 - val_loss: 6.3851 - val_categorical_accuracy: 0.4043 - lr: 1.0000e-04\n",
            "Epoch 118/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7146 - categorical_accuracy: 0.1089 - val_loss: 6.3792 - val_categorical_accuracy: 0.4002 - lr: 1.0000e-04\n",
            "Epoch 119/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7124 - categorical_accuracy: 0.1092 - val_loss: 6.3430 - val_categorical_accuracy: 0.3971 - lr: 1.0000e-04\n",
            "Epoch 120/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7116 - categorical_accuracy: 0.1091 - val_loss: 6.3746 - val_categorical_accuracy: 0.4005 - lr: 1.0000e-04\n",
            "Epoch 121/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7083 - categorical_accuracy: 0.1094 - val_loss: 6.3967 - val_categorical_accuracy: 0.4040 - lr: 1.0000e-05\n",
            "Epoch 122/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7067 - categorical_accuracy: 0.1099 - val_loss: 6.3936 - val_categorical_accuracy: 0.4032 - lr: 1.0000e-05\n",
            "Epoch 123/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7068 - categorical_accuracy: 0.1097 - val_loss: 6.3981 - val_categorical_accuracy: 0.4050 - lr: 1.0000e-05\n",
            "Epoch 124/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7060 - categorical_accuracy: 0.1098 - val_loss: 6.4007 - val_categorical_accuracy: 0.4063 - lr: 1.0000e-05\n",
            "Epoch 125/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7064 - categorical_accuracy: 0.1096 - val_loss: 6.4021 - val_categorical_accuracy: 0.4054 - lr: 1.0000e-05\n",
            "Epoch 126/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7054 - categorical_accuracy: 0.1098 - val_loss: 6.4000 - val_categorical_accuracy: 0.4059 - lr: 1.0000e-05\n",
            "Epoch 127/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7061 - categorical_accuracy: 0.1097 - val_loss: 6.4055 - val_categorical_accuracy: 0.4076 - lr: 1.0000e-05\n",
            "Epoch 128/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7059 - categorical_accuracy: 0.1096 - val_loss: 6.4110 - val_categorical_accuracy: 0.4066 - lr: 1.0000e-05\n",
            "Epoch 129/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7051 - categorical_accuracy: 0.1097 - val_loss: 6.4061 - val_categorical_accuracy: 0.4062 - lr: 1.0000e-05\n",
            "Epoch 130/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7045 - categorical_accuracy: 0.1098 - val_loss: 6.4029 - val_categorical_accuracy: 0.4061 - lr: 1.0000e-05\n",
            "Epoch 131/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7050 - categorical_accuracy: 0.1098 - val_loss: 6.4041 - val_categorical_accuracy: 0.4067 - lr: 1.0000e-05\n",
            "Epoch 132/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7042 - categorical_accuracy: 0.1099 - val_loss: 6.4072 - val_categorical_accuracy: 0.4043 - lr: 1.0000e-05\n",
            "Epoch 133/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7049 - categorical_accuracy: 0.1097 - val_loss: 6.4043 - val_categorical_accuracy: 0.4056 - lr: 1.0000e-05\n",
            "Epoch 134/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7041 - categorical_accuracy: 0.1099 - val_loss: 6.4048 - val_categorical_accuracy: 0.4063 - lr: 1.0000e-05\n",
            "Epoch 135/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7042 - categorical_accuracy: 0.1099 - val_loss: 6.4107 - val_categorical_accuracy: 0.4052 - lr: 1.0000e-05\n",
            "Epoch 136/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7042 - categorical_accuracy: 0.1100 - val_loss: 6.3976 - val_categorical_accuracy: 0.4048 - lr: 1.0000e-05\n",
            "Epoch 137/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.7035 - categorical_accuracy: 0.1101 - val_loss: 6.3990 - val_categorical_accuracy: 0.4058 - lr: 1.0000e-05\n",
            "Epoch 138/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7039 - categorical_accuracy: 0.1099 - val_loss: 6.3947 - val_categorical_accuracy: 0.4049 - lr: 1.0000e-05\n",
            "Epoch 139/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7035 - categorical_accuracy: 0.1100 - val_loss: 6.3961 - val_categorical_accuracy: 0.4055 - lr: 1.0000e-05\n",
            "Epoch 140/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7025 - categorical_accuracy: 0.1099 - val_loss: 6.3982 - val_categorical_accuracy: 0.4045 - lr: 1.0000e-05\n",
            "Epoch 141/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.7026 - categorical_accuracy: 0.1099 - val_loss: 6.4060 - val_categorical_accuracy: 0.4050 - lr: 1.0000e-05\n",
            "Epoch 142/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7029 - categorical_accuracy: 0.1099 - val_loss: 6.3966 - val_categorical_accuracy: 0.4033 - lr: 1.0000e-05\n",
            "Epoch 143/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7024 - categorical_accuracy: 0.1100 - val_loss: 6.4077 - val_categorical_accuracy: 0.4060 - lr: 1.0000e-05\n",
            "Epoch 144/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7019 - categorical_accuracy: 0.1100 - val_loss: 6.3985 - val_categorical_accuracy: 0.4052 - lr: 1.0000e-05\n",
            "Epoch 145/200\n",
            "390/390 [==============================] - 21s 52ms/step - loss: 2.7020 - categorical_accuracy: 0.1100 - val_loss: 6.4002 - val_categorical_accuracy: 0.4041 - lr: 1.0000e-05\n",
            "Epoch 146/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7020 - categorical_accuracy: 0.1098 - val_loss: 6.3992 - val_categorical_accuracy: 0.4067 - lr: 1.0000e-05\n",
            "Epoch 147/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7015 - categorical_accuracy: 0.1101 - val_loss: 6.4023 - val_categorical_accuracy: 0.4052 - lr: 1.0000e-05\n",
            "Epoch 148/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7020 - categorical_accuracy: 0.1101 - val_loss: 6.4053 - val_categorical_accuracy: 0.4063 - lr: 1.0000e-05\n",
            "Epoch 149/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7014 - categorical_accuracy: 0.1099 - val_loss: 6.4045 - val_categorical_accuracy: 0.4046 - lr: 1.0000e-05\n",
            "Epoch 150/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7018 - categorical_accuracy: 0.1099 - val_loss: 6.4176 - val_categorical_accuracy: 0.4051 - lr: 1.0000e-05\n",
            "Epoch 151/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7012 - categorical_accuracy: 0.1100 - val_loss: 6.4166 - val_categorical_accuracy: 0.4075 - lr: 1.0000e-05\n",
            "Epoch 152/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7011 - categorical_accuracy: 0.1100 - val_loss: 6.4099 - val_categorical_accuracy: 0.4058 - lr: 1.0000e-05\n",
            "Epoch 153/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.7007 - categorical_accuracy: 0.1101 - val_loss: 6.4100 - val_categorical_accuracy: 0.4053 - lr: 1.0000e-05\n",
            "Epoch 154/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.7007 - categorical_accuracy: 0.1099 - val_loss: 6.4058 - val_categorical_accuracy: 0.4042 - lr: 1.0000e-05\n",
            "Epoch 155/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7005 - categorical_accuracy: 0.1101 - val_loss: 6.4088 - val_categorical_accuracy: 0.4046 - lr: 1.0000e-05\n",
            "Epoch 156/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7010 - categorical_accuracy: 0.1099 - val_loss: 6.4090 - val_categorical_accuracy: 0.4042 - lr: 1.0000e-05\n",
            "Epoch 157/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.7005 - categorical_accuracy: 0.1100 - val_loss: 6.4128 - val_categorical_accuracy: 0.4017 - lr: 1.0000e-05\n",
            "Epoch 158/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 2.7004 - categorical_accuracy: 0.1101 - val_loss: 6.4079 - val_categorical_accuracy: 0.4023 - lr: 1.0000e-05\n",
            "Epoch 159/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.7001 - categorical_accuracy: 0.1100 - val_loss: 6.4113 - val_categorical_accuracy: 0.4053 - lr: 1.0000e-05\n",
            "Epoch 160/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6997 - categorical_accuracy: 0.1099 - val_loss: 6.4087 - val_categorical_accuracy: 0.4036 - lr: 1.0000e-05\n",
            "Epoch 161/200\n",
            "390/390 [==============================] - 20s 51ms/step - loss: 2.6996 - categorical_accuracy: 0.1100 - val_loss: 6.4106 - val_categorical_accuracy: 0.4030 - lr: 1.0000e-06\n",
            "Epoch 162/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6997 - categorical_accuracy: 0.1100 - val_loss: 6.4087 - val_categorical_accuracy: 0.4030 - lr: 1.0000e-06\n",
            "Epoch 163/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6990 - categorical_accuracy: 0.1100 - val_loss: 6.4088 - val_categorical_accuracy: 0.4040 - lr: 1.0000e-06\n",
            "Epoch 164/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6996 - categorical_accuracy: 0.1099 - val_loss: 6.4101 - val_categorical_accuracy: 0.4049 - lr: 1.0000e-06\n",
            "Epoch 165/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6993 - categorical_accuracy: 0.1099 - val_loss: 6.4071 - val_categorical_accuracy: 0.4027 - lr: 1.0000e-06\n",
            "Epoch 166/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7000 - categorical_accuracy: 0.1099 - val_loss: 6.4116 - val_categorical_accuracy: 0.4040 - lr: 1.0000e-06\n",
            "Epoch 167/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6989 - categorical_accuracy: 0.1100 - val_loss: 6.4099 - val_categorical_accuracy: 0.4032 - lr: 1.0000e-06\n",
            "Epoch 168/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6996 - categorical_accuracy: 0.1101 - val_loss: 6.4072 - val_categorical_accuracy: 0.4034 - lr: 1.0000e-06\n",
            "Epoch 169/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6994 - categorical_accuracy: 0.1100 - val_loss: 6.4095 - val_categorical_accuracy: 0.4039 - lr: 1.0000e-06\n",
            "Epoch 170/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6992 - categorical_accuracy: 0.1100 - val_loss: 6.4084 - val_categorical_accuracy: 0.4038 - lr: 1.0000e-06\n",
            "Epoch 171/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7004 - categorical_accuracy: 0.1099 - val_loss: 6.4106 - val_categorical_accuracy: 0.4039 - lr: 1.0000e-06\n",
            "Epoch 172/200\n",
            "390/390 [==============================] - 22s 56ms/step - loss: 2.6992 - categorical_accuracy: 0.1100 - val_loss: 6.4053 - val_categorical_accuracy: 0.4030 - lr: 1.0000e-06\n",
            "Epoch 173/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6998 - categorical_accuracy: 0.1099 - val_loss: 6.4111 - val_categorical_accuracy: 0.4042 - lr: 1.0000e-06\n",
            "Epoch 174/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6991 - categorical_accuracy: 0.1100 - val_loss: 6.4073 - val_categorical_accuracy: 0.4032 - lr: 1.0000e-06\n",
            "Epoch 175/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6992 - categorical_accuracy: 0.1101 - val_loss: 6.4109 - val_categorical_accuracy: 0.4034 - lr: 1.0000e-06\n",
            "Epoch 176/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6990 - categorical_accuracy: 0.1101 - val_loss: 6.4122 - val_categorical_accuracy: 0.4040 - lr: 1.0000e-06\n",
            "Epoch 177/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6991 - categorical_accuracy: 0.1101 - val_loss: 6.4075 - val_categorical_accuracy: 0.4032 - lr: 1.0000e-06\n",
            "Epoch 178/200\n",
            "390/390 [==============================] - 21s 52ms/step - loss: 2.6996 - categorical_accuracy: 0.1099 - val_loss: 6.4099 - val_categorical_accuracy: 0.4030 - lr: 1.0000e-06\n",
            "Epoch 179/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.7001 - categorical_accuracy: 0.1099 - val_loss: 6.4090 - val_categorical_accuracy: 0.4034 - lr: 1.0000e-06\n",
            "Epoch 180/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6993 - categorical_accuracy: 0.1101 - val_loss: 6.4096 - val_categorical_accuracy: 0.4044 - lr: 1.0000e-06\n",
            "Epoch 181/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6996 - categorical_accuracy: 0.1100 - val_loss: 6.4063 - val_categorical_accuracy: 0.4035 - lr: 5.0000e-07\n",
            "Epoch 182/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6994 - categorical_accuracy: 0.1101 - val_loss: 6.4063 - val_categorical_accuracy: 0.4041 - lr: 5.0000e-07\n",
            "Epoch 183/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6985 - categorical_accuracy: 0.1102 - val_loss: 6.4091 - val_categorical_accuracy: 0.4031 - lr: 5.0000e-07\n",
            "Epoch 184/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6994 - categorical_accuracy: 0.1100 - val_loss: 6.4080 - val_categorical_accuracy: 0.4042 - lr: 5.0000e-07\n",
            "Epoch 185/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6990 - categorical_accuracy: 0.1102 - val_loss: 6.4085 - val_categorical_accuracy: 0.4041 - lr: 5.0000e-07\n",
            "Epoch 186/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.6990 - categorical_accuracy: 0.1099 - val_loss: 6.4099 - val_categorical_accuracy: 0.4035 - lr: 5.0000e-07\n",
            "Epoch 187/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6988 - categorical_accuracy: 0.1099 - val_loss: 6.4087 - val_categorical_accuracy: 0.4029 - lr: 5.0000e-07\n",
            "Epoch 188/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6995 - categorical_accuracy: 0.1099 - val_loss: 6.4081 - val_categorical_accuracy: 0.4034 - lr: 5.0000e-07\n",
            "Epoch 189/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6994 - categorical_accuracy: 0.1099 - val_loss: 6.4097 - val_categorical_accuracy: 0.4042 - lr: 5.0000e-07\n",
            "Epoch 190/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6994 - categorical_accuracy: 0.1100 - val_loss: 6.4089 - val_categorical_accuracy: 0.4038 - lr: 5.0000e-07\n",
            "Epoch 191/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6993 - categorical_accuracy: 0.1101 - val_loss: 6.4051 - val_categorical_accuracy: 0.4037 - lr: 5.0000e-07\n",
            "Epoch 192/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6997 - categorical_accuracy: 0.1098 - val_loss: 6.4087 - val_categorical_accuracy: 0.4039 - lr: 5.0000e-07\n",
            "Epoch 193/200\n",
            "390/390 [==============================] - 20s 52ms/step - loss: 2.6992 - categorical_accuracy: 0.1100 - val_loss: 6.4077 - val_categorical_accuracy: 0.4032 - lr: 5.0000e-07\n",
            "Epoch 194/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6993 - categorical_accuracy: 0.1101 - val_loss: 6.4100 - val_categorical_accuracy: 0.4042 - lr: 5.0000e-07\n",
            "Epoch 195/200\n",
            "390/390 [==============================] - 22s 57ms/step - loss: 2.6986 - categorical_accuracy: 0.1101 - val_loss: 6.4075 - val_categorical_accuracy: 0.4035 - lr: 5.0000e-07\n",
            "Epoch 196/200\n",
            "390/390 [==============================] - 21s 54ms/step - loss: 2.6993 - categorical_accuracy: 0.1100 - val_loss: 6.4093 - val_categorical_accuracy: 0.4042 - lr: 5.0000e-07\n",
            "Epoch 197/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6994 - categorical_accuracy: 0.1100 - val_loss: 6.4090 - val_categorical_accuracy: 0.4031 - lr: 5.0000e-07\n",
            "Epoch 198/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6989 - categorical_accuracy: 0.1100 - val_loss: 6.4081 - val_categorical_accuracy: 0.4035 - lr: 5.0000e-07\n",
            "Epoch 199/200\n",
            "390/390 [==============================] - 21s 55ms/step - loss: 2.6990 - categorical_accuracy: 0.1100 - val_loss: 6.4070 - val_categorical_accuracy: 0.4041 - lr: 5.0000e-07\n",
            "Epoch 200/200\n",
            "390/390 [==============================] - 21s 53ms/step - loss: 2.6990 - categorical_accuracy: 0.1100 - val_loss: 6.4074 - val_categorical_accuracy: 0.4033 - lr: 5.0000e-07\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "(42.68999993801117,\n",
              " array([ 0.99999998,  9.05999988, 10.08000001, 12.09999993, 15.77000022,\n",
              "        16.94000065, 15.35000056, 18.63999963, 19.32999939, 17.00000018,\n",
              "        23.6499995 , 23.78000021, 21.78999931, 24.56      , 23.48999977,\n",
              "        22.70999998, 24.37999994, 24.02999997, 23.51000011, 28.34999859,\n",
              "        25.51000118, 29.37000096, 26.96999907, 29.19999957, 28.18999887,\n",
              "        26.44000053, 29.28000093, 29.69999909, 32.57000148, 30.30000031,\n",
              "        28.2799989 , 30.03999889, 33.12000036, 30.6400001 , 31.61999881,\n",
              "        35.22000015, 32.35000074, 34.40000117, 32.0600003 , 33.82999897,\n",
              "        35.0699991 , 32.98000097, 33.79999995, 33.79000127, 32.06999898,\n",
              "        30.43000102, 36.00000143, 33.32000077, 32.76000023, 33.19999874,\n",
              "        34.76999998, 31.65000081, 34.13000107, 31.06999993, 35.76999903,\n",
              "        36.64999902, 34.59999859, 35.24000049, 29.64999974, 33.6800009 ,\n",
              "        36.52999997, 34.56999958, 32.74999857, 35.67000031, 31.65000081,\n",
              "        33.52000117, 34.20000076, 35.22999883, 35.24999917, 30.6400001 ,\n",
              "        35.26000082, 33.73000026, 37.56999969, 33.57000053, 34.04000103,\n",
              "        32.82999992, 35.40999889, 34.67999995, 33.36000144, 34.20000076,\n",
              "        32.53000081, 42.03000069, 42.42999852, 42.23999977, 42.68999994,\n",
              "        41.99000001, 42.08000004, 42.1299994 , 42.37999916, 41.7899996 ,\n",
              "        41.80000126, 41.62999988, 41.99999869, 41.49000049, 41.56999886,\n",
              "        41.2499994 , 41.60999954, 40.99000096, 41.65000021, 41.26000106,\n",
              "        41.33000076, 41.33000076, 40.84999859, 40.56000113, 40.65999985,\n",
              "        40.95000029, 40.79999924, 40.56999981, 40.52000046, 40.50000012,\n",
              "        40.13000131, 40.97000062, 40.86000025, 40.65999985, 39.84000087,\n",
              "        40.34999907, 39.98000026, 40.43000042, 40.02000093, 39.71000016,\n",
              "        40.04999995, 40.40000141, 40.32000005, 40.50000012, 40.63000083,\n",
              "        40.5400008 , 40.59000015, 40.75999856, 40.65999985, 40.61999917,\n",
              "        40.61000049, 40.66999853, 40.43000042, 40.56000113, 40.63000083,\n",
              "        40.52000046, 40.47999978, 40.58000147, 40.49000144, 40.54999948,\n",
              "        40.45000076, 40.50000012, 40.32999873, 40.59999883, 40.52000046,\n",
              "        40.41000009, 40.66999853, 40.52000046, 40.63000083, 40.45999944,\n",
              "        40.5099988 , 40.74999988, 40.58000147, 40.52999914, 40.41999876,\n",
              "        40.45999944, 40.41999876, 40.169999  , 40.23000002, 40.52999914,\n",
              "        40.36000073, 40.29999971, 40.29999971, 40.40000141, 40.49000144,\n",
              "        40.2700007 , 40.40000141, 40.32000005, 40.34000039, 40.38999975,\n",
              "        40.38000107, 40.38999975, 40.29999971, 40.41999876, 40.32000005,\n",
              "        40.34000039, 40.40000141, 40.32000005, 40.29999971, 40.34000039,\n",
              "        40.4399991 , 40.34999907, 40.41000009, 40.31000137, 40.41999876,\n",
              "        40.41000009, 40.34999907, 40.29000103, 40.34000039, 40.41999876,\n",
              "        40.38000107, 40.36999941, 40.38999975, 40.32000005, 40.41999876,\n",
              "        40.34999907, 40.41999876, 40.31000137, 40.34999907, 40.41000009,\n",
              "        40.32999873]))"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ]
    }
  ]
}