{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dfDL2zy0-d0u"
   },
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 519,
     "status": "ok",
     "timestamp": 1706798757973,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "1K1p461sY8Xz"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "import random\n",
    "import tensorflow as tf\n",
    "\n",
    "tf.random.set_seed(103847532)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mxgCxwQVT3pd"
   },
   "source": [
    "# Experiment Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 52,
     "status": "ok",
     "timestamp": 1706815922808,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "z1q4L1xuGlRw"
   },
   "outputs": [],
   "source": [
    "_TEACHER_ARCHITECTURE = 'resnet'\n",
    "_STUDENT_ARCHITECTURE = 'mobilenet'\n",
    "_DATASET_NAME = 'cifar100'\n",
    "_TEACHER_EPOCHS = 200\n",
    "_STUDENT_EPOCHS = 200\n",
    "\n",
    "_NUM_TRIALS = 1\n",
    "_TEACHER_BATCH_SIZE = 256\n",
    "_STUDENT_BATCH_SIZE = 128\n",
    "_TEACHER_RESNET_DEPTH = 92\n",
    "_STUDENT_RESNET_DEPTH = 56\n",
    "_ALPHA = 0.0\n",
    "_BETA = 0.0\n",
    "_EMBEDDING_LOSS = 'squared_difference'\n",
    "_NUM_SUBCLASSES = 1000\n",
    "_TYPE_OF_CLUSTERING = 'PCANoWeightRotation'\n",
    "_TEACHER_LABEL_TEMP = 1.0\n",
    "_TEMPERATURE = 10.0\n",
    "_ADD_MLP = False\n",
    "_NUMBER_OF_LABELED_EXAMPLES = None\n",
    "_USE_SUBCLASS_TEACHER = False\n",
    "_AUXILIARY_LOSS_PARAM = 0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gXb1pG73_bij"
   },
   "source": [
    "# Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 2588,
     "status": "ok",
     "timestamp": 1706806674629,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "lLWkr4p8_fko"
   },
   "outputs": [],
   "source": [
    "\"\"\"Methods for loading datasets.\"\"\"\n",
    "\n",
    "import copy\n",
    "import datetime as dt\n",
    "\n",
    "import numpy as np\n",
    "import scipy.optimize\n",
    "import sklearn.cluster\n",
    "import sklearn.decomposition\n",
    "import sklearn.manifold\n",
    "import sklearn.mixture\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow_hub as hub\n",
    "import tensorflow_text as tf_text  # Required for the preprocessor.\n",
    "\n",
    "\n",
    "def load_cifar_data(start=0, end=50000, num_classes=10):\n",
    "  \"\"\"Loads cifar10 or cifar100 dataset.\"\"\"\n",
    "  if num_classes == 10:\n",
    "    cifar_data = tfds.load('cifar10')\n",
    "  else:\n",
    "    cifar_data = tfds.load('cifar100')\n",
    "\n",
    "  x_train = np.zeros([50000, 32, 32, 3], dtype=np.uint8)\n",
    "  y_train = np.zeros([50000, 1], dtype=np.uint8)\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(cifar_data['train']):\n",
    "    x_train[i] = example['image']\n",
    "    y_train[i] = example['label']\n",
    "    i += 1\n",
    "\n",
    "  x_test = np.zeros([10000, 32, 32, 3])\n",
    "  y_test = np.zeros([10000, 1])\n",
    "\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(cifar_data['test']):\n",
    "    x_test[i] = example['image']\n",
    "    y_test[i] = example['label']\n",
    "    i += 1\n",
    "\n",
    "  return (x_train[start:end], y_train[start:end]), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_cifarbin_data(start=0, end=50000, num_classes=10):\n",
    "  \"\"\"Loads cifar10 or cifar100 dataset.\"\"\"\n",
    "  if num_classes == 10:\n",
    "    cifar_data = tfds.load('cifar10')\n",
    "  else:\n",
    "    cifar_data = tfds.load('cifar100')\n",
    "\n",
    "  x_train = np.zeros([50000, 32, 32, 3], dtype=np.uint8)\n",
    "  y_train = np.zeros([50000, 1], dtype=np.uint8)\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(cifar_data['train']):\n",
    "    x_train[i] = example['image']\n",
    "    y_train[i] = example['label'] % 2\n",
    "    i += 1\n",
    "\n",
    "  x_test = np.zeros([10000, 32, 32, 3])\n",
    "  y_test = np.zeros([10000, 1])\n",
    "\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(cifar_data['test']):\n",
    "    x_test[i] = example['image']\n",
    "    y_test[i] = example['label'] % 2\n",
    "    i += 1\n",
    "\n",
    "  return (x_train[start:end], y_train[start:end]), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_cifar_corrupted_data(start=0, end=50000):\n",
    "  \"\"\"Loads cifar10_corrupted.\"\"\"\n",
    "  cifar_data = tfds.load('cifar10')\n",
    "  cifar_test_data = tfds.load('cifar10_corrupted')\n",
    "\n",
    "  x_train = np.zeros([50000, 32, 32, 3], dtype=np.uint8)\n",
    "  y_train = np.zeros([50000, 1], dtype=np.uint8)\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(cifar_data['train']):\n",
    "    x_train[i] = example['image']\n",
    "    y_train[i] = example['label']\n",
    "    i += 1\n",
    "\n",
    "  x_test = np.zeros([10000, 32, 32, 3])\n",
    "  y_test = np.zeros([10000, 1])\n",
    "\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(cifar_test_data['test']):\n",
    "    x_test[i] = example['image']\n",
    "    y_test[i] = example['label']\n",
    "    i += 1\n",
    "\n",
    "  return (x_train[start:end], y_train[start:end]), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_svhn_data(start=0, end=50000):\n",
    "  \"\"\"Loads SVHN.\"\"\"\n",
    "\n",
    "  svhn_data = tfds.load('svhn_cropped')\n",
    "\n",
    "  x_train = np.zeros([73257, 32, 32, 3], dtype=np.uint8)\n",
    "  y_train = np.zeros([73257, 1], dtype=np.uint8)\n",
    "\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(svhn_data['train']):\n",
    "    x_train[i] = example['image']\n",
    "    y_train[i] = example['label']\n",
    "    i += 1\n",
    "  x_test = np.zeros([26032, 32, 32, 3])\n",
    "  y_test = np.zeros([26032, 1])\n",
    "\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(svhn_data['test']):\n",
    "    x_test[i] = example['image']\n",
    "    y_test[i] = example['label']\n",
    "    i += 1\n",
    "\n",
    "  return (x_train[start:end], y_train[start:end]), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_celeb_a_data(start=0, end=162770, label_key='Male', group_key='Young'):\n",
    "  \"\"\"Loads celeb_a.\"\"\"\n",
    "\n",
    "  get_image_and_label = lambda feat_dict: (feat_dict['image'], feat_dict[\n",
    "      'attributes'][label_key])\n",
    "\n",
    "  def preprocess_input_dict(feat_dict):\n",
    "    # Separate out the image and target variable from the feature dictionary.\n",
    "    image = feat_dict['image']\n",
    "    label = feat_dict['attributes'][label_key]\n",
    "    group = feat_dict['attributes'][group_key]\n",
    "\n",
    "    image = tf.cast(image, tf.float32)\n",
    "    image = tf.image.resize(image, [32, 32])\n",
    "\n",
    "    label = tf.cast(label, tf.float32)\n",
    "    group = tf.cast(group, tf.float32)\n",
    "\n",
    "    feat_dict['image'] = image\n",
    "    feat_dict['attributes'][label_key] = label\n",
    "    feat_dict['attributes'][group_key] = group\n",
    "\n",
    "    return feat_dict\n",
    "\n",
    "  def celeb_a_train_data_wo_group(batch_size):\n",
    "    celeb_a_train_data = celeb_a_builder.as_dataset(\n",
    "        split='train').batch(batch_size).map(preprocess_input_dict)\n",
    "    return celeb_a_train_data.map(get_image_and_label)\n",
    "\n",
    "  celeb_a_builder = tfds.builder('celeb_a', version='2.0.1')\n",
    "\n",
    "  celeb_a_builder.download_and_prepare()\n",
    "\n",
    "  celeb_a_train_data = celeb_a_train_data_wo_group(1)\n",
    "  # Test data for the overall evaluation\n",
    "  celeb_a_test_data = celeb_a_builder.as_dataset(\n",
    "      split='test').batch(1).map(preprocess_input_dict).map(get_image_and_label)\n",
    "\n",
    "  x_train = np.zeros([162770, 32, 32, 3], dtype=np.uint8)\n",
    "  y_train = np.zeros([162770, 1], dtype=np.uint8)\n",
    "\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(celeb_a_train_data):\n",
    "    x_train[i] = example[0]\n",
    "    y_train[i] = example[1]\n",
    "    i += 1\n",
    "\n",
    "  x_test = np.zeros([19962, 32, 32, 3])\n",
    "  y_test = np.zeros([19962, 1])\n",
    "\n",
    "  i = 0\n",
    "  for example in tfds.as_numpy(celeb_a_test_data):\n",
    "    x_test[i] = example[0]\n",
    "    y_test[i] = example[1]\n",
    "    i += 1\n",
    "\n",
    "  return (x_train[start:end], y_train[start:end]), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_imdb_reviews_data(start=0, end=25000):\n",
    "  \"\"\"Loads the imdb reviews dataset.\"\"\"\n",
    "\n",
    "  train_dataset = tfds.load('imdb_reviews', split='train', shuffle_files=False)\n",
    "  test_dataset = tfds.load('imdb_reviews', split='test', shuffle_files=False)\n",
    "\n",
    "  x_train = []\n",
    "  y_train = []\n",
    "  for example in tfds.as_numpy(train_dataset):\n",
    "    x_train.append(example['text'])\n",
    "    y_train.append(example['label'])\n",
    "  x_train = np.array(x_train)\n",
    "  y_train = np.array(y_train)\n",
    "\n",
    "  x_test = []\n",
    "  y_test = []\n",
    "  for example in tfds.as_numpy(test_dataset):\n",
    "    x_test.append(example['text'])\n",
    "    y_test.append(example['label'])\n",
    "  x_test = np.array(x_test)\n",
    "  y_test = np.array(y_test)\n",
    "\n",
    "  return (x_train[start:end], y_train[start:end]), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_glue_data(name, start, end):\n",
    "  \"\"\"Loads the glue datasets.\"\"\"\n",
    "\n",
    "  if name == 'glue_cola':\n",
    "    train_dataset = tfds.load('glue/cola', split='train', shuffle_files=False)\n",
    "    test_dataset = tfds.load(\n",
    "        'glue/cola', split='validation', shuffle_files=False\n",
    "    )\n",
    "    text_data = 'sentence'\n",
    "  elif name == 'glue_sst2':\n",
    "    train_dataset = tfds.load('glue/sst2', split='train', shuffle_files=False)\n",
    "    test_dataset = tfds.load(\n",
    "        'glue/sst2', split='validation', shuffle_files=False\n",
    "    )\n",
    "    text_data = 'sentence'\n",
    "  else:\n",
    "    raise ValueError(f'Unsupported glue dataset: {name}')\n",
    "\n",
    "  x_train = []\n",
    "  y_train = []\n",
    "  for example in tfds.as_numpy(train_dataset):\n",
    "    x_train.append(example[text_data])\n",
    "    y_train.append(example['label'])\n",
    "  x_train = np.array(x_train)\n",
    "  y_train = np.array(y_train)\n",
    "\n",
    "  x_test = []\n",
    "  y_test = []\n",
    "  for example in tfds.as_numpy(test_dataset):\n",
    "    x_test.append(example[text_data])\n",
    "    y_test.append(example['label'])\n",
    "  x_test = np.array(x_test)\n",
    "  y_test = np.array(y_test)\n",
    "\n",
    "  return (x_train[start:end], y_train[start:end]), (x_test, y_test)\n",
    "\n",
    "\n",
    "def normalize_x(x):\n",
    "  x = x.astype('float32')\n",
    "  x /= 127.5\n",
    "  x -= 1.0\n",
    "  return x\n",
    "\n",
    "\n",
    "def normalize_y(y, num_classes):\n",
    "  y = tf.keras.utils.to_categorical(y, num_classes)\n",
    "  return y\n",
    "\n",
    "\n",
    "def prepare_nlp_dataset(\n",
    "    examples: tf.Tensor | np.ndarray,\n",
    "    labels: tf.Tensor | np.ndarray,\n",
    "    batch_size: int,\n",
    "    text_data: str,\n",
    "    preprocess: bool = True,\n",
    "    expand_label_dims: bool = True,\n",
    "):\n",
    "  \"\"\"Preprocesses the nlp dataset.\n",
    "\n",
    "  Args:\n",
    "    examples: The examples of the dataset.\n",
    "    labels: The labels of the dataset.\n",
    "    batch_size: The batch size.\n",
    "    text_data: A string correspodning to the key of tex_data\n",
    "    preprocess: Whether to preprocess the dataset.\n",
    "    expand_label_dims: Whether to expand the label dims.\n",
    "\n",
    "  Returns:\n",
    "    A preprocessed dataset.\n",
    "  \"\"\"\n",
    "  dataset = tf.data.Dataset.from_tensor_slices(\n",
    "      {text_data: examples, 'label': labels}\n",
    "  )\n",
    "  batched_dataset = dataset.batch(batch_size, drop_remainder=True)\n",
    "\n",
    "  if preprocess and expand_label_dims:\n",
    "    preprocessor = hub.KerasLayer(\n",
    "        '    )\n",
    "    return batched_dataset.map(\n",
    "        lambda x: (\n",
    "            preprocessor(x[text_data]),\n",
    "            tf.expand_dims(x['label'], axis=-1),\n",
    "        )\n",
    "    )\n",
    "  if preprocess:\n",
    "    preprocessor = hub.KerasLayer(\n",
    "        '    )\n",
    "    return batched_dataset.map(\n",
    "        lambda x: (\n",
    "            preprocessor(x[text_data]),\n",
    "            x['label'],\n",
    "        )\n",
    "    )\n",
    "\n",
    "  return batched_dataset\n",
    "\n",
    "\n",
    "def prepare_subclasses_dataset(\n",
    "    teacher_model: tf.keras.Model,\n",
    "    teacher_architecture: str,\n",
    "    train_x: np.ndarray,\n",
    "    train_y: np.ndarray,\n",
    "    num_classes: int,\n",
    "    num_subclasses: int,\n",
    "    type_of_clustering: str = 'Agglomerative',\n",
    "    subclass_teacher: bool = False,\n",
    "    text_data='text',\n",
    "    temperature: float = 1.0,\n",
    "    soft_labels: bool = False,\n",
    "):\n",
    "  \"\"\"Prepare the subclasses dataset.\"\"\"\n",
    "\n",
    "  def zero_out_part(x: np.ndarray, part: int):\n",
    "    n = x.size\n",
    "    if part == 0:\n",
    "      for i in range(n // 2):\n",
    "        x[i] = 0.0\n",
    "    elif part == 1:\n",
    "      for i in range(n // 2, n):\n",
    "        x[i] = 0.0\n",
    "    else:\n",
    "      raise ValueError(f'Invalid part: {part}')\n",
    "    return x\n",
    "\n",
    "  # This works only for binary classification datasets.\n",
    "  # TO DO: Make it work for multiclass datasets.\n",
    "\n",
    "  if 'bert' in teacher_architecture:\n",
    "    train_set = prepare_nlp_dataset(\n",
    "        examples=train_x,\n",
    "        labels=train_y,\n",
    "        batch_size=1,\n",
    "        text_data=text_data,\n",
    "    )\n",
    "  else:\n",
    "    train_set = None\n",
    "\n",
    "  if subclass_teacher:\n",
    "    if 'bert' in teacher_architecture:\n",
    "      teacher_predictions = teacher_model.predict(train_set)\n",
    "    else:\n",
    "      teacher_predictions = teacher_model.predict(train_x)\n",
    "    teacher_logits = tf.math.log(teacher_predictions)\n",
    "    scaled_teacher_logits = teacher_logits / temperature\n",
    "    teacher_predictions = tf.nn.softmax(scaled_teacher_logits / temperature)\n",
    "    return train_x, teacher_predictions\n",
    "\n",
    "  if teacher_architecture == 'resnet':\n",
    "    teacher_model_with_embeddings = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-3].output\n",
    "    )\n",
    "  elif teacher_architecture == 'mobilenet':\n",
    "    penultimate = teacher_model.get_layer('dropout').output\n",
    "    embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "    teacher_model_with_embeddings = tf.keras.Model(\n",
    "        teacher_model.inputs, embeddings_layer\n",
    "    )\n",
    "  elif 'bert' in teacher_architecture:\n",
    "    penultimate = teacher_model.get_layer('bert_encoder').output['default']\n",
    "    embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "    teacher_model_with_embeddings = tf.keras.Model(\n",
    "        teacher_model.inputs, embeddings_layer\n",
    "    )\n",
    "  else:\n",
    "    raise ValueError(\n",
    "        f'Unsupported teacher architecture: {teacher_architecture}'\n",
    "    )\n",
    "\n",
    "  if 'bert' in teacher_architecture:\n",
    "    teacher_embeddings = teacher_model_with_embeddings.predict(train_set)\n",
    "  else:\n",
    "    teacher_embeddings = teacher_model_with_embeddings.predict(train_x)\n",
    "  dataset_size = len(train_x)\n",
    "\n",
    "  embeddings_class_0 = np.zeros(\n",
    "      [dataset_size, teacher_embeddings.shape[1]], dtype=np.float32\n",
    "  )\n",
    "  y_train_class_0 = np.zeros([dataset_size, num_subclasses], dtype=np.float32)\n",
    "  x_train_class_0 = np.copy(train_x[0:dataset_size])\n",
    "  embeddings_class_1 = np.zeros(\n",
    "      [dataset_size, teacher_embeddings.shape[1]], dtype=np.float32\n",
    "  )\n",
    "  y_train_class_1 = np.zeros([dataset_size, num_subclasses], dtype=np.float32)\n",
    "  x_train_class_1 = np.copy(train_x[0:dataset_size])\n",
    "\n",
    "  if type_of_clustering == 'true_clustering' and num_subclasses == 10:\n",
    "    _, true_train_y, _, _ = get_data(0, len(train_x), 'cifar10')\n",
    "  elif type_of_clustering == 'true_clustering' and num_subclasses == 100:\n",
    "    _, true_train_y, _, _ = get_data(0, len(train_x), 'cifar100')\n",
    "  else:\n",
    "    true_train_y = None\n",
    "\n",
    "  ones = 0\n",
    "  zeros = 0\n",
    "  for i in range(len(train_y)):\n",
    "    if np.argmax(train_y[i]) == 1:\n",
    "      embeddings_class_1[ones] = teacher_embeddings[i]\n",
    "      x_train_class_1[ones] = train_x[i]\n",
    "      if type_of_clustering == 'true_clustering':\n",
    "        y_train_class_1[ones] = true_train_y[i]\n",
    "      ones += 1\n",
    "    elif np.argmax(train_y[i]) == 0:\n",
    "      embeddings_class_0[zeros] = teacher_embeddings[i]\n",
    "      x_train_class_0[zeros] = train_x[i]\n",
    "      if type_of_clustering == 'true_clustering':\n",
    "        y_train_class_0[zeros] = true_train_y[i]\n",
    "      zeros += 1\n",
    "  print(f'ones: {ones}, zeros: {zeros}')\n",
    "  x_train_class_0 = x_train_class_0[0:zeros]\n",
    "  y_train_class_0 = y_train_class_0[0:zeros]\n",
    "  x_train_class_1 = x_train_class_1[0:ones]\n",
    "  y_train_class_1 = y_train_class_1[0:ones]\n",
    "\n",
    "  cluster_predictions_0 = None\n",
    "  cluster_predictions_1 = None\n",
    "\n",
    "  if type_of_clustering == 'Agglomerative':\n",
    "    cluster_0 = sklearn.cluster.AgglomerativeClustering(\n",
    "        n_clusters=num_subclasses // num_classes,\n",
    "        affinity='euclidean',\n",
    "        linkage='complete',\n",
    "        distance_threshold=None,\n",
    "    )\n",
    "    cluster_1 = sklearn.cluster.AgglomerativeClustering(\n",
    "        n_clusters=num_subclasses // num_classes,\n",
    "        affinity='euclidean',\n",
    "        linkage='complete',\n",
    "        distance_threshold=None,\n",
    "    )\n",
    "    cluster_predictions_0 = cluster_0.fit_predict(embeddings_class_0)\n",
    "    cluster_predictions_1 = cluster_1.fit_predict(embeddings_class_1)\n",
    "\n",
    "  elif type_of_clustering == 'kmeans':\n",
    "    cluster_0 = sklearn.cluster.KMeans(\n",
    "        n_clusters=num_subclasses // num_classes, random_state=0\n",
    "    )\n",
    "    cluster_1 = sklearn.cluster.KMeans(\n",
    "        n_clusters=num_subclasses // num_classes, random_state=0\n",
    "    )\n",
    "    cluster_0.fit(embeddings_class_0)\n",
    "    cluster_1.fit(embeddings_class_1)\n",
    "    cluster_predictions_0 = cluster_0.predict(embeddings_class_0)\n",
    "    cluster_predictions_1 = cluster_1.predict(embeddings_class_1)\n",
    "  elif type_of_clustering == 'tsne_kmeans':\n",
    "    tsne_cluster_0 = sklearn.manifold.TSNE(\n",
    "        n_components=2, learning_rate='auto', init='random'\n",
    "    )\n",
    "    tsne_cluster_1 = sklearn.manifold.TSNE(\n",
    "        n_components=2, learning_rate='auto', init='random'\n",
    "    )\n",
    "    tsne_embedd_0 = tsne_cluster_0.fit_transform(embeddings_class_0)\n",
    "    tsne_embedd_1 = tsne_cluster_1.fit_transform(embeddings_class_1)\n",
    "    cluster_0 = sklearn.cluster.KMeans(\n",
    "        n_clusters=num_subclasses // num_classes, random_state=0\n",
    "    )\n",
    "    cluster_1 = sklearn.cluster.KMeans(\n",
    "        n_clusters=num_subclasses // num_classes, random_state=0\n",
    "    )\n",
    "    cluster_0.fit(tsne_embedd_0)\n",
    "    cluster_1.fit(tsne_embedd_1)\n",
    "    cluster_predictions_0 = cluster_0.predict(tsne_embedd_0)\n",
    "    cluster_predictions_1 = cluster_1.predict(tsne_embedd_1)\n",
    "  elif type_of_clustering == 'true_clustering':\n",
    "    cluster_predictions_0 = np.copy(np.argmax(y_train_class_0, axis=1)) // 2\n",
    "    cluster_predictions_1 = np.copy(np.argmax(y_train_class_1, axis=1)) // 2\n",
    "\n",
    "  x_subclass_train = np.copy(train_x)\n",
    "  y_subclass_train = np.zeros([len(train_x), 1], dtype=np.float32)\n",
    "\n",
    "  n0 = len(x_train_class_0)\n",
    "  n1 = len(x_train_class_1)\n",
    "\n",
    "  for i in range(n0):\n",
    "    x_subclass_train[i] = x_train_class_0[i]\n",
    "    y_subclass_train[i] = cluster_predictions_0[i]\n",
    "\n",
    "  for i in range(n0, n0 + n1):\n",
    "    x_subclass_train[i] = x_train_class_1[i - n0]\n",
    "    y_subclass_train[i] = (\n",
    "        cluster_predictions_1[i - n0] + num_subclasses // num_classes\n",
    "    )\n",
    "\n",
    "  y_subclass_train = normalize_y(y_subclass_train, num_subclasses)\n",
    "\n",
    "  if soft_labels:\n",
    "    teacher_model_with_logits = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-2].output\n",
    "    )\n",
    "    teacher_logits = teacher_model_with_logits.predict(x_subclass_train)\n",
    "    teacher_predictions_tensor = tf.nn.softmax(teacher_logits / temperature)\n",
    "    teacher_predictions = teacher_predictions_tensor.numpy()\n",
    "    for i in range(n0):\n",
    "      temp = np.ones_like(y_subclass_train[i])\n",
    "      temp = temp/(num_subclasses // num_classes)\n",
    "      temp = zero_out_part(temp, 0)\n",
    "      y_subclass_train[i] = (\n",
    "          teacher_predictions[i][0] * y_subclass_train[i]\n",
    "          + teacher_predictions[i][1] * temp\n",
    "      )\n",
    "    for i in range(n0, n0 + n1):\n",
    "      temp = np.ones_like(y_subclass_train[i])\n",
    "      temp = temp/(num_subclasses // num_classes)\n",
    "      temp = zero_out_part(temp, 1)\n",
    "      y_subclass_train[i] = (\n",
    "          teacher_predictions[i][1] * y_subclass_train[i]\n",
    "          + teacher_predictions[i][0] * temp\n",
    "      )\n",
    "\n",
    "  rand_seed = int(dt.datetime.now().strftime('%f')[:-4])\n",
    "  x_subclass_train = np.random.RandomState(seed=rand_seed).permutation(\n",
    "      x_subclass_train\n",
    "  )\n",
    "  y_subclass_train = np.random.RandomState(seed=rand_seed).permutation(\n",
    "      y_subclass_train\n",
    "  )\n",
    "\n",
    "  return x_subclass_train, y_subclass_train\n",
    "\n",
    "\n",
    "def prepare_subclasses_dataset_pca(\n",
    "    teacher_model: tf.keras.Model,\n",
    "    teacher_architecture: str,\n",
    "    train_x: np.ndarray,\n",
    "    train_y: np.ndarray,\n",
    "    num_classes: int,\n",
    "    num_subclasses: int,\n",
    "    type_of_clustering: str = 'PCA',\n",
    "    subclass_teacher: bool = False,\n",
    "    text_data='text',\n",
    "    temperature: float = 1.0,\n",
    "):\n",
    "  \"\"\"Prepare the subclasses dataset for PCA subclasses.\"\"\"\n",
    "\n",
    "  if teacher_architecture == 'resnet':\n",
    "    teacher_model_with_embeddings = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-3].output\n",
    "    )\n",
    "    teacher_output_weight = teacher_model.get_weight_paths()[\n",
    "        'dense_1.kernel'\n",
    "    ].numpy()  # H x C\n",
    "  elif teacher_architecture == 'mobilenet':\n",
    "    penultimate = teacher_model.get_layer('dropout').output\n",
    "    embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "    teacher_model_with_embeddings = tf.keras.Model(\n",
    "        teacher_model.inputs, embeddings_layer\n",
    "    )\n",
    "    teacher_output_weight = teacher_model.get_weight_paths()[\n",
    "        'conv_preds.kernel'\n",
    "    ].numpy()[0, 0]\n",
    "  elif 'bert' in teacher_architecture:\n",
    "    penultimate = teacher_model.get_layer('bert_encoder').output['default']\n",
    "    embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "    teacher_model_with_embeddings = tf.keras.Model(\n",
    "        teacher_model.inputs, embeddings_layer\n",
    "    )\n",
    "    teacher_output_weight = teacher_model.get_weight_paths()[\n",
    "        'dense.kernel'\n",
    "    ].numpy()\n",
    "  else:\n",
    "    raise ValueError(\n",
    "        f'Unsupported teacher architecture: {teacher_architecture}'\n",
    "    )\n",
    "  teacher_model_with_logits = tf.keras.Model(\n",
    "      teacher_model.inputs, teacher_model.layers[-2].output\n",
    "  )\n",
    "\n",
    "  if 'bert' in teacher_architecture:\n",
    "    train_set = prepare_nlp_dataset(\n",
    "        examples=train_x,\n",
    "        labels=train_y,\n",
    "        batch_size=1,\n",
    "        text_data=text_data,\n",
    "    )\n",
    "  else:\n",
    "    train_set = None\n",
    "\n",
    "  if 'bert' in teacher_architecture:\n",
    "    teacher_embeddings = teacher_model_with_embeddings.predict(train_set)\n",
    "  else:\n",
    "    teacher_embeddings = teacher_model_with_embeddings.predict(train_x)\n",
    "\n",
    "  class_counts = np.zeros([num_classes])\n",
    "  x_train_classes = [[] for i in range(num_classes)]\n",
    "  y_train_classes = [[] for i in range(num_classes)]\n",
    "  embeddings_classes = [[] for i in range(num_classes)]\n",
    "\n",
    "  for i in range(len(train_y)):\n",
    "    true_class = np.argmax(train_y[i])\n",
    "    class_counts[true_class] += 1\n",
    "    x_train_classes[true_class].append(train_x[i])\n",
    "    y_train_classes[true_class].append(true_class)\n",
    "    embeddings_classes[true_class].append(teacher_embeddings[i])\n",
    "\n",
    "  for c in range(num_classes):\n",
    "    x_train_classes[c] = np.stack(x_train_classes[c], 0)\n",
    "    y_train_classes[c] = np.stack(y_train_classes[c], 0)\n",
    "    embeddings_classes[c] = np.stack(embeddings_classes[c], 0)\n",
    "\n",
    "  clustering_info = None\n",
    "\n",
    "  if type_of_clustering == 'PCANoWeightRotation':\n",
    "    teacher_output_weight = np.linalg.qr(teacher_output_weight)[0]\n",
    "    clustering_subclass_info = []\n",
    "    for c in range(num_classes):\n",
    "      # H x C\n",
    "      mags = embeddings_classes[c] @ teacher_output_weight\n",
    "      # N x C\n",
    "\n",
    "      embeddings_class = embeddings_classes[c] - mags @ teacher_output_weight.T\n",
    "\n",
    "      # include_zero =\n",
    "\n",
    "      cluster_model = sklearn.decomposition.PCA(\n",
    "          n_components=(num_subclasses // num_classes) - 0,\n",
    "      )\n",
    "\n",
    "      cluster_model.fit(embeddings_class)\n",
    "\n",
    "      pca_mean = tf.convert_to_tensor(cluster_model.mean_, dtype=tf.float32)\n",
    "\n",
    "      rotation = get_random_orthogonal_matrix(\n",
    "          (num_subclasses // num_classes) - 0\n",
    "      )\n",
    "\n",
    "      pca_components = tf.convert_to_tensor(\n",
    "          rotation @ cluster_model.components_, dtype=tf.float32\n",
    "      )  # C x H\n",
    "\n",
    "      random_proj = (\n",
    "          embeddings_class - pca_mean.numpy()\n",
    "      ) @ pca_components.numpy().T  # N x C\n",
    "\n",
    "      if num_subclasses // num_classes - 0 == 1:\n",
    "        # max_cov = 1.0\n",
    "        max_cov = np.max(np.cov(random_proj.T))\n",
    "      else:\n",
    "        max_cov = np.max(np.diag(np.cov(random_proj.T)))\n",
    "\n",
    "      clustering_info_class = {\n",
    "          'variance_normalization': np.sqrt(max_cov),\n",
    "          'mean': pca_mean,\n",
    "          'components': pca_components,\n",
    "      }\n",
    "      clustering_subclass_info.append(clustering_info_class)\n",
    "\n",
    "    clustering_info = {\n",
    "        'type': 'linear',\n",
    "        'subclass_info': clustering_subclass_info,\n",
    "    }\n",
    "  else:\n",
    "    raise Exception('Invalid type of clustering: {}'.format(type_of_clustering))\n",
    "\n",
    "  return clustering_info\n",
    "\n",
    "\n",
    "def get_data(start=0, end=50000, dataset_name='cifar100'):\n",
    "  \"\"\"Loads and preprocesses datasets.\"\"\"\n",
    "  if dataset_name == 'cifar10':\n",
    "    num_classes = 10\n",
    "    (train_x, train_y), (test_x, test_y) = load_cifar_data(\n",
    "        start, end, num_classes\n",
    "    )\n",
    "  elif dataset_name == 'cifar10bin':\n",
    "    num_classes = 10\n",
    "    (train_x, train_y), (test_x, test_y) = load_cifarbin_data(\n",
    "        start, end, num_classes\n",
    "    )\n",
    "    num_classes = 2\n",
    "  elif dataset_name == 'cifar100bin':\n",
    "    num_classes = 100\n",
    "    (train_x, train_y), (test_x, test_y) = load_cifarbin_data(\n",
    "        start, end, num_classes\n",
    "    )\n",
    "    num_classes = 2\n",
    "  elif dataset_name == 'cifar100':\n",
    "    num_classes = 100\n",
    "    (train_x, train_y), (test_x, test_y) = load_cifar_data(\n",
    "        start, end, num_classes\n",
    "    )\n",
    "  elif dataset_name == 'cifar10_corrupted':\n",
    "    num_classes = 10\n",
    "    (train_x, train_y), (test_x, test_y) = load_cifar_corrupted_data(start, end)\n",
    "  elif dataset_name == 'svhn':\n",
    "    num_classes = 10\n",
    "    (train_x, train_y), (test_x, test_y) = load_svhn_data(start, end)\n",
    "  elif dataset_name == 'celeb_a':\n",
    "    num_classes = 2\n",
    "    (train_x, train_y), (test_x, test_y) = load_celeb_a_data(start, end)\n",
    "  elif dataset_name == 'imdb_reviews':\n",
    "    num_classes = 2\n",
    "    (train_x, train_y), (test_x, test_y) = load_imdb_reviews_data(start, end)\n",
    "    return (\n",
    "        train_x,\n",
    "        normalize_y(train_y, num_classes),\n",
    "        test_x,\n",
    "        normalize_y(test_y, num_classes),\n",
    "    )\n",
    "  elif 'glue' in dataset_name:\n",
    "    num_classes = 2\n",
    "    (train_x, train_y), (test_x, test_y) = load_glue_data(\n",
    "        dataset_name, start, end\n",
    "    )\n",
    "    return (\n",
    "        train_x,\n",
    "        normalize_y(train_y, num_classes),\n",
    "        test_x,\n",
    "        normalize_y(test_y, num_classes),\n",
    "    )\n",
    "  else:\n",
    "    print('invalid dataset name')\n",
    "    return None, None, None, None\n",
    "  return (\n",
    "      normalize_x(train_x),\n",
    "      normalize_y(train_y, num_classes),\n",
    "      normalize_x(test_x),\n",
    "      normalize_y(test_y, num_classes),\n",
    "  )\n",
    "\n",
    "\n",
    "def get_orthogonal_part(embeddings, weight):\n",
    "  # embeddings is N x H\n",
    "  # weight is H x C\n",
    "  weight = np.linalg.qr(weight)[0]\n",
    "\n",
    "  mags = embeddings @ weight\n",
    "  # N x C\n",
    "\n",
    "  orthog = embeddings - mags @ weight.T\n",
    "\n",
    "  return orthog\n",
    "\n",
    "\n",
    "def get_random_orthogonal_matrix(size):\n",
    "  random_mat = np.random.normal(\n",
    "      size=[size, size],\n",
    "  )\n",
    "\n",
    "  return np.linalg.qr(random_mat)[0]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yXbZEsNK_--3"
   },
   "source": [
    "# Distillers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 3951,
     "status": "ok",
     "timestamp": 1706806498420,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "DaBCjwgfABAS"
   },
   "outputs": [],
   "source": [
    "\n",
    "import dataclasses\n",
    "import math\n",
    "from typing import Callable\n",
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class DistillerParam:\n",
    "  \"\"\"The parameters for the distiller.\n",
    "\n",
    "  Attributes:\n",
    "    num_classes: The number of classes.\n",
    "    student_loss_fn: The student loss function with respect to the ground truth.\n",
    "    distillation_loss_fn: The distillation loss function with respect to the\n",
    "      teachere's predictions.\n",
    "    embedding_loss_fn: The embedding loss function with respect to the teacher's\n",
    "      embeddings.\n",
    "    teacher_subclass_fn: A function which converts teacher logits and embeddings\n",
    "      into subclass probabilities. None if not using subclasses.\n",
    "    metric_fn: The metric function.\n",
    "    resnet_depth: The depth of the renset model.\n",
    "    learning_rate: The (initial) learning rate.\n",
    "    decay_steps: The decay steps for the learning rate schedule.\n",
    "    temperature: The temperature.\n",
    "    alpha: The coeficient of the student_loss_fn. The coefficient of the\n",
    "      distillation loss_fn is 1-alpha.\n",
    "    beta: The coefficient of the embeddings loss function.\n",
    "    embedding_dimension: The dimension of the student's embeddings\n",
    "    teacher_architecture: The name of the teacher architecture\n",
    "    vid_min_bound: The min bound for VID.\n",
    "    pretraining: A boolean variable determining whether the distiller is in\n",
    "      pretraining mode or not. This is valid only for the FITNET case.\n",
    "    teacher_label_temperature: temperature for subclass logits\n",
    "  \"\"\"\n",
    "\n",
    "  num_classes: int\n",
    "  student_loss_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | str | None = (\n",
    "      None\n",
    "  )\n",
    "  distillation_loss_fn: Callable[\n",
    "      [tf.Tensor, tf.Tensor], tf.Tensor\n",
    "  ] | str | None = None\n",
    "  embedding_loss_fn: Callable[\n",
    "      [tf.Tensor, tf.Tensor], tf.Tensor\n",
    "  ] | str | None = None\n",
    "  teacher_subclass_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | None = None\n",
    "  metric_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | str | None = None\n",
    "  resnet_depth: int | None = None\n",
    "  learning_rate: float | None = None\n",
    "  decay_steps: int | None = None\n",
    "  temperature: float | None = None\n",
    "  alpha: float | None = None\n",
    "  beta: float | None = None\n",
    "  embedding_dimension: int | None = None\n",
    "  teacher_architecture: str | None = 'resnet'\n",
    "  vid_min_bound: float = 0.0001\n",
    "  pretraining: bool = False\n",
    "  teacher_label_temperature: float | None = None\n",
    "\n",
    "\n",
    "class Distiller(tf.keras.Model):\n",
    "  \"\"\"A class for implementing distillation.\n",
    "\n",
    "  Attributes:\n",
    "    teacher_model: The teacher model, tf.keras.Model instance.\n",
    "    student_model: The student model, tf.keras.Model instance.\n",
    "    params: The parameters for distillation.\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      teacher_model: tf.keras.Model,\n",
    "      student_model: tf.keras.Model,\n",
    "      params: DistillerParam,\n",
    "  ):\n",
    "    super(Distiller, self).__init__()\n",
    "\n",
    "    if params.teacher_architecture == 'resnet':\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, teacher_model.layers[-3].output\n",
    "      )\n",
    "    elif params.teacher_architecture == 'mobilenet':\n",
    "      penultimate = teacher_model.get_layer('dropout').output\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer)\n",
    "    elif 'bert' in params.teacher_architecture:\n",
    "      penultimate = teacher_model.get_layer('bert_encoder').output['default']\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer)\n",
    "    else:\n",
    "      raise ValueError(\n",
    "          f'Unsupported teacher architecture: {params.teacher_architecture}'\n",
    "      )\n",
    "    self.teacher_model_with_logits = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-2].output\n",
    "    )\n",
    "    self.student_model = student_model\n",
    "    self.num_classes = params.num_classes\n",
    "    self.embedding_dimension = params.embedding_dimension\n",
    "\n",
    "  def compile(self, optimizer, params: DistillerParam):\n",
    "    super(Distiller, self).compile(\n",
    "        optimizer=optimizer, metrics=params.metric_fn\n",
    "    )\n",
    "    self.student_loss_fn = params.student_loss_fn\n",
    "    self.distillation_loss_fn = params.distillation_loss_fn\n",
    "    self.embedding_loss_fn = params.embedding_loss_fn\n",
    "    self.alpha = params.alpha\n",
    "    self.beta = params.beta\n",
    "    self.temperature = params.temperature\n",
    "\n",
    "  def call(self, inputs, training=None, mask=None):\n",
    "    return self.student_model(inputs, training=training)\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data\n",
    "\n",
    "    teacher_logits = self.teacher_model_with_logits(x, training=False)\n",
    "    teacher_probabilities = tf.nn.softmax(teacher_logits / self.temperature)\n",
    "    teacher_embeddings = self.teacher_model_with_embeddings(x, training=False)\n",
    "\n",
    "    if math.isclose(self.beta, 0.0):\n",
    "      if teacher_embeddings.shape[1] > self.embedding_dimension:\n",
    "        teacher_embeddings = teacher_embeddings[:, : self.embedding_dimension]\n",
    "      elif teacher_embeddings.shape[1] < self.embedding_dimension:\n",
    "        self.embedding_dimension = teacher_embeddings.shape[1]\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "      student_output = self.student_model(x, training=True)\n",
    "      student_logits = student_output[:, : self.num_classes]\n",
    "      scaled_student_probabilities = tf.nn.softmax(\n",
    "          student_logits / self.temperature\n",
    "      )\n",
    "      student_probabilities = tf.nn.softmax(student_logits)\n",
    "      student_embeddings = student_output[\n",
    "          :, self.num_classes : (self.embedding_dimension + self.num_classes)\n",
    "      ]\n",
    "      student_loss = tf.math.reduce_mean(\n",
    "          self.student_loss_fn(y, student_probabilities)\n",
    "      )\n",
    "      distillation_loss = (self.temperature) ** 2 * tf.math.reduce_mean(\n",
    "          self.distillation_loss_fn(\n",
    "              teacher_probabilities,\n",
    "              scaled_student_probabilities,\n",
    "          )\n",
    "      )\n",
    "      embedding_loss = tf.math.reduce_mean(\n",
    "          self.embedding_loss_fn(teacher_embeddings, student_embeddings)\n",
    "      )\n",
    "      loss = self.alpha * student_loss + (1 - self.alpha) * (\n",
    "          distillation_loss + self.beta * embedding_loss\n",
    "      )\n",
    "\n",
    "    trainable_vars = self.student_model.trainable_variables\n",
    "    self.optimizer.minimize(loss, trainable_vars, tape=tape)\n",
    "    self.compiled_metrics.update_state(y, student_output)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({\n",
    "        'student_loss': student_loss,\n",
    "        'distillation_loss': distillation_loss,\n",
    "        'embedding_loss': embedding_loss,\n",
    "    })\n",
    "    return results\n",
    "\n",
    "  def test_step(self, data):\n",
    "    x, y = data\n",
    "    student_output = self.student_model(x, training=False)\n",
    "    student_logits = student_output[:, : self.num_classes]\n",
    "    student_probabilities = tf.nn.softmax(student_logits)\n",
    "    student_loss = tf.math.reduce_mean(\n",
    "        self.student_loss_fn(y, student_probabilities)\n",
    "    )\n",
    "    self.compiled_metrics.update_state(y, student_probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({'student_loss': student_loss})\n",
    "    return results\n",
    "\n",
    "\n",
    "class FitNetDistiller(tf.keras.Model):\n",
    "  \"\"\"A class for implementing distillation.\n",
    "\n",
    "  Attributes:\n",
    "    teacher_model: The teacher model, tf.keras.Model instance.\n",
    "    student_model: The student model, tf.keras.Model instance.\n",
    "    params: The parameters for distillation.\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      teacher_model: tf.keras.Model,\n",
    "      student_model: tf.keras.Model,\n",
    "      params: DistillerParam,\n",
    "  ):\n",
    "    super(FitNetDistiller, self).__init__()\n",
    "\n",
    "    if params.teacher_architecture == 'resnet':\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, teacher_model.layers[-3].output\n",
    "      )\n",
    "    elif params.teacher_architecture == 'mobilenet':\n",
    "      penultimate = teacher_model.get_layer('dropout').output\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    elif 'bert' in params.teacher_architecture:\n",
    "      penultimate = teacher_model.get_layer('bert_encoder').output['default']\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    else:\n",
    "      raise ValueError(\n",
    "          f'Unsupported teacher architecture: {params.teacher_architecture}'\n",
    "      )\n",
    "    self.teacher_model_with_logits = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-2].output\n",
    "    )\n",
    "    self.alpha = params.alpha\n",
    "    self.beta = params.beta\n",
    "    self.student_model = student_model\n",
    "    self.num_classes = params.num_classes\n",
    "    self.embedding_dimension = params.embedding_dimension\n",
    "    self.pretraining = params.pretraining\n",
    "    self.temperature = params.temperature\n",
    "\n",
    "  def compile(self, optimizer, params: DistillerParam):\n",
    "    super(FitNetDistiller, self).compile(\n",
    "        optimizer=optimizer, metrics=params.metric_fn\n",
    "    )\n",
    "    self.student_loss_fn = params.student_loss_fn\n",
    "    self.distillation_loss_fn = params.distillation_loss_fn\n",
    "    self.embedding_loss_fn = params.embedding_loss_fn\n",
    "\n",
    "  def call(self, inputs, training=None, mask=None):\n",
    "    return self.student_model(inputs, training=training)\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data\n",
    "\n",
    "    teacher_logits = self.teacher_model_with_logits(x, training=False)\n",
    "    teacher_probabilities = tf.nn.softmax(teacher_logits / self.temperature)\n",
    "    teacher_embeddings = self.teacher_model_with_embeddings(x, training=False)\n",
    "\n",
    "    if math.isclose(self.beta, 0.0):\n",
    "      if teacher_embeddings.shape[1] > self.embedding_dimension:\n",
    "        teacher_embeddings = teacher_embeddings[:, : self.embedding_dimension]\n",
    "      elif teacher_embeddings.shape[1] < self.embedding_dimension:\n",
    "        self.embedding_dimension = teacher_embeddings.shape[1]\n",
    "    if self.pretraining:\n",
    "      with tf.GradientTape() as tape:\n",
    "        student_output = self.student_model(x, training=True)\n",
    "        student_embeddings = student_output[\n",
    "            :, self.num_classes : (self.embedding_dimension + self.num_classes)\n",
    "        ]\n",
    "        embedding_loss = tf.math.reduce_mean(\n",
    "            self.embedding_loss_fn(teacher_embeddings, student_embeddings)\n",
    "        )\n",
    "      trainable_vars = self.student_model.trainable_variables\n",
    "      self.optimizer.minimize(embedding_loss, trainable_vars, tape=tape)\n",
    "      self.compiled_metrics.update_state(y, student_output)\n",
    "      results = {m.name: m.result() for m in self.metrics}\n",
    "      results.update({\n",
    "          'embedding_loss': embedding_loss,\n",
    "      })\n",
    "      return results\n",
    "    else:\n",
    "      with tf.GradientTape() as tape:\n",
    "        student_output = self.student_model(x, training=True)\n",
    "        student_logits = student_output[:, : self.num_classes]\n",
    "        student_probabilities = tf.nn.softmax(student_logits)\n",
    "        scaled_student_probabilities = tf.nn.softmax(\n",
    "            student_logits / self.temperature\n",
    "        )\n",
    "        student_embeddings = student_output[\n",
    "            :, self.num_classes : (self.embedding_dimension + self.num_classes)\n",
    "        ]\n",
    "        student_loss = tf.math.reduce_mean(\n",
    "            self.student_loss_fn(y, student_probabilities)\n",
    "        )\n",
    "        distillation_loss = (self.temperature) ** 2 * tf.math.reduce_mean(\n",
    "            self.distillation_loss_fn(\n",
    "                teacher_probabilities,\n",
    "                scaled_student_probabilities,\n",
    "            )\n",
    "        )\n",
    "        embedding_loss = tf.math.reduce_mean(\n",
    "            self.embedding_loss_fn(teacher_embeddings, student_embeddings)\n",
    "        )\n",
    "        loss = self.alpha * student_loss + (1 - self.alpha) * (\n",
    "            distillation_loss + self.beta * embedding_loss\n",
    "        )\n",
    "\n",
    "      trainable_vars = self.student_model.trainable_variables\n",
    "      self.optimizer.minimize(loss, trainable_vars, tape=tape)\n",
    "      self.compiled_metrics.update_state(y, student_output)\n",
    "      results = {m.name: m.result() for m in self.metrics}\n",
    "      results.update({\n",
    "          'student_loss': student_loss,\n",
    "          'distillation_loss': distillation_loss,\n",
    "          'embedding_loss': embedding_loss,\n",
    "      })\n",
    "      return results\n",
    "\n",
    "  def test_step(self, data):\n",
    "    x, y = data\n",
    "    student_output = self.student_model(x, training=False)\n",
    "    student_logits = student_output[:, : self.num_classes]\n",
    "    student_probabilities = tf.nn.softmax(student_logits)\n",
    "    student_loss = tf.math.reduce_mean(\n",
    "        self.student_loss_fn(y, student_probabilities)\n",
    "    )\n",
    "    self.compiled_metrics.update_state(y, student_probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({'student_loss': student_loss})\n",
    "    return results\n",
    "\n",
    "\n",
    "class DistillerVID(tf.keras.Model):\n",
    "  \"\"\"A class for implementing VID distillation (for embeddings).\n",
    "\n",
    "  Attributes:\n",
    "    teacher_model: The teacher model, tf.keras.Model instance.\n",
    "    student_model: The student model, tf.keras.Model instance.\n",
    "    params: The parameters for distillation.\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      teacher_model: tf.keras.Model,\n",
    "      student_model: tf.keras.Model,\n",
    "      params: DistillerParam,\n",
    "  ):\n",
    "    super(DistillerVID, self).__init__()\n",
    "\n",
    "    if params.teacher_architecture == 'resnet':\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, teacher_model.layers[-3].output\n",
    "      )\n",
    "    elif params.teacher_architecture == 'mobilenet':\n",
    "      penultimate = teacher_model.get_layer('dropout').output\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    elif 'bert' in params.teacher_architecture:\n",
    "      penultimate = teacher_model.get_layer('bert_encoder').output['default']\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    else:\n",
    "      raise ValueError(\n",
    "          f'Unsupported teacher architecture: {params.teacher_architecture}'\n",
    "      )\n",
    "    self.teacher_model_with_logits = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-2].output\n",
    "    )\n",
    "    self.student_model = student_model\n",
    "    self.num_classes = params.num_classes\n",
    "    self.embedding_dimension = params.embedding_dimension\n",
    "    self.vid_min_bound = params.vid_min_bound\n",
    "\n",
    "  def compile(self, optimizer, params: DistillerParam):\n",
    "    super(DistillerVID, self).compile(\n",
    "        optimizer=optimizer, metrics=params.metric_fn\n",
    "    )\n",
    "    self.student_loss_fn = params.student_loss_fn\n",
    "    self.distillation_loss_fn = params.distillation_loss_fn\n",
    "    self.embedding_loss_fn = params.embedding_loss_fn\n",
    "    self.alpha = params.alpha\n",
    "    self.beta = params.beta\n",
    "    self.temperature = params.temperature\n",
    "\n",
    "  def call(self, inputs, training=None, mask=None):\n",
    "    return self.student_model(inputs, training=training)\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data\n",
    "\n",
    "    teacher_logits = self.teacher_model_with_logits(x, training=False)\n",
    "    teacher_probabilities = tf.nn.softmax(teacher_logits / self.temperature)\n",
    "    teacher_embeddings = self.teacher_model_with_embeddings(x, training=False)\n",
    "\n",
    "    if math.isclose(self.beta, 0.0):\n",
    "      if teacher_embeddings.shape[1] > self.embedding_dimension:\n",
    "        teacher_embeddings = teacher_embeddings[:, : self.embedding_dimension]\n",
    "      elif teacher_embeddings.shape[1] < self.embedding_dimension:\n",
    "        self.embedding_dimension = teacher_embeddings.shape[1]\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "      student_output = self.student_model(x, training=True)\n",
    "      student_logits = student_output[:, : self.num_classes]\n",
    "      student_probabilities = tf.nn.softmax(student_logits)\n",
    "      scaled_student_probabilities = tf.nn.softmax(\n",
    "          student_logits / self.temperature\n",
    "      )\n",
    "      student_embeddings = student_output[\n",
    "          :, self.num_classes : (self.embedding_dimension + self.num_classes)\n",
    "      ]\n",
    "      embeddings_scaling = student_output[\n",
    "          :,\n",
    "          self.num_classes\n",
    "          + self.embedding_dimension : (\n",
    "              2 * self.embedding_dimension + self.num_classes\n",
    "          ),\n",
    "      ]\n",
    "      student_loss = tf.math.reduce_mean(\n",
    "          self.student_loss_fn(y, student_probabilities)\n",
    "      )\n",
    "      distillation_loss = (self.temperature) ** 2 * tf.math.reduce_mean(\n",
    "          self.distillation_loss_fn(\n",
    "              teacher_probabilities,\n",
    "              scaled_student_probabilities,\n",
    "          )\n",
    "      )\n",
    "\n",
    "      positive_embeddings_scaling = (\n",
    "          tf.keras.activations.softplus(embeddings_scaling) + self.vid_min_bound\n",
    "      )\n",
    "      l2_scaled = (teacher_embeddings - student_embeddings) ** 2 / (\n",
    "          2 * positive_embeddings_scaling\n",
    "      )\n",
    "\n",
    "      mu_loss = tf.math.reduce_sum(\n",
    "          l2_scaled + tf.math.log(positive_embeddings_scaling) / 2, axis=1\n",
    "      )\n",
    "      embedding_loss = tf.math.reduce_mean(mu_loss)\n",
    "\n",
    "      loss = self.alpha * student_loss + (1 - self.alpha) * (\n",
    "          distillation_loss + self.beta * embedding_loss\n",
    "      )\n",
    "\n",
    "    trainable_vars = self.student_model.trainable_variables\n",
    "    self.optimizer.minimize(loss, trainable_vars, tape=tape)\n",
    "    self.compiled_metrics.update_state(y, student_output)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({\n",
    "        'student_loss': student_loss,\n",
    "        'distillation_loss': distillation_loss,\n",
    "        'embedding_loss': embedding_loss,\n",
    "    })\n",
    "    return results\n",
    "\n",
    "  def test_step(self, data):\n",
    "    x, y = data\n",
    "    student_output = self.student_model(x, training=False)\n",
    "    student_logits = student_output[:, : self.num_classes]\n",
    "    student_probabilities = tf.nn.softmax(student_logits)\n",
    "    student_loss = tf.math.reduce_mean(\n",
    "        self.student_loss_fn(y, student_probabilities)\n",
    "    )\n",
    "    self.compiled_metrics.update_state(y, student_probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({'student_loss': student_loss})\n",
    "    return results\n",
    "\n",
    "\n",
    "class SubclassDistiller(tf.keras.Model):\n",
    "  \"\"\"A class for implementing distillation with subclasses.\n",
    "\n",
    "  Attributes:\n",
    "    student_model: The student model, tf.keras.Model instance.\n",
    "    params: The parameters for distillation.\n",
    "    num_subclasses: The number of subclasses\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      student_model: tf.keras.Model,\n",
    "      params: DistillerParam,\n",
    "      num_subclasses: int,\n",
    "  ):\n",
    "    super(SubclassDistiller, self).__init__()\n",
    "\n",
    "    self.student_model = student_model\n",
    "    self.num_classes = params.num_classes\n",
    "    self.num_subclasses = num_subclasses\n",
    "    self.embedding_dimension = params.embedding_dimension\n",
    "    self.test_loss_fn = tf.keras.losses.CategoricalCrossentropy(\n",
    "        reduction=tf.keras.losses.Reduction.NONE\n",
    "    )\n",
    "\n",
    "  def compile(self, optimizer, params: DistillerParam):\n",
    "    super(SubclassDistiller, self).compile(\n",
    "        optimizer=optimizer, metrics=params.metric_fn\n",
    "    )\n",
    "    self.student_loss_fn = params.student_loss_fn\n",
    "    self.temperature = params.temperature\n",
    "\n",
    "  def call(self, inputs, training=None, mask=None):\n",
    "    return self.student_model(inputs, training=training)\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "      student_output = self.student_model(x, training=True)\n",
    "      student_logits = student_output[:, : self.num_subclasses]\n",
    "      scaled_student_probabilities = tf.nn.softmax(\n",
    "          student_logits / self.temperature\n",
    "      )\n",
    "      loss = (self.temperature) ** 2 * tf.math.reduce_mean(\n",
    "          self.student_loss_fn(y, scaled_student_probabilities)\n",
    "      )\n",
    "\n",
    "    trainable_vars = self.student_model.trainable_variables\n",
    "    self.optimizer.minimize(loss, trainable_vars, tape=tape)\n",
    "    self.compiled_metrics.update_state(y, student_output)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    return results\n",
    "\n",
    "  def test_step(self, data):\n",
    "    x, y = data\n",
    "    student_output = self.student_model(x, training=False)\n",
    "    student_logits = student_output[:, : self.num_subclasses]\n",
    "    subclass_predictions = tf.nn.softmax(student_logits / self.temperature)\n",
    "    reshaped_subclass_predictions = tf.reshape(\n",
    "        subclass_predictions,\n",
    "        [-1, self.num_classes, self.num_subclasses // self.num_classes],\n",
    "    )\n",
    "    student_probabilities = tf.math.reduce_sum(\n",
    "        reshaped_subclass_predictions, axis=2\n",
    "    )\n",
    "    student_loss = tf.math.reduce_mean(\n",
    "        self.test_loss_fn(y, student_probabilities)\n",
    "    )\n",
    "    self.compiled_metrics.update_state(y, student_probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({'student_loss': student_loss})\n",
    "    return results\n",
    "\n",
    "\n",
    "class SubclassTeacher(tf.keras.Model):\n",
    "  \"\"\"A class for implementing subclass distillation.\n",
    "\n",
    "  Attributes:\n",
    "    teacher_model: The teacher model, tf.keras.Model instance.\n",
    "    num_subclasses: The number of subclasses\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      teacher_model: tf.keras.Model,\n",
    "      num_classes: int,\n",
    "      num_subclasses: int,\n",
    "      loss_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],\n",
    "      metric_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | str,\n",
    "  ):\n",
    "    super(SubclassTeacher, self).__init__()\n",
    "\n",
    "    self.teacher_model = teacher_model\n",
    "    self.teacher_model_with_logits = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-2].output\n",
    "    )\n",
    "    self.num_classes = num_classes\n",
    "    self.num_subclasses = num_subclasses\n",
    "    self.loss_fn = loss_fn\n",
    "    self.metric_fn = metric_fn\n",
    "\n",
    "  def compile(self, optimizer):\n",
    "    super(SubclassTeacher, self).compile(\n",
    "        optimizer=optimizer, metrics=self.metric_fn\n",
    "    )\n",
    "\n",
    "  def call(self, inputs, training=None, mask=None):\n",
    "    return self.teacher_model(inputs, training=training)\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "      teacher_logits = self.teacher_model_with_logits(x, training=True)\n",
    "      subclass_predictions = tf.nn.softmax(teacher_logits)\n",
    "      reshaped_subclass_predictions = tf.reshape(\n",
    "          subclass_predictions,\n",
    "          [-1, self.num_classes, self.num_subclasses // self.num_classes],\n",
    "      )\n",
    "      probabilities = tf.math.reduce_sum(reshaped_subclass_predictions, axis=2)\n",
    "      loss = tf.math.reduce_mean(self.loss_fn(y, teacher_logits))\n",
    "\n",
    "    trainable_vars = self.teacher_model.trainable_variables\n",
    "    self.optimizer.minimize(loss, trainable_vars, tape=tape)\n",
    "    self.compiled_metrics.update_state(y, probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    return results\n",
    "\n",
    "  def test_step(self, data):\n",
    "    x, y = data\n",
    "    teacher_logits = self.teacher_model_with_logits(x, training=False)\n",
    "    subclass_predictions = tf.nn.softmax(teacher_logits)\n",
    "    reshaped_subclass_predictions = tf.reshape(\n",
    "        subclass_predictions,\n",
    "        [-1, self.num_classes, self.num_subclasses // self.num_classes],\n",
    "    )\n",
    "    probabilities = tf.math.reduce_sum(reshaped_subclass_predictions, axis=2)\n",
    "    loss = tf.math.reduce_mean(self.loss_fn(y, teacher_logits))\n",
    "    self.compiled_metrics.update_state(y, probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({'subclass_loss': loss})\n",
    "    return results\n",
    "\n",
    "\n",
    "class DynamicSubclassDistiller(tf.keras.Model):\n",
    "  \"\"\"A class for implementing distillation which makes teacher labels given teacher embeddings and a subclass function.\n",
    "\n",
    "  Attributes:\n",
    "    teacher_model: The teacher model, tf.keras.Model instance.\n",
    "    student_model: The student model, tf.keras.Model instance.\n",
    "    params: The parameters for distillation.\n",
    "    num_subclasses: The number of subclasses\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      teacher_model: tf.keras.Model,\n",
    "      student_model: tf.keras.Model,\n",
    "      params: DistillerParam,\n",
    "      num_subclasses: int,\n",
    "  ):\n",
    "    super(DynamicSubclassDistiller, self).__init__()\n",
    "\n",
    "    if params.teacher_architecture == 'resnet':\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, teacher_model.layers[-3].output\n",
    "      )\n",
    "    elif params.teacher_architecture == 'mobilenet':\n",
    "      penultimate = teacher_model.get_layer('dropout').output\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    elif 'bert' in params.teacher_architecture:\n",
    "      penultimate = teacher_model.get_layer('bert_encoder').output['default']\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    else:\n",
    "      raise ValueError(\n",
    "          f'Unsupported teacher architecture: {params.teacher_architecture}'\n",
    "      )\n",
    "    self.teacher_model_with_logits = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-2].output\n",
    "    )\n",
    "    self.student_model = student_model\n",
    "    self.num_classes = params.num_classes\n",
    "    self.embedding_dimension = params.embedding_dimension\n",
    "    self.num_subclasses = num_subclasses\n",
    "\n",
    "  def compile(self, optimizer, params: DistillerParam):\n",
    "    super(DynamicSubclassDistiller, self).compile(\n",
    "        optimizer=optimizer,\n",
    "        metrics=params.metric_fn,\n",
    "    )\n",
    "    self.student_loss_fn = params.student_loss_fn\n",
    "    self.distillation_loss_fn = params.distillation_loss_fn\n",
    "    self.embedding_loss_fn = params.embedding_loss_fn\n",
    "    self.alpha = params.alpha\n",
    "    self.beta = params.beta\n",
    "    self.temperature = params.temperature\n",
    "    self.teacher_subclass_fn = params.teacher_subclass_fn\n",
    "\n",
    "  def call(self, inputs, training=None, mask=None):\n",
    "    return self.student_model(inputs, training=training)\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data\n",
    "\n",
    "    teacher_logits = self.teacher_model_with_logits(x, training=False)\n",
    "    teacher_embeddings = self.teacher_model_with_embeddings(x, training=False)\n",
    "\n",
    "    teacher_subclass_logits = self.teacher_subclass_fn(\n",
    "        teacher_logits, teacher_embeddings\n",
    "    )\n",
    "\n",
    "    teacher_probabilities = tf.nn.softmax(\n",
    "        teacher_subclass_logits / self.temperature\n",
    "    )\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "      student_output = self.student_model(x, training=True)\n",
    "      student_logits = student_output[:, : self.num_subclasses]\n",
    "      student_probabilities = tf.nn.softmax(student_logits / self.temperature)\n",
    "\n",
    "      loss = (self.temperature) ** 2 * tf.math.reduce_mean(\n",
    "          self.student_loss_fn(teacher_probabilities, student_probabilities)\n",
    "      )\n",
    "\n",
    "    trainable_vars = self.student_model.trainable_variables\n",
    "    self.optimizer.minimize(loss, trainable_vars, tape=tape)\n",
    "    self.compiled_metrics.update_state(y, student_output)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    return results\n",
    "\n",
    "  def test_step(self, data):\n",
    "    x, y = data\n",
    "    student_output = self.student_model(x, training=False)\n",
    "    student_logits = student_output[:, : self.num_subclasses]\n",
    "    subclass_predictions = tf.nn.softmax(student_logits / self.temperature)\n",
    "    reshaped_subclass_predictions = tf.reshape(\n",
    "        subclass_predictions,\n",
    "        [-1, self.num_classes, self.num_subclasses // self.num_classes],\n",
    "    )\n",
    "    student_probabilities = tf.math.reduce_sum(\n",
    "        reshaped_subclass_predictions, axis=2\n",
    "    )\n",
    "    student_loss = tf.math.reduce_mean(\n",
    "        self.student_loss_fn(y, student_probabilities)\n",
    "    )\n",
    "    self.compiled_metrics.update_state(y, student_probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({'student_loss': student_loss})\n",
    "    return results\n",
    "\n",
    "\n",
    "class DynamicSubclassAdjustmentDistiller(tf.keras.Model):\n",
    "  \"\"\"A class for implementing distillation which makes teacher labels given teacher embeddings and a subclass function.\n",
    "\n",
    "  Attributes:\n",
    "    teacher_model: The teacher model, tf.keras.Model instance.\n",
    "    student_model: The student model, tf.keras.Model instance.\n",
    "    params: The parameters for distillation.\n",
    "    num_subclasses: The number of subclasses\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      teacher_model: tf.keras.Model,\n",
    "      student_model: tf.keras.Model,\n",
    "      params: DistillerParam,\n",
    "      num_subclasses: int,\n",
    "  ):\n",
    "    super(DynamicSubclassAdjustmentDistiller, self).__init__()\n",
    "\n",
    "    if params.teacher_architecture == 'resnet':\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, teacher_model.layers[-3].output\n",
    "      )\n",
    "    elif params.teacher_architecture == 'mobilenet':\n",
    "      penultimate = teacher_model.get_layer('dropout').output\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    elif 'bert' in params.teacher_architecture:\n",
    "      penultimate = teacher_model.get_layer('bert_encoder').output['default']\n",
    "      embeddings_layer = tf.keras.layers.Flatten()(penultimate)\n",
    "      self.teacher_model_with_embeddings = tf.keras.Model(\n",
    "          teacher_model.inputs, embeddings_layer\n",
    "      )\n",
    "    else:\n",
    "      raise ValueError(\n",
    "          f'Unsupported teacher architecture: {params.teacher_architecture}'\n",
    "      )\n",
    "    self.teacher_model_with_logits = tf.keras.Model(\n",
    "        teacher_model.inputs, teacher_model.layers[-2].output\n",
    "    )\n",
    "    self.student_model = student_model\n",
    "    self.num_classes = params.num_classes\n",
    "    self.embedding_dimension = params.embedding_dimension\n",
    "    self.num_subclasses = num_subclasses\n",
    "\n",
    "  def compile(self, optimizer, params: DistillerParam):\n",
    "    super(DynamicSubclassAdjustmentDistiller, self).compile(\n",
    "        optimizer=optimizer,\n",
    "        metrics=params.metric_fn,\n",
    "    )\n",
    "    self.student_loss_fn = params.student_loss_fn\n",
    "    self.distillation_loss_fn = params.distillation_loss_fn\n",
    "    self.embedding_loss_fn = params.embedding_loss_fn\n",
    "    self.alpha = params.alpha\n",
    "    self.beta = params.beta\n",
    "    self.temperature = params.temperature\n",
    "    self.teacher_label_temperature = params.teacher_label_temperature\n",
    "    self.teacher_subclass_fn = params.teacher_subclass_fn\n",
    "\n",
    "  def call(self, inputs, training=None, mask=None):\n",
    "    return self.student_model(inputs, training=training)\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data\n",
    "\n",
    "    teacher_logits = self.teacher_model_with_logits(x, training=False)\n",
    "    teacher_embeddings = self.teacher_model_with_embeddings(x, training=False)\n",
    "\n",
    "    teacher_subclass_logits = self.teacher_subclass_fn(\n",
    "        teacher_logits, teacher_embeddings\n",
    "    )\n",
    "\n",
    "    teacher_probabilities = tf.nn.softmax(\n",
    "        teacher_subclass_logits / self.temperature\n",
    "    )\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "      student_output = self.student_model(x, training=True)\n",
    "      student_logits = student_output[\n",
    "          :, : self.num_subclasses + self.num_classes\n",
    "      ]\n",
    "      student_base_logits = tf.nn.log_softmax(\n",
    "          student_logits[:, : self.num_classes], axis=-1\n",
    "      )\n",
    "      student_subclass_adjument_logits = tf.nn.log_softmax(\n",
    "          tf.reshape(\n",
    "              student_logits[:, self.num_classes :],\n",
    "              [\n",
    "                  student_logits.shape[0],\n",
    "                  self.num_classes,\n",
    "                  self.num_subclasses // self.num_classes,\n",
    "              ],\n",
    "          ),\n",
    "          axis=-1,\n",
    "      )\n",
    "\n",
    "      student_combined_logits = (\n",
    "          student_base_logits[:, :, None] / self.temperature\n",
    "          + student_subclass_adjument_logits / self.teacher_label_temperature\n",
    "      )\n",
    "      student_combined_logits = tf.reshape(\n",
    "          student_combined_logits, [student_logits.shape[0], -1]\n",
    "      )\n",
    "\n",
    "      student_probabilities = tf.nn.softmax(student_combined_logits)\n",
    "\n",
    "      loss = (self.temperature) ** 2 * tf.math.reduce_mean(\n",
    "          self.student_loss_fn(teacher_probabilities, student_probabilities)\n",
    "      )\n",
    "\n",
    "    trainable_vars = self.student_model.trainable_variables\n",
    "    self.optimizer.minimize(loss, trainable_vars, tape=tape)\n",
    "    self.compiled_metrics.update_state(y, student_output)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    return results\n",
    "\n",
    "  def test_step(self, data):\n",
    "    x, y = data\n",
    "    student_output = self.student_model(x, training=False)\n",
    "    student_logits = student_output[:, : self.num_classes]\n",
    "    # subclass_predictions = tf.nn.softmax(student_logits / self.temperature)\n",
    "    # reshaped_subclass_predictions = tf.reshape(\n",
    "    #     subclass_predictions,\n",
    "    #     [-1, self.num_classes, self.num_subclasses // self.num_classes],\n",
    "    # )\n",
    "    # student_probabilities = tf.math.reduce_sum(\n",
    "    #     reshaped_subclass_predictions, axis=2\n",
    "    # )\n",
    "    student_probabilities = tf.nn.softmax(student_logits, axis=-1)\n",
    "    student_loss = tf.math.reduce_mean(\n",
    "        self.student_loss_fn(y, student_probabilities)\n",
    "    )\n",
    "    self.compiled_metrics.update_state(y, student_probabilities)\n",
    "    results = {m.name: m.result() for m in self.metrics}\n",
    "    results.update({'student_loss': student_loss})\n",
    "    return results\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def subclass_teacher_loss(\n",
    "    alpha: float,\n",
    "    loss_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],\n",
    "    num_classes: int,\n",
    "    num_subclasses: int,\n",
    "    auxiliary_loss_temperature: float = 1.0,\n",
    ") -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:\n",
    "  \"\"\"Subclass teacher loss function.\n",
    "\n",
    "  Args:\n",
    "    alpha: The weight of the auxiliary loss.\n",
    "    loss_fn: The loss function to be used.\n",
    "    num_classes: The number of classes.\n",
    "    num_subclasses: The number fo subclasses.\n",
    "    auxiliary_loss_temperature: The temperature for the auxiliary loss.\n",
    "\n",
    "  Returns:\n",
    "    The subclass_teacher loss function.\n",
    "  \"\"\"\n",
    "\n",
    "  def auxiliary_loss(subclass_logits, auxiliary_loss_temperature):\n",
    "    \"\"\"Auxiliary loss for subclass distillation.\n",
    "\n",
    "    This auxiliary loss is applied to the teacher to encourage the model to use\n",
    "    all subclasses. It penalizes high correlation between the logits in a given\n",
    "    minibatch.\n",
    "    Args:\n",
    "      subclass_logits: A tf.Tensor containing the logits for every subclass.\n",
    "      auxiliary_loss_temperature: Temperature of auxiliary loss.\n",
    "\n",
    "    Returns:\n",
    "      A scalar Tensor containing the overall loss.\n",
    "    \"\"\"\n",
    "    batch_size = tf.cast(tf.shape(subclass_logits)[0], subclass_logits.dtype)\n",
    "    num_all_subclasses = tf.convert_to_tensor(\n",
    "        int(subclass_logits.shape[-1]), dtype=subclass_logits.dtype\n",
    "    )\n",
    "    mean_y, variance_y = tf.nn.moments(subclass_logits, -1, keepdims=True)\n",
    "    v = tf.nn.batch_normalization(\n",
    "        subclass_logits,\n",
    "        mean=mean_y,\n",
    "        variance=variance_y,\n",
    "        offset=None,\n",
    "        scale=1.0 / tf.sqrt(auxiliary_loss_temperature * num_all_subclasses),\n",
    "        variance_epsilon=1e-9,\n",
    "    )\n",
    "    u = tf.transpose(v)\n",
    "    m = tf.matmul(v, u)\n",
    "    m = tf.math.exp(m)\n",
    "    sums = tf.math.reduce_sum(m, axis=0)\n",
    "    log_sums = tf.math.log(sums)\n",
    "    outer_sum = tf.math.reduce_sum(log_sums, axis=0)\n",
    "    return (\n",
    "        outer_sum / batch_size\n",
    "        - 1 / auxiliary_loss_temperature\n",
    "        - tf.math.log(batch_size)\n",
    "    )\n",
    "\n",
    "  def _loss(y_true, y_pred):\n",
    "    subclass_predictions = tf.nn.softmax(y_pred)\n",
    "    reshaped_subclass_predictions = tf.reshape(\n",
    "        subclass_predictions,\n",
    "        [-1, num_classes, num_subclasses // num_classes],\n",
    "    )\n",
    "    probabilities = tf.math.reduce_sum(reshaped_subclass_predictions, axis=2)\n",
    "    loss = tf.math.reduce_mean(\n",
    "        loss_fn(y_true, probabilities)\n",
    "    ) + alpha * auxiliary_loss(\n",
    "        y_pred, auxiliary_loss_temperature=auxiliary_loss_temperature\n",
    "    )\n",
    "    return loss\n",
    "\n",
    "  return _loss\n",
    "\n",
    "\n",
    "\n",
    "def get_teacher_subclass_fn(\n",
    "    clustering_info,\n",
    "    teacher_label_temp=1.0,\n",
    "    soft_teacher_labels=False,\n",
    "    teacher_outer_label_temp=1.0,\n",
    "    num_classes=2,\n",
    "):\n",
    "  \"\"\"Returns a function that converts teacher logits to subclass labels.\n",
    "\n",
    "  Args:\n",
    "    clustering_info: parameters used to compute subclass labels.\n",
    "    teacher_label_temp: The temperature for subclass labels.\n",
    "    soft_teacher_labels: Whether to use soft superclass labels as well as soft\n",
    "      subclass inner labels.\n",
    "    teacher_outer_label_temp: teacher temperature for outer labels\n",
    "    num_classes: number of outer classes\n",
    "  Returns:\n",
    "    A function that converts teacher logits to subclass labels.\n",
    "\n",
    "  \"\"\"\n",
    "  if clustering_info is None:\n",
    "    return None\n",
    "\n",
    "  elif clustering_info['type'] == 'linear':\n",
    "    return get_projection_subclass_fn(\n",
    "        clustering_info,\n",
    "        teacher_label_temp,\n",
    "        soft_teacher_labels=soft_teacher_labels,\n",
    "        teacher_outer_label_temp=teacher_outer_label_temp,\n",
    "        num_classes=num_classes,\n",
    "    )\n",
    "\n",
    "  return None\n",
    "\n",
    "\n",
    "def get_projection_subclass_fn(\n",
    "    clustering_info,\n",
    "    teacher_label_temp=1.0,\n",
    "    soft_teacher_labels=False,\n",
    "    teacher_outer_label_temp=1.0,\n",
    "    num_classes=2,\n",
    "):\n",
    "  \"\"\"Returns a function that converts teacher logits to subclass labels.\n",
    "\n",
    "  Args:\n",
    "    clustering_info: parameters used to compute subclass labels.\n",
    "    teacher_label_temp: The temperature for subclass labels.\n",
    "    soft_teacher_labels: Whether to use soft superclass labels as well as soft\n",
    "      subclass inner labels.\n",
    "    teacher_outer_label_temp: teacher temperature for outer labels\n",
    "    num_classes: number of outer classes\n",
    "  \"\"\"\n",
    "\n",
    "  subclass_info = clustering_info['subclass_info']\n",
    "\n",
    "  means = tf.stack(\n",
    "      [subclass_info[c]['mean'] for c in range(num_classes)], 1\n",
    "  )  # H x C\n",
    "  components = tf.stack(\n",
    "      [\n",
    "          tf.transpose(subclass_info[c]['components'])\n",
    "          for c in range(num_classes)\n",
    "      ],\n",
    "      1,\n",
    "  )  # H x C x S\n",
    "  variance_normalizations = tf.convert_to_tensor(\n",
    "      np.stack(\n",
    "          [\n",
    "              subclass_info[c]['variance_normalization']\n",
    "              for c in range(num_classes)\n",
    "          ],\n",
    "          -1,\n",
    "      ),\n",
    "      dtype=tf.float32,\n",
    "  )  # C\n",
    "\n",
    "  def projection_subclass_fn(\n",
    "      teacher_logits, teacher_embeddings, onehot_labels=None\n",
    "  ):\n",
    "    # zero_mask = tf.cast(\n",
    "    #     tf.math.argmax(teacher_logits, axis=-1) == 0, tf.float32\n",
    "    # )\n",
    "    if onehot_labels is None:\n",
    "      one_hot_prediction = tf.one_hot(\n",
    "          tf.math.argmax(teacher_logits, axis=-1), teacher_logits.shape[0]\n",
    "      )\n",
    "    else:\n",
    "      one_hot_prediction = onehot_labels\n",
    "    linear_projection = (\n",
    "        teacher_embeddings[:, :, None] - means[None]\n",
    "    )  # B x H x C\n",
    "    linear_projection = tf.einsum(\n",
    "        'bhc, hcs->bcs', linear_projection, components\n",
    "    )  # bcs\n",
    "    linear_projection = linear_projection / (\n",
    "        variance_normalizations[None, :, None] * teacher_label_temp\n",
    "    )  # bcs\n",
    "\n",
    "    if soft_teacher_labels:\n",
    "      # teacher_logit_adjustment = teacher_logits\n",
    "      teacher_logit_adjustment = tf.nn.log_softmax(\n",
    "          teacher_logits / teacher_outer_label_temp, axis=-1\n",
    "      )\n",
    "      stack_probs = tf.nn.log_softmax(linear_projection, axis=-1)\n",
    "\n",
    "      # teacher_logits = tf.math.log(\n",
    "      #     teacher_base_probabilities * subclass_base_probabilities\n",
    "      # )\n",
    "      teacher_logits = teacher_outer_label_temp * tf.reshape(\n",
    "          teacher_logit_adjustment[:, :, None] + stack_probs,\n",
    "          [teacher_embeddings.shape[0], -1],\n",
    "      )\n",
    "    else:\n",
    "      teacher_logit_adjustment = -99999 * (1 - one_hot_prediction)\n",
    "\n",
    "      teacher_logits = tf.reshape(\n",
    "          teacher_logit_adjustment[:, :, None] + linear_projection,\n",
    "          [teacher_embeddings.shape[0], -1],\n",
    "      )\n",
    "\n",
    "    return teacher_logits\n",
    "\n",
    "  return projection_subclass_fn\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3Io7-UTq_dPN"
   },
   "source": [
    "# Model Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "executionInfo": {
     "elapsed": 910,
     "status": "ok",
     "timestamp": 1706806518365,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "Yz2k8PlV_np1"
   },
   "outputs": [],
   "source": [
    "\"\"\"A collection of Neural Net models to be used for experiments.\"\"\"\n",
    "\n",
    "from typing import Callable\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "\n",
    "\n",
    "\n",
    "BERT_ENCODERS = {\n",
    "}\n",
    "\n",
    "\n",
    "class Parameter(tf.keras.layers.Layer):\n",
    "  \"\"\"A wrapper class that creates a layer containing a trainable parameter.\n",
    "\n",
    "   Tensorflow requires all model outputs to be instances of\n",
    "   tf.keras.layers.Layers and this class provides a simple workaround to\n",
    "   create a trainable parameter that the model can output.\n",
    "   For its use see, e.g., the vid_modify() method.\n",
    "\n",
    "  Methods\n",
    "  ----------\n",
    "  build():\n",
    "      Creates parameter tensor of the same number as input_shape.\n",
    "  call():\n",
    "      Returns the parameter tensor (broadcasted to match the dimensions of\n",
    "      the batch).\n",
    "  \"\"\"\n",
    "\n",
    "  def build(self, input_shape: tf.Tensor) -> None:\n",
    "    \"\"\"Creates parameter tensor of the same number as input_shape.\"\"\"\n",
    "\n",
    "    self.parameter = self.add_weight(\n",
    "        'parameter', shape=input_shape[1:], initializer='zeros', trainable=True\n",
    "    )\n",
    "\n",
    "  def call(self, x: tf.Tensor) -> tf.Tensor:\n",
    "    \"\"\"Returns the parameter tensor (broadcasted to match batch dimensions).\"\"\"\n",
    "\n",
    "    batch_size = tf.shape(x)[0]\n",
    "    expanded = tf.expand_dims(self.parameter, axis=0)\n",
    "    return tf.broadcast_to(expanded, [batch_size, tf.shape(x)[1]])\n",
    "\n",
    "\n",
    "def get_mobilenet(\n",
    "    depth_multiplier,\n",
    "    num_classes,\n",
    "    width_multiplier=1.0,\n",
    "    only_logits=False,\n",
    "    output_embeddings=False,\n",
    "    learnable_projection_layer_dimension: int | None = None,\n",
    "    vid_modify: bool = False,\n",
    "    input_shape: tuple[int, int, int] = (32, 32, 3),\n",
    "):\n",
    "  \"\"\"Loads mobilenet.\n",
    "\n",
    "  Args:\n",
    "    depth_multiplier: the depth_multiplier parameter for mobilenet\n",
    "    num_classes: the number of classes\n",
    "    width_multiplier: the width_multiplier for mobilenet\n",
    "    only_logits: Whether we want to ouput only its logits.\n",
    "    output_embeddings: Whether we output the embeddings or not.\n",
    "    learnable_projection_layer_dimension: The dimension of the learnable\n",
    "      projection matrix used for embedding distillation when there is a mismatch\n",
    "      between the teacher's and student's dimension of the penultimate layer.\n",
    "    vid_modify: Whether we add the vid-modification for embeddings.\n",
    "    input_shape: The input shape.\n",
    "\n",
    "  Returns:\n",
    "    Returns the mobilenet model.\n",
    "  \"\"\"\n",
    "  mobilenet = tf.keras.applications.mobilenet.MobileNet(\n",
    "      input_shape=input_shape,\n",
    "      alpha=width_multiplier,\n",
    "      depth_multiplier=depth_multiplier,\n",
    "      dropout=0.001,\n",
    "      include_top=True,\n",
    "      weights=None,\n",
    "      input_tensor=None,\n",
    "      pooling=None,\n",
    "      classes=num_classes,\n",
    "      classifier_activation='softmax',\n",
    "  )\n",
    "\n",
    "  if learnable_projection_layer_dimension is not None:\n",
    "    penultimate = mobilenet.get_layer('dropout').output\n",
    "    old_embeddings = tf.keras.layers.Flatten()(penultimate)\n",
    "    embeddings = tf.keras.layers.Dense(\n",
    "        learnable_projection_layer_dimension, kernel_initializer='he_normal'\n",
    "    )(old_embeddings)\n",
    "    logits = tf.keras.layers.Dense(num_classes, kernel_initializer='he_normal')(\n",
    "        embeddings\n",
    "    )\n",
    "    probabilities = tf.keras.layers.Activation('softmax')(logits)\n",
    "  else:\n",
    "    logits = mobilenet.get_layer('reshape_2').output\n",
    "    probabilities = tf.keras.layers.Activation('softmax')(logits)\n",
    "    penultimate = mobilenet.get_layer('dropout').output\n",
    "    embeddings = tf.keras.layers.Flatten()(penultimate)\n",
    "\n",
    "  if only_logits and output_embeddings:\n",
    "    if vid_modify:\n",
    "      embeddings_scaling = Parameter()(embeddings)\n",
    "      outputs = tf.keras.backend.concatenate(\n",
    "          [logits, embeddings, embeddings_scaling]\n",
    "      )\n",
    "    else:\n",
    "      outputs = tf.keras.backend.concatenate([logits, embeddings])\n",
    "  elif not only_logits and output_embeddings:\n",
    "    if vid_modify:\n",
    "      embeddings_scaling = Parameter()(embeddings)\n",
    "      outputs = tf.keras.backend.concatenate(\n",
    "          [probabilities, embeddings, embeddings_scaling]\n",
    "      )\n",
    "    else:\n",
    "      outputs = tf.keras.backend.concatenate([probabilities, embeddings])\n",
    "  elif only_logits:\n",
    "    outputs = logits\n",
    "  else:\n",
    "    outputs = probabilities\n",
    "\n",
    "  # Instantiate model.\n",
    "  model = tf.keras.models.Model(inputs=mobilenet.input, outputs=outputs)\n",
    "  return model\n",
    "\n",
    "\n",
    "def resnet_layer(\n",
    "    inputs,\n",
    "    num_filters=16,\n",
    "    kernel_size=3,\n",
    "    strides=1,\n",
    "    activation='relu',\n",
    "    batch_normalization=True,\n",
    "    conv_first=True,\n",
    "    batch_norm_decay: float | None = None,\n",
    "    batch_norm_epsilon: float | None = None,\n",
    "):\n",
    "  \"\"\"2D Convolution-Batch Normalization-Activation stack builder.\n",
    "\n",
    "  Args:\n",
    "    inputs: input tensor from input image or previous layer\n",
    "    num_filters: Conv2D number of filters\n",
    "    kernel_size: Conv2D square kernel dimensions\n",
    "    strides: Conv2D square stride dimensions\n",
    "    activation: activation name\n",
    "    batch_normalization: whether to include batch normalization\n",
    "    conv_first: conv-bn-activation (True) or bn-activation-conv (False)\n",
    "    batch_norm_decay: The value of batch_norm_decay\n",
    "    batch_norm_epsilon: The value of batch_norm_epsilon\n",
    "\n",
    "  Returns:\n",
    "   x: tensor as input to the next layer\n",
    "  \"\"\"\n",
    "  conv = tf.keras.layers.Conv2D(\n",
    "      num_filters,\n",
    "      kernel_size=kernel_size,\n",
    "      strides=strides,\n",
    "      padding='same',\n",
    "      kernel_initializer='he_normal',\n",
    "      kernel_regularizer=tf.keras.regularizers.l2(1e-4),\n",
    "  )\n",
    "\n",
    "  x = inputs\n",
    "  if conv_first:\n",
    "    x = conv(x)\n",
    "    if batch_normalization:\n",
    "      if (batch_norm_decay is not None) and (batch_norm_epsilon is not None):\n",
    "        x = tf.keras.layers.BatchNormalization(\n",
    "            momentum=batch_norm_decay, epsilon=batch_norm_epsilon\n",
    "        )(x)\n",
    "      else:\n",
    "        x = tf.keras.layers.BatchNormalization()(x)\n",
    "    if activation is not None:\n",
    "      x = tf.keras.layers.Activation(activation)(x)\n",
    "  else:\n",
    "    if batch_normalization:\n",
    "      if (batch_norm_decay is not None) and (batch_norm_epsilon is not None):\n",
    "        x = tf.keras.layers.BatchNormalization(\n",
    "            momentum=batch_norm_decay, epsilon=batch_norm_epsilon\n",
    "        )(x)\n",
    "      else:\n",
    "        x = tf.keras.layers.BatchNormalization()(x)\n",
    "    if activation is not None:\n",
    "      x = tf.keras.layers.Activation(activation)(x)\n",
    "    x = conv(x)\n",
    "  return x\n",
    "\n",
    "\n",
    "def resnet_v2(\n",
    "    input_shape=(32, 32, 3),\n",
    "    depth=29,\n",
    "    num_classes=10,\n",
    "    data_augmentation=False,\n",
    "    only_logits=False,\n",
    "    output_embeddings=True,\n",
    "    learnable_projection_layer_dimension: int | None = None,\n",
    "    vid_modify: bool = False,\n",
    "    weight_decay: bool = False,\n",
    "    batch_norm_decay: float | None = None,\n",
    "    batch_norm_epsilon: float | None = None,\n",
    "):\n",
    "  \"\"\"ResNet Version 2 Model builder [a].\n",
    "\n",
    "  Args:\n",
    "    input_shape: shape of input image tensor\n",
    "    depth: number of core convolutional layers\n",
    "    num_classes: number of classes\n",
    "    data_augmentation: A boolean variable that determines whether we use data\n",
    "      augmentation or not\n",
    "    only_logits: A boolean that determines whether the output of the model is\n",
    "      logits or probabilities.\n",
    "    output_embeddings: A boolean that determines whether the model outputs its\n",
    "      embeddings or not.\n",
    "    learnable_projection_layer_dimension: The dimension of the learnable\n",
    "      projection matrix used for embedding distillation when there is a mismatch\n",
    "      between the teacher's and student's dimension of the penultimate layer.\n",
    "    vid_modify: Whether we add the vid-modification for embeddings.\n",
    "    weight_decay: Whether we apply l2_weight_decay.\n",
    "    batch_norm_decay: The value of batch_norm_decay\n",
    "    batch_norm_epsilon: The value of batch_norm_epsilon\n",
    "\n",
    "  Returns:\n",
    "    model (Model): Keras model instance\n",
    "  \"\"\"\n",
    "  if (depth - 2) % 9 != 0:\n",
    "    raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\n",
    "  # Start model definition.\n",
    "  num_filters_in = 16\n",
    "  num_res_blocks = int((depth - 2) / 9)\n",
    "  num_filters_out = 0\n",
    "\n",
    "  inputs = tf.keras.layers.Input(shape=input_shape)\n",
    "  # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\n",
    "  if data_augmentation:\n",
    "    data_augmentation_module = tf.keras.Sequential([\n",
    "        tf.keras.layers.Resizing(36, 36),\n",
    "        tf.keras.layers.RandomCrop(height=32, width=32),\n",
    "    ])\n",
    "    x = data_augmentation_module(inputs)\n",
    "    x = resnet_layer(\n",
    "        inputs=x,\n",
    "        num_filters=num_filters_in,\n",
    "        conv_first=True,\n",
    "        batch_norm_decay=batch_norm_decay,\n",
    "        batch_norm_epsilon=batch_norm_epsilon,\n",
    "    )\n",
    "  else:\n",
    "    x = resnet_layer(\n",
    "        inputs=inputs,\n",
    "        num_filters=num_filters_in,\n",
    "        conv_first=True,\n",
    "        batch_norm_decay=batch_norm_decay,\n",
    "        batch_norm_epsilon=batch_norm_epsilon,\n",
    "    )\n",
    "\n",
    "  # Instantiate the stack of residual units\n",
    "  for stage in range(3):\n",
    "    for res_block in range(num_res_blocks):\n",
    "      activation = 'relu'\n",
    "      batch_normalization = True\n",
    "      strides = 1\n",
    "      if stage == 0:\n",
    "        num_filters_out = num_filters_in * 4\n",
    "        if res_block == 0:  # first layer and first stage\n",
    "          activation = None\n",
    "          batch_normalization = False\n",
    "      else:\n",
    "        num_filters_out = num_filters_in * 2\n",
    "        if res_block == 0:  # first layer but not first stage\n",
    "          strides = 2  # downsample\n",
    "\n",
    "      # bottleneck residual unit\n",
    "      y = resnet_layer(\n",
    "          inputs=x,\n",
    "          num_filters=num_filters_in,\n",
    "          kernel_size=1,\n",
    "          strides=strides,\n",
    "          activation=activation,\n",
    "          batch_normalization=batch_normalization,\n",
    "          conv_first=False,\n",
    "          batch_norm_decay=batch_norm_decay,\n",
    "          batch_norm_epsilon=batch_norm_epsilon,\n",
    "      )\n",
    "      y = resnet_layer(\n",
    "          inputs=y,\n",
    "          num_filters=num_filters_in,\n",
    "          conv_first=False,\n",
    "          batch_norm_decay=batch_norm_decay,\n",
    "          batch_norm_epsilon=batch_norm_epsilon,\n",
    "      )\n",
    "      y = resnet_layer(\n",
    "          inputs=y,\n",
    "          num_filters=num_filters_out,\n",
    "          kernel_size=1,\n",
    "          conv_first=False,\n",
    "          batch_norm_decay=batch_norm_decay,\n",
    "          batch_norm_epsilon=batch_norm_epsilon,\n",
    "      )\n",
    "      if res_block == 0:\n",
    "        # linear projection residual shortcut connection to match\n",
    "        # changed dims\n",
    "        x = resnet_layer(\n",
    "            inputs=x,\n",
    "            num_filters=num_filters_out,\n",
    "            kernel_size=1,\n",
    "            strides=strides,\n",
    "            activation=None,\n",
    "            batch_normalization=False,\n",
    "            batch_norm_decay=batch_norm_decay,\n",
    "            batch_norm_epsilon=batch_norm_epsilon,\n",
    "        )\n",
    "      x = tf.keras.layers.add([x, y])\n",
    "\n",
    "    num_filters_in = num_filters_out\n",
    "\n",
    "  # Add classifier on top.\n",
    "  # v2 has BN-ReLU before Pooling\n",
    "  x = tf.keras.layers.BatchNormalization()(x)\n",
    "  x = tf.keras.layers.Activation('relu')(x)\n",
    "  x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)\n",
    "  if learnable_projection_layer_dimension is not None:\n",
    "    old_embeddings = tf.keras.layers.Flatten()(x)\n",
    "    embeddings = tf.keras.layers.Dense(\n",
    "        learnable_projection_layer_dimension, kernel_initializer='he_normal'\n",
    "    )(old_embeddings)\n",
    "    if weight_decay:\n",
    "      x = tf.keras.layers.Dense(\n",
    "          num_classes,\n",
    "          kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),\n",
    "          kernel_regularizer=tf.keras.regularizers.l2(1e-4),\n",
    "          bias_regularizer=tf.keras.regularizers.l2(1e-4),\n",
    "      )(embeddings)\n",
    "    else:\n",
    "      x = tf.keras.layers.Dense(num_classes, kernel_initializer='he_normal')(\n",
    "          embeddings\n",
    "      )\n",
    "  else:\n",
    "    embeddings = tf.keras.layers.Flatten()(x)\n",
    "    if weight_decay:\n",
    "      x = tf.keras.layers.Dense(\n",
    "          num_classes,\n",
    "          kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),\n",
    "          kernel_regularizer=tf.keras.regularizers.l2(1e-4),\n",
    "          bias_regularizer=tf.keras.regularizers.l2(1e-4),\n",
    "      )(embeddings)\n",
    "    else:\n",
    "      x = tf.keras.layers.Dense(num_classes, kernel_initializer='he_normal')(\n",
    "          embeddings\n",
    "      )\n",
    "  logits = x\n",
    "  probabilities = tf.keras.layers.Activation('softmax')(logits)\n",
    "  if only_logits and output_embeddings:\n",
    "    if vid_modify:\n",
    "      embeddings_scaling = Parameter()(embeddings)\n",
    "      outputs = tf.keras.backend.concatenate(\n",
    "          [logits, embeddings, embeddings_scaling]\n",
    "      )\n",
    "    else:\n",
    "      outputs = tf.keras.backend.concatenate([logits, embeddings])\n",
    "  elif not only_logits and output_embeddings:\n",
    "    if vid_modify:\n",
    "      embeddings_scaling = Parameter()(embeddings)\n",
    "      outputs = tf.keras.backend.concatenate(\n",
    "          [probabilities, embeddings, embeddings_scaling]\n",
    "      )\n",
    "    else:\n",
    "      outputs = tf.keras.backend.concatenate([probabilities, embeddings])\n",
    "  elif only_logits:\n",
    "    outputs = logits\n",
    "  else:\n",
    "    outputs = probabilities\n",
    "\n",
    "  # Instantiate model.\n",
    "  model = tf.keras.models.Model(inputs=inputs, outputs=outputs)\n",
    "  return model\n",
    "\n",
    "\n",
    "def get_bert_model(\n",
    "    encoder_size: str,\n",
    "    trainable: bool,\n",
    "    only_logits: bool = False,\n",
    "    num_classes: int = 2,\n",
    "    output_embeddings: bool = False,\n",
    "    learnable_projection_layer_dimension: int | None = None,\n",
    "    vid_modify: bool = False,\n",
    "    add_mlp: bool = False,\n",
    ") -> tf.keras.Model:\n",
    "  \"\"\"Returns an ALBERT model.\n",
    "\n",
    "  Args:\n",
    "    encoder_size: The encoder size\n",
    "    trainable: Whether the model is trainable or not.\n",
    "    only_logits: Whether we want to ouput only its logits.\n",
    "    num_classes: The number of classes.\n",
    "    output_embeddings: Whether we output the embeddings or not.\n",
    "    learnable_projection_layer_dimension: The dimension of the learnable\n",
    "      projection matrix used for embedding distillation when there is a mismatch\n",
    "      between the teacher's and student's dimension of the penultimate layer.\n",
    "    vid_modify: Whether we add the vid-modification for embeddings.\n",
    "    add_mlp: Whether to add MLP-layers after the BERT-embeddings.\n",
    "\n",
    "  Returns:\n",
    "    Returns an ALBERT model.\n",
    "  \"\"\"\n",
    "\n",
    "  def create_model():\n",
    "    bert_encoder = hub.KerasLayer(\n",
    "        BERT_ENCODERS[encoder_size], trainable=trainable, name='bert_encoder'\n",
    "    )\n",
    "    bert_inputs = dict(\n",
    "        input_word_ids=tf.keras.layers.Input(\n",
    "            shape=(None,), dtype=tf.int32, name='input_word_ids'\n",
    "        ),\n",
    "        input_mask=tf.keras.layers.Input(\n",
    "            shape=(None,), dtype=tf.int32, name='input_mask'\n",
    "        ),\n",
    "        input_type_ids=tf.keras.layers.Input(\n",
    "            shape=(None,), dtype=tf.int32, name='input_type_ids'\n",
    "        ),\n",
    "    )\n",
    "    bert_out = bert_encoder(bert_inputs)['pooled_output']\n",
    "    if add_mlp:\n",
    "      bert_out = tf.keras.layers.Dense(200, activation='relu')(bert_out)\n",
    "      bert_out = tf.keras.layers.Dense(100, activation='relu')(bert_out)\n",
    "      bert_out = tf.keras.layers.Dense(50, activation='relu')(bert_out)\n",
    "    if learnable_projection_layer_dimension is not None:\n",
    "      old_logits = tf.keras.layers.Dense(num_classes, activation=None)(bert_out)\n",
    "      model = tf.keras.Model(inputs=bert_inputs, outputs=old_logits)\n",
    "      penultimate = model.get_layer('bert_encoder').output['default']\n",
    "      old_embeddings = tf.keras.layers.Flatten()(penultimate)\n",
    "      embeddings = tf.keras.layers.Dense(\n",
    "          learnable_projection_layer_dimension, kernel_initializer='he_normal'\n",
    "      )(old_embeddings)\n",
    "      logits = tf.keras.layers.Dense(\n",
    "          num_classes, kernel_initializer='he_normal'\n",
    "      )(embeddings)\n",
    "      probabilities = tf.keras.layers.Activation('softmax')(logits)\n",
    "    else:\n",
    "      logits = tf.keras.layers.Dense(num_classes, activation=None)(bert_out)\n",
    "      probabilities = tf.keras.layers.Activation('softmax')(logits)\n",
    "      model = tf.keras.Model(inputs=bert_inputs, outputs=logits)\n",
    "      penultimate = model.get_layer('bert_encoder').output['default']\n",
    "      embeddings = tf.keras.layers.Flatten()(penultimate)\n",
    "\n",
    "    if only_logits and output_embeddings:\n",
    "      if vid_modify:\n",
    "        embeddings_scaling = Parameter()(embeddings)\n",
    "        outputs = tf.keras.backend.concatenate(\n",
    "            [logits, embeddings, embeddings_scaling]\n",
    "        )\n",
    "      else:\n",
    "        outputs = tf.keras.backend.concatenate([logits, embeddings])\n",
    "    elif not only_logits and output_embeddings:\n",
    "      if vid_modify:\n",
    "        embeddings_scaling = Parameter()(embeddings)\n",
    "        outputs = tf.keras.backend.concatenate(\n",
    "            [probabilities, embeddings, embeddings_scaling]\n",
    "        )\n",
    "      else:\n",
    "        outputs = tf.keras.backend.concatenate([probabilities, embeddings])\n",
    "    elif only_logits:\n",
    "      outputs = logits\n",
    "    else:\n",
    "      outputs = probabilities\n",
    "    bert_classifier = tf.keras.Model(bert_inputs, outputs)\n",
    "    return bert_classifier\n",
    "\n",
    "  bert_classifier = create_model()\n",
    "  return bert_classifier\n",
    "\n",
    "\n",
    "def get_compiled_model(\n",
    "    num_classes: int,\n",
    "    loss_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | str,\n",
    "    metric_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | str,\n",
    "    resnet_depth: int,\n",
    "    mobilenet_dm: int,\n",
    "    learning_rate: float,\n",
    "    decay_steps: int,\n",
    "    only_logits: bool,\n",
    "    output_embeddings: bool,\n",
    "    architecture: str = 'resnet',\n",
    "    subclass_teacher: bool = False,\n",
    "    num_subclasses: int | None = None,\n",
    "    subclass_teacher_temperature: float = 1.0,\n",
    "    auxiliary_loss_param: float = 0.1,\n",
    "    input_shape: tuple[int, int, int] = (32, 32, 3),\n",
    "    add_mlp: bool = False,\n",
    ") -> tf.keras.Model:\n",
    "  \"\"\"A method that returns a compiled model.\n",
    "\n",
    "  Args:\n",
    "    num_classes: The number of classes.\n",
    "    loss_fn: The loss function.\n",
    "    metric_fn: The metric function.\n",
    "    resnet_depth: The depth of the resnet model.\n",
    "    mobilenet_dm: The depth multiplier of mobilenet.\n",
    "    learning_rate: The learning rate.\n",
    "    decay_steps: The number of decay steps.\n",
    "    only_logits: A boolean determining whether the model outputs logits or\n",
    "      probabilites.\n",
    "    output_embeddings: A boolean determining whether the model outputs\n",
    "      embeddings or not.\n",
    "    architecture: A string with the name of the model architecture.\n",
    "    subclass_teacher: Whether this is a subclass teacher or not.\n",
    "    num_subclasses: The number of subclasses.\n",
    "    subclass_teacher_temperature: The temperature for subclass distillation.\n",
    "    auxiliary_loss_param: The paramater for the auxiliary loss.\n",
    "    input_shape: The input shape.\n",
    "    add_mlp: Whether to add MLP-layers to the BERT model.\n",
    "\n",
    "  Returns:\n",
    "    A compiled resnet model.\n",
    "  \"\"\"\n",
    "\n",
    "  output_dimension = num_subclasses if subclass_teacher else num_classes\n",
    "\n",
    "  if architecture == 'resnet':\n",
    "    model = resnet_v2(\n",
    "        depth=resnet_depth,\n",
    "        num_classes=output_dimension,\n",
    "        only_logits=only_logits,\n",
    "        output_embeddings=output_embeddings,\n",
    "        input_shape=input_shape,\n",
    "    )\n",
    "    lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(\n",
    "        learning_rate, decay_steps=decay_steps\n",
    "    )\n",
    "    opt = tf.keras.optimizers.SGD(\n",
    "        learning_rate=lr_decayed_fn, momentum=0.9, nesterov=False\n",
    "    )\n",
    "\n",
    "  elif architecture == 'mobilenet':\n",
    "    model = get_mobilenet(\n",
    "        depth_multiplier=mobilenet_dm,\n",
    "        num_classes=output_dimension,\n",
    "        width_multiplier=2.0,\n",
    "        only_logits=only_logits,\n",
    "        output_embeddings=output_embeddings,\n",
    "        input_shape=input_shape,\n",
    "    )\n",
    "    opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
    "  elif 'bert' in architecture:\n",
    "    model = get_bert_model(\n",
    "        encoder_size=architecture,\n",
    "        trainable=True,\n",
    "        only_logits=only_logits,\n",
    "        num_classes=output_dimension,\n",
    "        output_embeddings=output_embeddings,\n",
    "        add_mlp=add_mlp\n",
    "    )\n",
    "    opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
    "  else:\n",
    "    raise ValueError(f'student_architecture:{architecture} is not supported.')\n",
    "\n",
    "  if subclass_teacher:\n",
    "    if loss_fn == 'categorical_crossentropy':\n",
    "      loss_fn = tf.keras.losses.CategoricalCrossentropy(\n",
    "          reduction=tf.keras.losses.Reduction.NONE\n",
    "      )\n",
    "    elif loss_fn == 'binary_crossentropy':\n",
    "      loss_fn = tf.keras.losses.BinaryCrossentropy(\n",
    "          reduction=tf.keras.losses.Reduction.NONE\n",
    "      )\n",
    "    loss_function = subclass_teacher_loss(\n",
    "        alpha=auxiliary_loss_param,\n",
    "        loss_fn=loss_fn,\n",
    "        num_classes=num_classes,\n",
    "        num_subclasses=num_subclasses,\n",
    "        auxiliary_loss_temperature=subclass_teacher_temperature,\n",
    "    )\n",
    "    model = SubclassTeacher(\n",
    "        teacher_model=model,\n",
    "        num_classes=num_classes,\n",
    "        num_subclasses=num_subclasses,\n",
    "        loss_fn=loss_function,\n",
    "        metric_fn=metric_fn,\n",
    "    )\n",
    "    model.compile(optimizer=opt)\n",
    "  else:\n",
    "    model.compile(loss=loss_fn, optimizer=opt, metrics=metric_fn)\n",
    "  return model\n",
    "\n",
    "\n",
    "def get_compiled_distilled_student_model(\n",
    "    teacher_model: tf.keras.Model,\n",
    "    params,\n",
    "    student_architecture: str = 'resnet',\n",
    "    num_subclasses: int | None = None,\n",
    "    dynamic_subclasses: bool = False,\n",
    "    learnable_projection_dimension: int | None = None,\n",
    "    vid_modify: bool = False,\n",
    "    use_fitnet: bool = False,\n",
    "    decomposed_logits: bool = False,\n",
    "    input_shape: tuple[int, int, int] = (32, 32, 3),\n",
    "    add_mlp: bool = False,\n",
    "):\n",
    "  \"\"\"A method that returns a compiled distilled resnet model.\n",
    "\n",
    "  Args:\n",
    "    teacher_model: The teacher mdoel\n",
    "    params: The parameters for distillation.\n",
    "    student_architecture: The student architecture, e.g. 'resnet' or\n",
    "      'mobilenet'.\n",
    "    num_subclasses: The number of subclasses of the dataset.\n",
    "    dynamic_subclasses: whether to use the PCA distiller\n",
    "    learnable_projection_dimension: The dimension of the learnable projection\n",
    "      matrix used for embedding distillation when there is a mismatch between\n",
    "      the teacher's and student's dimension of the penultimate layer.\n",
    "    vid_modify: Whether we add the vid-modification for embeddings.\n",
    "    use_fitnet: Whether to use the fitnet approach.\n",
    "    decomposed_logits: use the logit adjustment thing\n",
    "    input_shape: The input shape.\n",
    "    add_mlp: Whether to add MLP-layers to the BERT model.\n",
    "\n",
    "  Returns:\n",
    "     The compiled student model (tf.keras.Model)\n",
    "  \"\"\"\n",
    "\n",
    "  num_classes = (\n",
    "      num_subclasses if num_subclasses is not None else params.num_classes\n",
    "  )\n",
    "\n",
    "  if decomposed_logits:\n",
    "    num_classes = params.num_classes + num_subclasses\n",
    "\n",
    "  if student_architecture == 'resnet':\n",
    "    student_model = resnet_v2(\n",
    "        depth=params.resnet_depth,\n",
    "        num_classes=num_classes,\n",
    "        only_logits=True,\n",
    "        output_embeddings=True,\n",
    "        learnable_projection_layer_dimension=learnable_projection_dimension,\n",
    "        vid_modify=vid_modify,\n",
    "        input_shape=input_shape,\n",
    "    )\n",
    "    lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(\n",
    "        initial_learning_rate=params.learning_rate,\n",
    "        decay_steps=params.decay_steps,\n",
    "        alpha=1e-6,\n",
    "    )\n",
    "    optimizer = tf.keras.optimizers.SGD(\n",
    "        learning_rate=lr_decayed_fn, momentum=0.9, nesterov=True\n",
    "    )\n",
    "  elif student_architecture == 'mobilenet':\n",
    "    student_model = get_mobilenet(\n",
    "        depth_multiplier=1,\n",
    "        num_classes=num_classes,\n",
    "        width_multiplier=1.0,\n",
    "        only_logits=True,\n",
    "        output_embeddings=True,\n",
    "        learnable_projection_layer_dimension=learnable_projection_dimension,\n",
    "        vid_modify=vid_modify,\n",
    "        input_shape=input_shape,\n",
    "    )\n",
    "    optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)\n",
    "  elif 'bert' in student_architecture:\n",
    "    student_model = get_bert_model(\n",
    "        encoder_size=student_architecture,\n",
    "        trainable=True,\n",
    "        only_logits=True,\n",
    "        num_classes=num_classes,\n",
    "        output_embeddings=True,\n",
    "        learnable_projection_layer_dimension=learnable_projection_dimension,\n",
    "        vid_modify=vid_modify,\n",
    "        add_mlp=add_mlp\n",
    "    )\n",
    "    optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)\n",
    "  else:\n",
    "    raise ValueError(\n",
    "        f'student_architecture:{student_architecture} is not supported.'\n",
    "    )\n",
    "\n",
    "  if num_subclasses is None:\n",
    "    if vid_modify:\n",
    "      compiled_distilled_model = DistillerVID(\n",
    "          teacher_model=teacher_model,\n",
    "          student_model=student_model,\n",
    "          params=params,\n",
    "      )\n",
    "    elif use_fitnet:\n",
    "      compiled_distilled_model = FitNetDistiller(\n",
    "          teacher_model=teacher_model,\n",
    "          student_model=student_model,\n",
    "          params=params,\n",
    "      )\n",
    "    else:\n",
    "      compiled_distilled_model = Distiller(\n",
    "          teacher_model=teacher_model,\n",
    "          student_model=student_model,\n",
    "          params=params,\n",
    "      )\n",
    "  elif dynamic_subclasses:\n",
    "    if decomposed_logits:\n",
    "      compiled_distilled_model = DynamicSubclassAdjustmentDistiller(\n",
    "          student_model=student_model,\n",
    "          teacher_model=teacher_model,\n",
    "          params=params,\n",
    "          num_subclasses=num_subclasses,\n",
    "      )\n",
    "    else:\n",
    "      compiled_distilled_model = DynamicSubclassDistiller(\n",
    "          student_model=student_model,\n",
    "          teacher_model=teacher_model,\n",
    "          params=params,\n",
    "          num_subclasses=num_subclasses,\n",
    "      )\n",
    "  else:\n",
    "    compiled_distilled_model = SubclassDistiller(\n",
    "        student_model=student_model,\n",
    "        params=params,\n",
    "        num_subclasses=num_subclasses,\n",
    "    )\n",
    "  compiled_distilled_model.compile(\n",
    "      optimizer=optimizer,\n",
    "      params=params,\n",
    "  )\n",
    "\n",
    "  return compiled_distilled_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pvdMXypiB3U5"
   },
   "source": [
    "# Training Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 139,
     "status": "ok",
     "timestamp": 1706798801472,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "4GqLOWgFB6Y5"
   },
   "outputs": [],
   "source": [
    "\"\"\"A couple of standard helpful functions related to Neural Networks training.\"\"\"\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "def _train_model(\n",
    "    model: tf.keras.Model,\n",
    "    train_x: tf.Tensor,\n",
    "    train_y: tf.Tensor,\n",
    "    test_x: tf.Tensor,\n",
    "    test_y: tf.Tensor,\n",
    "    data_augmentation: bool|None = False,\n",
    "    weights: tf.Tensor|None = None,\n",
    "    epochs: int|None = 200,\n",
    "    with_lr_scheduler: bool|None = True,\n",
    "    batch_size: int|None = 128,\n",
    "    epochs_offset: int|None = 0) -> tf.keras.callbacks.History:\n",
    "  \"\"\"Trains a model.\n",
    "\n",
    "  Args:\n",
    "    model: Instance of tf.keras.Model.\n",
    "    train_x: A dataset containing the data to train on.\n",
    "    train_y: A dataset containing the labels of the data to train on.\n",
    "    test_x: A dataset containing the data to test on.\n",
    "    test_y: A dataset containing the labels of the data to test on.\n",
    "    data_augmentation: True if data augmentation is used and False otherwise.\n",
    "    weights: The weight sample-weights to be used.\n",
    "    epochs: The number of epochs to train for.\n",
    "    with_lr_scheduler: Whether we use a schedule for the learning-rate or not.\n",
    "    batch_size: The batch size used for training.\n",
    "    epochs_offset: How many epochs we assume we have performed so far.\n",
    "\n",
    "  Returns:\n",
    "      history: History of trained model.\n",
    "\n",
    "  Raises:\n",
    "    NotImplementedError: if online data augmention is used with weights or if\n",
    "    an augmentation method other than None, 'offline', or 'online' is provided.\n",
    "    To use weights with online data augmentation, weights should be passed as\n",
    "    an extra advice label and a corresponding loss with advice should be used,\n",
    "    see the loss functions defined in loss_functions.py.\n",
    "  \"\"\"\n",
    "\n",
    "  def lr_schedule(epoch):\n",
    "    lr = 1e-3\n",
    "    if epoch + epochs_offset > 180:\n",
    "      lr *= 0.5e-3\n",
    "    elif epoch + epochs_offset > 160:\n",
    "      lr *= 1e-3\n",
    "    elif epoch + epochs_offset > 120:\n",
    "      lr *= 1e-2\n",
    "    elif epoch + epochs_offset > 80:\n",
    "      lr *= 1e-1\n",
    "    return lr\n",
    "\n",
    "  callbacks = None\n",
    "  if with_lr_scheduler:\n",
    "    lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)\n",
    "    callbacks = [lr_scheduler]\n",
    "\n",
    "  with_weights = False\n",
    "  if weights is not None:\n",
    "    with_weights = True\n",
    "    train_dataset = tf.data.Dataset.from_tensor_slices(\n",
    "        (train_x, train_y, weights))\n",
    "  else:\n",
    "    with_weights = False\n",
    "    train_dataset = tf.data.Dataset.from_tensor_slices(\n",
    "        (train_x, train_y))\n",
    "\n",
    "  def prepare(ds, shuffle=False, augment=False, with_weights=False):\n",
    "    autotune = tf.data.AUTOTUNE\n",
    "\n",
    "    data_augmentation_layer = tf.keras.Sequential([\n",
    "        tf.keras.layers.RandomFlip('horizontal'),\n",
    "        tf.keras.layers.RandomTranslation(\n",
    "            height_factor=0.1,\n",
    "            width_factor=0.1,\n",
    "            fill_mode='nearest',\n",
    "            interpolation='bilinear',\n",
    "            seed=None,\n",
    "            fill_value=0.0)\n",
    "    ])\n",
    "\n",
    "    if shuffle:\n",
    "      ds = ds.shuffle(len(train_x), reshuffle_each_iteration=True)\n",
    "\n",
    "    # Use data augmentation only on the training set.\n",
    "    if augment:\n",
    "      if with_weights:\n",
    "        ds = ds.map(\n",
    "            lambda x, y, w: (data_augmentation_layer(x, training=True), y, w),\n",
    "            num_parallel_calls=autotune)\n",
    "      else:\n",
    "        ds = ds.map(\n",
    "            lambda x, y: (data_augmentation_layer(x, training=True), y),\n",
    "            num_parallel_calls=autotune)\n",
    "\n",
    "    # Batch the dataset. The drop_remainder is needed to avoid uneven batches.\n",
    "    ds = ds.batch(batch_size, drop_remainder=True)\n",
    "\n",
    "    # Use buffered prefetching on all datasets.\n",
    "    ds = ds.prefetch(buffer_size=autotune)\n",
    "\n",
    "    return ds\n",
    "\n",
    "  train_dataset = prepare(\n",
    "      train_dataset,\n",
    "      shuffle=True,\n",
    "      augment=data_augmentation,\n",
    "      with_weights=with_weights)\n",
    "\n",
    "  history = model.fit(\n",
    "      train_dataset,\n",
    "      validation_data=(test_x, test_y),\n",
    "      epochs=epochs,\n",
    "      verbose=1,\n",
    "      workers=4,\n",
    "      callbacks=callbacks)\n",
    "\n",
    "  return history\n",
    "\n",
    "\n",
    "def train_model(\n",
    "    model: tf.keras.Model,\n",
    "    train_x: tf.Tensor,\n",
    "    train_y: tf.Tensor,\n",
    "    test_x: tf.Tensor,\n",
    "    test_y: tf.Tensor,\n",
    "    data_augmentation: str|None = None,\n",
    "    weights: tf.Tensor|None = None,\n",
    "    epochs: int|None = 200,\n",
    "    with_lr_scheduler: bool|None = False,\n",
    "    batch_size: int|None = 128,\n",
    "    epochs_offset: int|None = 0) -> tf.keras.callbacks.History:\n",
    "  \"\"\"Trains a model.\n",
    "\n",
    "  Args:\n",
    "    model: Instance of tf.keras.Model.\n",
    "    train_x: A dataset containing the data to train on.\n",
    "    train_y: A dataset containing the labels of the data to train on.\n",
    "    test_x: A dataset containing the data to test on.\n",
    "    test_y: A dataset containing the labels of the data to test on.\n",
    "    data_augmentation: Can be either 'no', 'offline', 'online'.\n",
    "    weights: The weight sample-weights to be used.\n",
    "    epochs: The number of epochs to train for.\n",
    "    with_lr_scheduler: Whether we use a schedule for the learning-rate or not.\n",
    "    batch_size: The batch size used for training.\n",
    "    epochs_offset: How many epochs we assume we have performed so far.\n",
    "\n",
    "  Returns:\n",
    "      history: History of trained model.\n",
    "  \"\"\"\n",
    "\n",
    "  if data_augmentation != 'online':\n",
    "\n",
    "    if data_augmentation == 'no' or not data_augmentation:\n",
    "      use_data_augmentation = False\n",
    "    else:\n",
    "      use_data_augmentation = True\n",
    "\n",
    "    return _train_model(\n",
    "        model=model,\n",
    "        train_x=train_x,\n",
    "        train_y=train_y,\n",
    "        test_x=test_x,\n",
    "        test_y=test_y,\n",
    "        epochs=epochs,\n",
    "        data_augmentation=use_data_augmentation,\n",
    "        weights=weights,\n",
    "        with_lr_scheduler=with_lr_scheduler,\n",
    "        batch_size=batch_size,\n",
    "        epochs_offset=epochs_offset)\n",
    "\n",
    "  elif data_augmentation == 'online':\n",
    "    results = []\n",
    "\n",
    "    for i in range(1, epochs+1):\n",
    "\n",
    "      print(f'Epoch {i}/{epochs}')\n",
    "\n",
    "      history_tmp = _train_model(\n",
    "          model=model,\n",
    "          train_x=train_x,\n",
    "          train_y=train_y,\n",
    "          test_x=test_x,\n",
    "          test_y=test_y,\n",
    "          epochs=1,\n",
    "          data_augmentation=True,\n",
    "          weights=weights,\n",
    "          with_lr_scheduler=with_lr_scheduler,\n",
    "          batch_size=batch_size,\n",
    "          epochs_offset=i+epochs_offset)\n",
    "\n",
    "      results.append(history_tmp)\n",
    "\n",
    "    history = results[0]\n",
    "\n",
    "    for hist in results[1:]:\n",
    "      for key in hist.history.keys():\n",
    "        history.history[key] += hist.history[key]\n",
    "    return history\n",
    "\n",
    "  else:\n",
    "    raise NotImplementedError()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GXN7PV-CFapj"
   },
   "source": [
    "# Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 15127,
     "status": "ok",
     "timestamp": 1706798820713,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "JVBiDl2WNr0A"
   },
   "outputs": [],
   "source": [
    "if _DATASET_NAME == 'cifar10':\n",
    "    size = 50000\n",
    "    num_classes = 10\n",
    "    text_data = None\n",
    "elif _DATASET_NAME == 'cifar100':\n",
    "    size = 50000\n",
    "    num_classes = 100\n",
    "    text_data = None\n",
    "elif _DATASET_NAME == 'cifar10bin':\n",
    "    size = 50000\n",
    "    num_classes = 2\n",
    "    text_data = None\n",
    "elif _DATASET_NAME == 'cifar100bin':\n",
    "    size = 50000\n",
    "    num_classes = 2\n",
    "    text_data = None\n",
    "elif _DATASET_NAME == 'celeb_a':\n",
    "    size = 162770\n",
    "    num_classes = 2\n",
    "    text_data = None\n",
    "elif _DATASET_NAME == 'imdb_reviews':\n",
    "    size = 25000\n",
    "    num_classes = 2\n",
    "    text_data = 'text'\n",
    "elif _DATASET_NAME == 'glue_cola':\n",
    "    size = 8551\n",
    "    num_classes = 2\n",
    "    text_data = 'sentence'\n",
    "elif _DATASET_NAME == 'glue_sst2':\n",
    "    size = 67349\n",
    "    num_classes = 2\n",
    "    text_data = 'sentence'\n",
    "else:\n",
    "    raise ValueError(f'Unsupported dataset name: {_DATASET_NAME}')\n",
    "\n",
    "train_x, train_y, test_x, test_y = get_data(\n",
    "      0, size, _DATASET_NAME\n",
    "  )\n",
    "\n",
    "train_set, test_dataset, teacher_train_set = None, None, None\n",
    "if 'bert' in _TEACHER_ARCHITECTURE:\n",
    "    train_set = prepare_nlp_dataset(\n",
    "        examples=train_x,\n",
    "        labels=train_y,\n",
    "        batch_size=_TEACHER_BATCH_SIZE,\n",
    "        text_data=text_data,\n",
    "    )\n",
    "    test_dataset = prepare_nlp_dataset(\n",
    "        examples=test_x,\n",
    "        labels=test_y,\n",
    "        batch_size=1,\n",
    "        text_data=text_data,\n",
    "    )\n",
    "\n",
    "if _NUMBER_OF_LABELED_EXAMPLES is not None:\n",
    "    teacher_train_x, teacher_train_y, _, _ = get_data(\n",
    "        0, _NUMBER_OF_LABELED_EXAMPLES, _DATASET_NAME\n",
    "    )\n",
    "    if 'bert' in _TEACHER_ARCHITECTURE:\n",
    "      teacher_train_set = prepare_nlp_dataset(\n",
    "          examples=teacher_train_x,\n",
    "          labels=teacher_train_y,\n",
    "          batch_size=_TEACHER_BATCH_SIZE,\n",
    "          text_data=text_data,\n",
    "      )\n",
    "else:\n",
    "    teacher_train_x = train_x\n",
    "    teacher_train_y = train_y\n",
    "    if 'bert' in _TEACHER_ARCHITECTURE:\n",
    "      teacher_train_set = train_set\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yF-piHz9FT3L"
   },
   "source": [
    "# Train the teacher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1706806214464,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "0axbvRGZFVi_",
    "outputId": "bbca2897-f0ff-4c6f-b351-a0bbe662b4b4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/200\n",
      "195/195 [==============================] - 56s 175ms/step - loss: 5.7448 - categorical_accuracy: 0.0894 - val_loss: 5.7591 - val_categorical_accuracy: 0.0815\n",
      "Epoch 2/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 5.0436 - categorical_accuracy: 0.1771 - val_loss: 5.4692 - val_categorical_accuracy: 0.1132\n",
      "Epoch 3/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 4.5204 - categorical_accuracy: 0.2474 - val_loss: 5.1557 - val_categorical_accuracy: 0.1508\n",
      "Epoch 4/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 4.1025 - categorical_accuracy: 0.3108 - val_loss: 4.3237 - val_categorical_accuracy: 0.2568\n",
      "Epoch 5/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 3.7523 - categorical_accuracy: 0.3606 - val_loss: 4.1179 - val_categorical_accuracy: 0.2857\n",
      "Epoch 6/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 3.4495 - categorical_accuracy: 0.4087 - val_loss: 4.1721 - val_categorical_accuracy: 0.2905\n",
      "Epoch 7/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 3.1887 - categorical_accuracy: 0.4523 - val_loss: 4.6846 - val_categorical_accuracy: 0.2669\n",
      "Epoch 8/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 2.9910 - categorical_accuracy: 0.4845 - val_loss: 3.8664 - val_categorical_accuracy: 0.3365\n",
      "Epoch 9/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 2.7977 - categorical_accuracy: 0.5149 - val_loss: 3.4054 - val_categorical_accuracy: 0.4007\n",
      "Epoch 10/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 2.6462 - categorical_accuracy: 0.5385 - val_loss: 3.1341 - val_categorical_accuracy: 0.4538\n",
      "Epoch 11/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 2.4977 - categorical_accuracy: 0.5657 - val_loss: 3.0583 - val_categorical_accuracy: 0.4500\n",
      "Epoch 12/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 2.3811 - categorical_accuracy: 0.5839 - val_loss: 2.9800 - val_categorical_accuracy: 0.4578\n",
      "Epoch 13/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 2.2725 - categorical_accuracy: 0.6020 - val_loss: 2.9427 - val_categorical_accuracy: 0.4571\n",
      "Epoch 14/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 2.1606 - categorical_accuracy: 0.6211 - val_loss: 2.7486 - val_categorical_accuracy: 0.4854\n",
      "Epoch 15/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 2.0855 - categorical_accuracy: 0.6340 - val_loss: 4.0821 - val_categorical_accuracy: 0.3370\n",
      "Epoch 16/200\n",
      "195/195 [==============================] - 33s 167ms/step - loss: 1.9940 - categorical_accuracy: 0.6483 - val_loss: 2.7954 - val_categorical_accuracy: 0.4880\n",
      "Epoch 17/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.9143 - categorical_accuracy: 0.6635 - val_loss: 2.8352 - val_categorical_accuracy: 0.4796\n",
      "Epoch 18/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.8554 - categorical_accuracy: 0.6752 - val_loss: 2.7616 - val_categorical_accuracy: 0.4921\n",
      "Epoch 19/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.7943 - categorical_accuracy: 0.6848 - val_loss: 2.4873 - val_categorical_accuracy: 0.5222\n",
      "Epoch 20/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.7205 - categorical_accuracy: 0.6996 - val_loss: 2.9671 - val_categorical_accuracy: 0.4620\n",
      "Epoch 21/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.6810 - categorical_accuracy: 0.7050 - val_loss: 2.9571 - val_categorical_accuracy: 0.4953\n",
      "Epoch 22/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.6369 - categorical_accuracy: 0.7136 - val_loss: 2.7917 - val_categorical_accuracy: 0.4874\n",
      "Epoch 23/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.5908 - categorical_accuracy: 0.7223 - val_loss: 2.3695 - val_categorical_accuracy: 0.5638\n",
      "Epoch 24/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 1.5408 - categorical_accuracy: 0.7301 - val_loss: 2.4403 - val_categorical_accuracy: 0.5418\n",
      "Epoch 25/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.5157 - categorical_accuracy: 0.7369 - val_loss: 2.6528 - val_categorical_accuracy: 0.5053\n",
      "Epoch 26/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.4788 - categorical_accuracy: 0.7428 - val_loss: 2.4799 - val_categorical_accuracy: 0.5435\n",
      "Epoch 27/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.4389 - categorical_accuracy: 0.7531 - val_loss: 2.2651 - val_categorical_accuracy: 0.5748\n",
      "Epoch 28/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 1.4124 - categorical_accuracy: 0.7584 - val_loss: 2.9580 - val_categorical_accuracy: 0.4939\n",
      "Epoch 29/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.3806 - categorical_accuracy: 0.7654 - val_loss: 2.3549 - val_categorical_accuracy: 0.5556\n",
      "Epoch 30/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.3409 - categorical_accuracy: 0.7768 - val_loss: 2.4599 - val_categorical_accuracy: 0.5617\n",
      "Epoch 31/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.3355 - categorical_accuracy: 0.7739 - val_loss: 3.2891 - val_categorical_accuracy: 0.4659\n",
      "Epoch 32/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.3064 - categorical_accuracy: 0.7826 - val_loss: 2.9479 - val_categorical_accuracy: 0.4821\n",
      "Epoch 33/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.2796 - categorical_accuracy: 0.7881 - val_loss: 2.7961 - val_categorical_accuracy: 0.5251\n",
      "Epoch 34/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 1.2522 - categorical_accuracy: 0.7954 - val_loss: 2.7174 - val_categorical_accuracy: 0.5054\n",
      "Epoch 35/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.2410 - categorical_accuracy: 0.7986 - val_loss: 2.7863 - val_categorical_accuracy: 0.5445\n",
      "Epoch 36/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 1.2194 - categorical_accuracy: 0.8036 - val_loss: 2.4809 - val_categorical_accuracy: 0.5622\n",
      "Epoch 37/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.2030 - categorical_accuracy: 0.8062 - val_loss: 2.6783 - val_categorical_accuracy: 0.5372\n",
      "Epoch 38/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.1814 - categorical_accuracy: 0.8131 - val_loss: 2.6395 - val_categorical_accuracy: 0.5331\n",
      "Epoch 39/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.1575 - categorical_accuracy: 0.8205 - val_loss: 2.8548 - val_categorical_accuracy: 0.5299\n",
      "Epoch 40/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.1338 - categorical_accuracy: 0.8282 - val_loss: 2.7292 - val_categorical_accuracy: 0.5600\n",
      "Epoch 41/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.1371 - categorical_accuracy: 0.8241 - val_loss: 3.2221 - val_categorical_accuracy: 0.4917\n",
      "Epoch 42/200\n",
      "195/195 [==============================] - 31s 161ms/step - loss: 1.1238 - categorical_accuracy: 0.8305 - val_loss: 3.4358 - val_categorical_accuracy: 0.4905\n",
      "Epoch 43/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.0916 - categorical_accuracy: 0.8387 - val_loss: 2.5938 - val_categorical_accuracy: 0.5700\n",
      "Epoch 44/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 1.0902 - categorical_accuracy: 0.8375 - val_loss: 2.8462 - val_categorical_accuracy: 0.5260\n",
      "Epoch 45/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.0850 - categorical_accuracy: 0.8407 - val_loss: 3.3749 - val_categorical_accuracy: 0.4943\n",
      "Epoch 46/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.0543 - categorical_accuracy: 0.8502 - val_loss: 2.9025 - val_categorical_accuracy: 0.5238\n",
      "Epoch 47/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 1.0579 - categorical_accuracy: 0.8479 - val_loss: 3.0998 - val_categorical_accuracy: 0.5293\n",
      "Epoch 48/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.0290 - categorical_accuracy: 0.8581 - val_loss: 2.5392 - val_categorical_accuracy: 0.5908\n",
      "Epoch 49/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.0279 - categorical_accuracy: 0.8544 - val_loss: 2.3831 - val_categorical_accuracy: 0.5941\n",
      "Epoch 50/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.0190 - categorical_accuracy: 0.8594 - val_loss: 2.7748 - val_categorical_accuracy: 0.5368\n",
      "Epoch 51/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 1.0018 - categorical_accuracy: 0.8647 - val_loss: 3.0172 - val_categorical_accuracy: 0.5320\n",
      "Epoch 52/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.9976 - categorical_accuracy: 0.8654 - val_loss: 2.3453 - val_categorical_accuracy: 0.6037\n",
      "Epoch 53/200\n",
      "195/195 [==============================] - 33s 167ms/step - loss: 0.9768 - categorical_accuracy: 0.8712 - val_loss: 3.4118 - val_categorical_accuracy: 0.4977\n",
      "Epoch 54/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.9789 - categorical_accuracy: 0.8724 - val_loss: 2.4522 - val_categorical_accuracy: 0.6025\n",
      "Epoch 55/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.9566 - categorical_accuracy: 0.8778 - val_loss: 2.4867 - val_categorical_accuracy: 0.5941\n",
      "Epoch 56/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.9513 - categorical_accuracy: 0.8784 - val_loss: 2.5538 - val_categorical_accuracy: 0.5849\n",
      "Epoch 57/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.9319 - categorical_accuracy: 0.8846 - val_loss: 3.2378 - val_categorical_accuracy: 0.5258\n",
      "Epoch 58/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.9395 - categorical_accuracy: 0.8830 - val_loss: 2.7459 - val_categorical_accuracy: 0.5643\n",
      "Epoch 59/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.9266 - categorical_accuracy: 0.8856 - val_loss: 2.8431 - val_categorical_accuracy: 0.5807\n",
      "Epoch 60/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.9119 - categorical_accuracy: 0.8891 - val_loss: 2.6200 - val_categorical_accuracy: 0.5818\n",
      "Epoch 61/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.9006 - categorical_accuracy: 0.8937 - val_loss: 2.8765 - val_categorical_accuracy: 0.5779\n",
      "Epoch 62/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.8930 - categorical_accuracy: 0.8949 - val_loss: 2.5773 - val_categorical_accuracy: 0.5821\n",
      "Epoch 63/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.8761 - categorical_accuracy: 0.9011 - val_loss: 2.8994 - val_categorical_accuracy: 0.5437\n",
      "Epoch 64/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.8786 - categorical_accuracy: 0.8999 - val_loss: 3.2308 - val_categorical_accuracy: 0.5127\n",
      "Epoch 65/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.8746 - categorical_accuracy: 0.9000 - val_loss: 2.6787 - val_categorical_accuracy: 0.6061\n",
      "Epoch 66/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.8606 - categorical_accuracy: 0.9036 - val_loss: 2.7465 - val_categorical_accuracy: 0.5829\n",
      "Epoch 67/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.8539 - categorical_accuracy: 0.9062 - val_loss: 2.9895 - val_categorical_accuracy: 0.5777\n",
      "Epoch 68/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.8331 - categorical_accuracy: 0.9120 - val_loss: 2.3815 - val_categorical_accuracy: 0.6344\n",
      "Epoch 69/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.8237 - categorical_accuracy: 0.9144 - val_loss: 2.9570 - val_categorical_accuracy: 0.5773\n",
      "Epoch 70/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.8111 - categorical_accuracy: 0.9166 - val_loss: 2.6938 - val_categorical_accuracy: 0.5907\n",
      "Epoch 71/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.8051 - categorical_accuracy: 0.9185 - val_loss: 2.8824 - val_categorical_accuracy: 0.5749\n",
      "Epoch 72/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.8157 - categorical_accuracy: 0.9135 - val_loss: 2.9242 - val_categorical_accuracy: 0.5788\n",
      "Epoch 73/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.7924 - categorical_accuracy: 0.9213 - val_loss: 2.5667 - val_categorical_accuracy: 0.6266\n",
      "Epoch 74/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.7809 - categorical_accuracy: 0.9235 - val_loss: 2.9234 - val_categorical_accuracy: 0.5709\n",
      "Epoch 75/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.7762 - categorical_accuracy: 0.9251 - val_loss: 2.6050 - val_categorical_accuracy: 0.6211\n",
      "Epoch 76/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.7501 - categorical_accuracy: 0.9310 - val_loss: 2.6154 - val_categorical_accuracy: 0.6101\n",
      "Epoch 77/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.7611 - categorical_accuracy: 0.9276 - val_loss: 2.5603 - val_categorical_accuracy: 0.6234\n",
      "Epoch 78/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.7616 - categorical_accuracy: 0.9248 - val_loss: 2.7044 - val_categorical_accuracy: 0.6106\n",
      "Epoch 79/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.7451 - categorical_accuracy: 0.9306 - val_loss: 2.7579 - val_categorical_accuracy: 0.5958\n",
      "Epoch 80/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.7341 - categorical_accuracy: 0.9326 - val_loss: 2.5054 - val_categorical_accuracy: 0.6266\n",
      "Epoch 81/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.7386 - categorical_accuracy: 0.9302 - val_loss: 2.5604 - val_categorical_accuracy: 0.6195\n",
      "Epoch 82/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 0.7306 - categorical_accuracy: 0.9332 - val_loss: 2.8571 - val_categorical_accuracy: 0.5857\n",
      "Epoch 83/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.7217 - categorical_accuracy: 0.9350 - val_loss: 2.4265 - val_categorical_accuracy: 0.6318\n",
      "Epoch 84/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.7078 - categorical_accuracy: 0.9374 - val_loss: 2.6381 - val_categorical_accuracy: 0.6147\n",
      "Epoch 85/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.6872 - categorical_accuracy: 0.9437 - val_loss: 2.3882 - val_categorical_accuracy: 0.6488\n",
      "Epoch 86/200\n",
      "195/195 [==============================] - 33s 168ms/step - loss: 0.6763 - categorical_accuracy: 0.9458 - val_loss: 2.7651 - val_categorical_accuracy: 0.6041\n",
      "Epoch 87/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.6823 - categorical_accuracy: 0.9414 - val_loss: 2.6286 - val_categorical_accuracy: 0.6244\n",
      "Epoch 88/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 0.6712 - categorical_accuracy: 0.9436 - val_loss: 2.7425 - val_categorical_accuracy: 0.6128\n",
      "Epoch 89/200\n",
      "195/195 [==============================] - 31s 161ms/step - loss: 0.6740 - categorical_accuracy: 0.9436 - val_loss: 2.4154 - val_categorical_accuracy: 0.6372\n",
      "Epoch 90/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.6456 - categorical_accuracy: 0.9490 - val_loss: 2.4338 - val_categorical_accuracy: 0.6394\n",
      "Epoch 91/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.6380 - categorical_accuracy: 0.9511 - val_loss: 2.6624 - val_categorical_accuracy: 0.6258\n",
      "Epoch 92/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.6330 - categorical_accuracy: 0.9510 - val_loss: 2.8673 - val_categorical_accuracy: 0.5850\n",
      "Epoch 93/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.6285 - categorical_accuracy: 0.9521 - val_loss: 2.4320 - val_categorical_accuracy: 0.6433\n",
      "Epoch 94/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.6092 - categorical_accuracy: 0.9563 - val_loss: 2.7852 - val_categorical_accuracy: 0.5852\n",
      "Epoch 95/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.6074 - categorical_accuracy: 0.9563 - val_loss: 2.3749 - val_categorical_accuracy: 0.6575\n",
      "Epoch 96/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.5803 - categorical_accuracy: 0.9635 - val_loss: 2.5161 - val_categorical_accuracy: 0.6427\n",
      "Epoch 97/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5893 - categorical_accuracy: 0.9578 - val_loss: 2.7468 - val_categorical_accuracy: 0.6231\n",
      "Epoch 98/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5816 - categorical_accuracy: 0.9592 - val_loss: 2.4931 - val_categorical_accuracy: 0.6373\n",
      "Epoch 99/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5658 - categorical_accuracy: 0.9643 - val_loss: 2.2785 - val_categorical_accuracy: 0.6660\n",
      "Epoch 100/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.5779 - categorical_accuracy: 0.9575 - val_loss: 2.7526 - val_categorical_accuracy: 0.6071\n",
      "Epoch 101/200\n",
      "195/195 [==============================] - 31s 161ms/step - loss: 0.5632 - categorical_accuracy: 0.9614 - val_loss: 2.7620 - val_categorical_accuracy: 0.6235\n",
      "Epoch 102/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5584 - categorical_accuracy: 0.9606 - val_loss: 2.2582 - val_categorical_accuracy: 0.6575\n",
      "Epoch 103/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5438 - categorical_accuracy: 0.9645 - val_loss: 2.7937 - val_categorical_accuracy: 0.6128\n",
      "Epoch 104/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.5379 - categorical_accuracy: 0.9656 - val_loss: 2.1791 - val_categorical_accuracy: 0.6731\n",
      "Epoch 105/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5186 - categorical_accuracy: 0.9696 - val_loss: 2.2380 - val_categorical_accuracy: 0.6678\n",
      "Epoch 106/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5151 - categorical_accuracy: 0.9701 - val_loss: 2.3977 - val_categorical_accuracy: 0.6526\n",
      "Epoch 107/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.5064 - categorical_accuracy: 0.9708 - val_loss: 2.3133 - val_categorical_accuracy: 0.6618\n",
      "Epoch 108/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.4936 - categorical_accuracy: 0.9738 - val_loss: 2.4625 - val_categorical_accuracy: 0.6454\n",
      "Epoch 109/200\n",
      "195/195 [==============================] - 32s 166ms/step - loss: 0.4775 - categorical_accuracy: 0.9771 - val_loss: 2.2445 - val_categorical_accuracy: 0.6773\n",
      "Epoch 110/200\n",
      "195/195 [==============================] - 31s 161ms/step - loss: 0.4687 - categorical_accuracy: 0.9782 - val_loss: 2.2149 - val_categorical_accuracy: 0.6746\n",
      "Epoch 111/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.4683 - categorical_accuracy: 0.9759 - val_loss: 2.3802 - val_categorical_accuracy: 0.6684\n",
      "Epoch 112/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.4604 - categorical_accuracy: 0.9778 - val_loss: 2.3962 - val_categorical_accuracy: 0.6615\n",
      "Epoch 113/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.4522 - categorical_accuracy: 0.9784 - val_loss: 2.5317 - val_categorical_accuracy: 0.6412\n",
      "Epoch 114/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.4433 - categorical_accuracy: 0.9797 - val_loss: 2.2930 - val_categorical_accuracy: 0.6711\n",
      "Epoch 115/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.4286 - categorical_accuracy: 0.9823 - val_loss: 2.3347 - val_categorical_accuracy: 0.6602\n",
      "Epoch 116/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.4277 - categorical_accuracy: 0.9813 - val_loss: 2.2315 - val_categorical_accuracy: 0.6765\n",
      "Epoch 117/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.4156 - categorical_accuracy: 0.9841 - val_loss: 2.0725 - val_categorical_accuracy: 0.6839\n",
      "Epoch 118/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.4093 - categorical_accuracy: 0.9839 - val_loss: 2.1630 - val_categorical_accuracy: 0.6796\n",
      "Epoch 119/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.4001 - categorical_accuracy: 0.9859 - val_loss: 2.1594 - val_categorical_accuracy: 0.6958\n",
      "Epoch 120/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.3847 - categorical_accuracy: 0.9889 - val_loss: 2.1327 - val_categorical_accuracy: 0.6972\n",
      "Epoch 121/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.3805 - categorical_accuracy: 0.9875 - val_loss: 2.1319 - val_categorical_accuracy: 0.6971\n",
      "Epoch 122/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.3682 - categorical_accuracy: 0.9903 - val_loss: 2.0016 - val_categorical_accuracy: 0.7048\n",
      "Epoch 123/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.3571 - categorical_accuracy: 0.9924 - val_loss: 2.0805 - val_categorical_accuracy: 0.6972\n",
      "Epoch 124/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.3606 - categorical_accuracy: 0.9892 - val_loss: 2.0662 - val_categorical_accuracy: 0.6991\n",
      "Epoch 125/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.3578 - categorical_accuracy: 0.9880 - val_loss: 2.1991 - val_categorical_accuracy: 0.6943\n",
      "Epoch 126/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.3449 - categorical_accuracy: 0.9911 - val_loss: 2.1103 - val_categorical_accuracy: 0.6998\n",
      "Epoch 127/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.3418 - categorical_accuracy: 0.9906 - val_loss: 2.0633 - val_categorical_accuracy: 0.7111\n",
      "Epoch 128/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.3376 - categorical_accuracy: 0.9908 - val_loss: 2.1832 - val_categorical_accuracy: 0.6883\n",
      "Epoch 129/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.3249 - categorical_accuracy: 0.9936 - val_loss: 2.0752 - val_categorical_accuracy: 0.7000\n",
      "Epoch 130/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.3158 - categorical_accuracy: 0.9943 - val_loss: 2.0363 - val_categorical_accuracy: 0.7078\n",
      "Epoch 131/200\n",
      "195/195 [==============================] - 33s 167ms/step - loss: 0.3090 - categorical_accuracy: 0.9951 - val_loss: 1.9461 - val_categorical_accuracy: 0.7184\n",
      "Epoch 132/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.2991 - categorical_accuracy: 0.9969 - val_loss: 1.9739 - val_categorical_accuracy: 0.7235\n",
      "Epoch 133/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2929 - categorical_accuracy: 0.9968 - val_loss: 1.9495 - val_categorical_accuracy: 0.7223\n",
      "Epoch 134/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2881 - categorical_accuracy: 0.9972 - val_loss: 2.0422 - val_categorical_accuracy: 0.7138\n",
      "Epoch 135/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2826 - categorical_accuracy: 0.9971 - val_loss: 1.9309 - val_categorical_accuracy: 0.7256\n",
      "Epoch 136/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2759 - categorical_accuracy: 0.9981 - val_loss: 1.8955 - val_categorical_accuracy: 0.7300\n",
      "Epoch 137/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2699 - categorical_accuracy: 0.9987 - val_loss: 1.8763 - val_categorical_accuracy: 0.7364\n",
      "Epoch 138/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2640 - categorical_accuracy: 0.9986 - val_loss: 1.8908 - val_categorical_accuracy: 0.7295\n",
      "Epoch 139/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2602 - categorical_accuracy: 0.9985 - val_loss: 1.9165 - val_categorical_accuracy: 0.7307\n",
      "Epoch 140/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.2550 - categorical_accuracy: 0.9991 - val_loss: 1.9082 - val_categorical_accuracy: 0.7342\n",
      "Epoch 141/200\n",
      "195/195 [==============================] - 32s 166ms/step - loss: 0.2510 - categorical_accuracy: 0.9990 - val_loss: 1.8489 - val_categorical_accuracy: 0.7360\n",
      "Epoch 142/200\n",
      "195/195 [==============================] - 31s 161ms/step - loss: 0.2474 - categorical_accuracy: 0.9988 - val_loss: 1.8967 - val_categorical_accuracy: 0.7359\n",
      "Epoch 143/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 0.2433 - categorical_accuracy: 0.9991 - val_loss: 1.8367 - val_categorical_accuracy: 0.7381\n",
      "Epoch 144/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.2402 - categorical_accuracy: 0.9990 - val_loss: 1.8609 - val_categorical_accuracy: 0.7401\n",
      "Epoch 145/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.2359 - categorical_accuracy: 0.9992 - val_loss: 1.9135 - val_categorical_accuracy: 0.7399\n",
      "Epoch 146/200\n",
      "195/195 [==============================] - 33s 167ms/step - loss: 0.2331 - categorical_accuracy: 0.9992 - val_loss: 1.8259 - val_categorical_accuracy: 0.7422\n",
      "Epoch 147/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2296 - categorical_accuracy: 0.9993 - val_loss: 1.8595 - val_categorical_accuracy: 0.7431\n",
      "Epoch 148/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2263 - categorical_accuracy: 0.9994 - val_loss: 1.8390 - val_categorical_accuracy: 0.7443\n",
      "Epoch 149/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2233 - categorical_accuracy: 0.9994 - val_loss: 1.8112 - val_categorical_accuracy: 0.7426\n",
      "Epoch 150/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2206 - categorical_accuracy: 0.9995 - val_loss: 1.8251 - val_categorical_accuracy: 0.7480\n",
      "Epoch 151/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2178 - categorical_accuracy: 0.9995 - val_loss: 1.8313 - val_categorical_accuracy: 0.7467\n",
      "Epoch 152/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.2158 - categorical_accuracy: 0.9994 - val_loss: 1.8220 - val_categorical_accuracy: 0.7474\n",
      "Epoch 153/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2134 - categorical_accuracy: 0.9995 - val_loss: 1.8205 - val_categorical_accuracy: 0.7427\n",
      "Epoch 154/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.2108 - categorical_accuracy: 0.9995 - val_loss: 1.8047 - val_categorical_accuracy: 0.7497\n",
      "Epoch 155/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.2088 - categorical_accuracy: 0.9996 - val_loss: 1.8161 - val_categorical_accuracy: 0.7468\n",
      "Epoch 156/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2069 - categorical_accuracy: 0.9996 - val_loss: 1.8042 - val_categorical_accuracy: 0.7501\n",
      "Epoch 157/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.2048 - categorical_accuracy: 0.9997 - val_loss: 1.7934 - val_categorical_accuracy: 0.7530\n",
      "Epoch 158/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2028 - categorical_accuracy: 0.9997 - val_loss: 1.7980 - val_categorical_accuracy: 0.7507\n",
      "Epoch 159/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.2012 - categorical_accuracy: 0.9996 - val_loss: 1.7979 - val_categorical_accuracy: 0.7508\n",
      "Epoch 160/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.1997 - categorical_accuracy: 0.9996 - val_loss: 1.8074 - val_categorical_accuracy: 0.7483\n",
      "Epoch 161/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.1981 - categorical_accuracy: 0.9997 - val_loss: 1.8225 - val_categorical_accuracy: 0.7458\n",
      "Epoch 162/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 0.1967 - categorical_accuracy: 0.9997 - val_loss: 1.8051 - val_categorical_accuracy: 0.7476\n",
      "Epoch 163/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1952 - categorical_accuracy: 0.9996 - val_loss: 1.8194 - val_categorical_accuracy: 0.7466\n",
      "Epoch 164/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1940 - categorical_accuracy: 0.9997 - val_loss: 1.8064 - val_categorical_accuracy: 0.7494\n",
      "Epoch 165/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1927 - categorical_accuracy: 0.9997 - val_loss: 1.8019 - val_categorical_accuracy: 0.7499\n",
      "Epoch 166/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1918 - categorical_accuracy: 0.9996 - val_loss: 1.8115 - val_categorical_accuracy: 0.7478\n",
      "Epoch 167/200\n",
      "195/195 [==============================] - 33s 167ms/step - loss: 0.1905 - categorical_accuracy: 0.9997 - val_loss: 1.8176 - val_categorical_accuracy: 0.7487\n",
      "Epoch 168/200\n",
      "195/195 [==============================] - 32s 166ms/step - loss: 0.1896 - categorical_accuracy: 0.9997 - val_loss: 1.8085 - val_categorical_accuracy: 0.7504\n",
      "Epoch 169/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 0.1884 - categorical_accuracy: 0.9998 - val_loss: 1.8070 - val_categorical_accuracy: 0.7503\n",
      "Epoch 170/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.1876 - categorical_accuracy: 0.9998 - val_loss: 1.8111 - val_categorical_accuracy: 0.7511\n",
      "Epoch 171/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1869 - categorical_accuracy: 0.9996 - val_loss: 1.8136 - val_categorical_accuracy: 0.7493\n",
      "Epoch 172/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.1861 - categorical_accuracy: 0.9997 - val_loss: 1.8157 - val_categorical_accuracy: 0.7509\n",
      "Epoch 173/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.1853 - categorical_accuracy: 0.9998 - val_loss: 1.8102 - val_categorical_accuracy: 0.7513\n",
      "Epoch 174/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1846 - categorical_accuracy: 0.9998 - val_loss: 1.8122 - val_categorical_accuracy: 0.7509\n",
      "Epoch 175/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1840 - categorical_accuracy: 0.9998 - val_loss: 1.8148 - val_categorical_accuracy: 0.7519\n",
      "Epoch 176/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.1834 - categorical_accuracy: 0.9998 - val_loss: 1.8135 - val_categorical_accuracy: 0.7508\n",
      "Epoch 177/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1831 - categorical_accuracy: 0.9997 - val_loss: 1.8146 - val_categorical_accuracy: 0.7499\n",
      "Epoch 178/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1825 - categorical_accuracy: 0.9998 - val_loss: 1.8152 - val_categorical_accuracy: 0.7500\n",
      "Epoch 179/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1820 - categorical_accuracy: 0.9997 - val_loss: 1.8137 - val_categorical_accuracy: 0.7496\n",
      "Epoch 180/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1816 - categorical_accuracy: 0.9998 - val_loss: 1.8156 - val_categorical_accuracy: 0.7503\n",
      "Epoch 181/200\n",
      "195/195 [==============================] - 32s 161ms/step - loss: 0.1813 - categorical_accuracy: 0.9997 - val_loss: 1.8127 - val_categorical_accuracy: 0.7509\n",
      "Epoch 182/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1809 - categorical_accuracy: 0.9997 - val_loss: 1.8156 - val_categorical_accuracy: 0.7503\n",
      "Epoch 183/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1806 - categorical_accuracy: 0.9998 - val_loss: 1.8154 - val_categorical_accuracy: 0.7502\n",
      "Epoch 184/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.1804 - categorical_accuracy: 0.9997 - val_loss: 1.8154 - val_categorical_accuracy: 0.7503\n",
      "Epoch 185/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.1802 - categorical_accuracy: 0.9998 - val_loss: 1.8151 - val_categorical_accuracy: 0.7502\n",
      "Epoch 186/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1799 - categorical_accuracy: 0.9998 - val_loss: 1.8154 - val_categorical_accuracy: 0.7503\n",
      "Epoch 187/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1797 - categorical_accuracy: 0.9998 - val_loss: 1.8153 - val_categorical_accuracy: 0.7505\n",
      "Epoch 188/200\n",
      "195/195 [==============================] - 31s 161ms/step - loss: 0.1796 - categorical_accuracy: 0.9998 - val_loss: 1.8158 - val_categorical_accuracy: 0.7509\n",
      "Epoch 189/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.1794 - categorical_accuracy: 0.9999 - val_loss: 1.8147 - val_categorical_accuracy: 0.7511\n",
      "Epoch 190/200\n",
      "195/195 [==============================] - 32s 162ms/step - loss: 0.1794 - categorical_accuracy: 0.9996 - val_loss: 1.8162 - val_categorical_accuracy: 0.7509\n",
      "Epoch 191/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1792 - categorical_accuracy: 0.9998 - val_loss: 1.8157 - val_categorical_accuracy: 0.7506\n",
      "Epoch 192/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1793 - categorical_accuracy: 0.9997 - val_loss: 1.8151 - val_categorical_accuracy: 0.7507\n",
      "Epoch 193/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1791 - categorical_accuracy: 0.9998 - val_loss: 1.8149 - val_categorical_accuracy: 0.7510\n",
      "Epoch 194/200\n",
      "195/195 [==============================] - 32s 163ms/step - loss: 0.1791 - categorical_accuracy: 0.9997 - val_loss: 1.8145 - val_categorical_accuracy: 0.7510\n",
      "Epoch 195/200\n",
      "195/195 [==============================] - 32s 165ms/step - loss: 0.1791 - categorical_accuracy: 0.9998 - val_loss: 1.8148 - val_categorical_accuracy: 0.7510\n",
      "Epoch 196/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1790 - categorical_accuracy: 0.9998 - val_loss: 1.8144 - val_categorical_accuracy: 0.7507\n",
      "Epoch 197/200\n",
      "195/195 [==============================] - 32s 164ms/step - loss: 0.1791 - categorical_accuracy: 0.9997 - val_loss: 1.8145 - val_categorical_accuracy: 0.7513\n",
      "Epoch 198/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.1789 - categorical_accuracy: 0.9998 - val_loss: 1.8146 - val_categorical_accuracy: 0.7513\n",
      "Epoch 199/200\n",
      "195/195 [==============================] - 33s 166ms/step - loss: 0.1790 - categorical_accuracy: 0.9997 - val_loss: 1.8154 - val_categorical_accuracy: 0.7511\n",
      "Epoch 200/200\n",
      "195/195 [==============================] - 31s 160ms/step - loss: 0.1790 - categorical_accuracy: 0.9998 - val_loss: 1.8151 - val_categorical_accuracy: 0.7512\n",
      "313/313 [==============================] - 3s 9ms/step - loss: 1.8151 - categorical_accuracy: 0.7512\n",
      "The teacher's accuracy is: 75.1200020313263\n"
     ]
    }
   ],
   "source": [
    "if _TEACHER_ARCHITECTURE == 'resnet':\n",
    "    learning_rate = 0.1\n",
    "    with_lr_scheduler = False\n",
    "elif _TEACHER_ARCHITECTURE == 'mobilenet':\n",
    "    learning_rate = 0.001\n",
    "    with_lr_scheduler = True\n",
    "elif 'bert' in _TEACHER_ARCHITECTURE:\n",
    "    learning_rate = 1e-5\n",
    "    with_lr_scheduler = False\n",
    "else:\n",
    "    raise ValueError(\n",
    "        f'student_architecture:{_TEACHER_ARCHITECTURE} is not supported.'\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "teacher_model = get_compiled_model(\n",
    "  num_classes=num_classes,\n",
    "  loss_fn='categorical_crossentropy',\n",
    "  metric_fn='categorical_accuracy',\n",
    "  resnet_depth=_TEACHER_RESNET_DEPTH,\n",
    "  mobilenet_dm=2,\n",
    "  learning_rate=learning_rate,\n",
    "  decay_steps=int(\n",
    "            _TEACHER_EPOCHS\n",
    "            * (1.0 * train_x.shape[0] / _TEACHER_BATCH_SIZE)\n",
    "        ),\n",
    "  only_logits=False,\n",
    "  output_embeddings=False,\n",
    "  architecture=_TEACHER_ARCHITECTURE,\n",
    "  subclass_teacher=_USE_SUBCLASS_TEACHER,\n",
    "  num_subclasses=_NUM_SUBCLASSES,\n",
    "  auxiliary_loss_param=_AUXILIARY_LOSS_PARAM,\n",
    "  add_mlp=_ADD_MLP,\n",
    "  )\n",
    "if (\n",
    "        _TEACHER_ARCHITECTURE == 'resnet'\n",
    "        or _TEACHER_ARCHITECTURE == 'mobilenet'\n",
    "    ):\n",
    "      teacher_model.build((None, 32, 32, 3))\n",
    "elif 'bert' in _TEACHER_ARCHITECTURE:\n",
    "      dummy_train_set = prepare_nlp_dataset(\n",
    "          examples=teacher_train_x[0 : _TEACHER_BATCH_SIZE],\n",
    "          labels=teacher_train_y[0 : _TEACHER_BATCH_SIZE],\n",
    "          batch_size=_TEACHER_BATCH_SIZE,\n",
    "          text_data=text_data,\n",
    "      )\n",
    "      teacher_model.predict(dummy_train_set)\n",
    "else:\n",
    "      raise ValueError(\n",
    "          f'architecture:{_TEACHER_ARCHITECTURE} is not supported.'\n",
    "      )\n",
    "\n",
    "\n",
    "if (\n",
    "        _TEACHER_ARCHITECTURE == 'resnet'\n",
    "        or _TEACHER_ARCHITECTURE == 'mobilenet'\n",
    "    ):\n",
    "      teacher_history = train_model(\n",
    "          model=teacher_model,\n",
    "          train_x=teacher_train_x,\n",
    "          train_y=teacher_train_y,\n",
    "          test_x=test_x,\n",
    "          test_y=test_y,\n",
    "          data_augmentation='online',\n",
    "          with_lr_scheduler=with_lr_scheduler,\n",
    "          batch_size=_TEACHER_BATCH_SIZE,\n",
    "          epochs=_TEACHER_EPOCHS,\n",
    "      )\n",
    "elif 'bert' in _TEACHER_ARCHITECTURE:\n",
    "      teacher_history = teacher_model.fit(\n",
    "          teacher_train_set,\n",
    "          epochs=_TEACHER_EPOCHS,\n",
    "          validation_data=test_dataset,\n",
    "          shuffle=True,\n",
    "      )\n",
    "\n",
    "else:\n",
    "      raise ValueError(\n",
    "          f'architecture:{_TEACHER_ARCHITECTURE } is not supported.'\n",
    "      )\n",
    "\n",
    "\n",
    "if 'bert' in _TEACHER_ARCHITECTURE:\n",
    "    score = teacher_model.evaluate(test_dataset)\n",
    "else:\n",
    "    score = teacher_model.evaluate(test_x, test_y)\n",
    "teacher_accuracy = 0.0\n",
    "\n",
    "if isinstance(score, (list, tuple, np.ndarray)):\n",
    "    accuracy_index = 0 if _USE_SUBCLASS_TEACHER else 1\n",
    "    teacher_accuracy = score[accuracy_index] * 100.0\n",
    "print(\"The teacher's accuracy is:\", teacher_accuracy)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Uy3SiMExF3fg"
   },
   "source": [
    "# Train the student"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "executionInfo": {
     "elapsed": 3054987,
     "status": "ok",
     "timestamp": 1706818980223,
     "user": {
      "displayName": "",
      "userId": "07549615973646857260"
     },
     "user_tz": 300
    },
    "id": "940JnPCLF5lT",
    "outputId": "276853b2-2f92-481d-ab5f-59973622cdb3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1563/1563 [==============================] - 10s 5ms/step\n",
      "Dynamic Subclasses!\n",
      "Epoch 1/200\n",
      "390/390 [==============================] - 26s 40ms/step - categorical_accuracy: 4.2067e-04 - val_categorical_accuracy: 0.0100 - val_student_loss: 4.6554 - lr: 0.0010\n",
      "Epoch 2/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0011 - val_categorical_accuracy: 0.1496 - val_student_loss: 3.7942 - lr: 0.0010\n",
      "Epoch 3/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0014 - val_categorical_accuracy: 0.1839 - val_student_loss: 3.3283 - lr: 0.0010\n",
      "Epoch 4/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.2078 - val_student_loss: 3.2242 - lr: 0.0010\n",
      "Epoch 5/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.2493 - val_student_loss: 3.0858 - lr: 0.0010\n",
      "Epoch 6/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.2714 - val_student_loss: 3.0693 - lr: 0.0010\n",
      "Epoch 7/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.2984 - val_student_loss: 2.9450 - lr: 0.0010\n",
      "Epoch 8/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0021 - val_categorical_accuracy: 0.2974 - val_student_loss: 2.8564 - lr: 0.0010\n",
      "Epoch 9/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.2772 - val_student_loss: 3.1302 - lr: 0.0010\n",
      "Epoch 10/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.3672 - val_student_loss: 2.9382 - lr: 0.0010\n",
      "Epoch 11/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.3747 - val_student_loss: 2.9849 - lr: 0.0010\n",
      "Epoch 12/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.3862 - val_student_loss: 2.8420 - lr: 0.0010\n",
      "Epoch 13/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.3920 - val_student_loss: 2.7204 - lr: 0.0010\n",
      "Epoch 14/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.3868 - val_student_loss: 2.8096 - lr: 0.0010\n",
      "Epoch 15/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.3971 - val_student_loss: 2.7716 - lr: 0.0010\n",
      "Epoch 16/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.4249 - val_student_loss: 2.5544 - lr: 0.0010\n",
      "Epoch 17/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.4299 - val_student_loss: 2.4351 - lr: 0.0010\n",
      "Epoch 18/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0022 - val_categorical_accuracy: 0.4349 - val_student_loss: 2.6717 - lr: 0.0010\n",
      "Epoch 19/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.4376 - val_student_loss: 2.4682 - lr: 0.0010\n",
      "Epoch 20/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0021 - val_categorical_accuracy: 0.4527 - val_student_loss: 2.5426 - lr: 0.0010\n",
      "Epoch 21/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.4500 - val_student_loss: 2.6072 - lr: 0.0010\n",
      "Epoch 22/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0021 - val_categorical_accuracy: 0.4597 - val_student_loss: 2.6887 - lr: 0.0010\n",
      "Epoch 23/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.4744 - val_student_loss: 2.5231 - lr: 0.0010\n",
      "Epoch 24/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0022 - val_categorical_accuracy: 0.4527 - val_student_loss: 2.5944 - lr: 0.0010\n",
      "Epoch 25/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.4867 - val_student_loss: 2.3594 - lr: 0.0010\n",
      "Epoch 26/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.4765 - val_student_loss: 2.6691 - lr: 0.0010\n",
      "Epoch 27/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.4959 - val_student_loss: 2.4880 - lr: 0.0010\n",
      "Epoch 28/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.4802 - val_student_loss: 2.3589 - lr: 0.0010\n",
      "Epoch 29/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.5096 - val_student_loss: 2.6404 - lr: 0.0010\n",
      "Epoch 30/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5080 - val_student_loss: 2.4730 - lr: 0.0010\n",
      "Epoch 31/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5099 - val_student_loss: 2.4209 - lr: 0.0010\n",
      "Epoch 32/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5079 - val_student_loss: 2.3480 - lr: 0.0010\n",
      "Epoch 33/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.5108 - val_student_loss: 2.3827 - lr: 0.0010\n",
      "Epoch 34/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5236 - val_student_loss: 2.4225 - lr: 0.0010\n",
      "Epoch 35/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.5169 - val_student_loss: 2.4178 - lr: 0.0010\n",
      "Epoch 36/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.5262 - val_student_loss: 2.2753 - lr: 0.0010\n",
      "Epoch 37/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5002 - val_student_loss: 2.2798 - lr: 0.0010\n",
      "Epoch 38/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5244 - val_student_loss: 2.4871 - lr: 0.0010\n",
      "Epoch 39/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5300 - val_student_loss: 2.2517 - lr: 0.0010\n",
      "Epoch 40/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5222 - val_student_loss: 2.3202 - lr: 0.0010\n",
      "Epoch 41/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5170 - val_student_loss: 2.3543 - lr: 0.0010\n",
      "Epoch 42/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5261 - val_student_loss: 2.2289 - lr: 0.0010\n",
      "Epoch 43/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5194 - val_student_loss: 2.4232 - lr: 0.0010\n",
      "Epoch 44/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5303 - val_student_loss: 2.3798 - lr: 0.0010\n",
      "Epoch 45/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5383 - val_student_loss: 2.3367 - lr: 0.0010\n",
      "Epoch 46/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5524 - val_student_loss: 2.2869 - lr: 0.0010\n",
      "Epoch 47/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5114 - val_student_loss: 2.3908 - lr: 0.0010\n",
      "Epoch 48/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5292 - val_student_loss: 2.3652 - lr: 0.0010\n",
      "Epoch 49/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5605 - val_student_loss: 2.2697 - lr: 0.0010\n",
      "Epoch 50/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5487 - val_student_loss: 2.3225 - lr: 0.0010\n",
      "Epoch 51/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5573 - val_student_loss: 2.1132 - lr: 0.0010\n",
      "Epoch 52/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5457 - val_student_loss: 2.1335 - lr: 0.0010\n",
      "Epoch 53/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5552 - val_student_loss: 2.2455 - lr: 0.0010\n",
      "Epoch 54/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0021 - val_categorical_accuracy: 0.5576 - val_student_loss: 2.3118 - lr: 0.0010\n",
      "Epoch 55/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5583 - val_student_loss: 2.1841 - lr: 0.0010\n",
      "Epoch 56/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5484 - val_student_loss: 2.2946 - lr: 0.0010\n",
      "Epoch 57/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5490 - val_student_loss: 2.4043 - lr: 0.0010\n",
      "Epoch 58/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5428 - val_student_loss: 2.4669 - lr: 0.0010\n",
      "Epoch 59/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5449 - val_student_loss: 2.4235 - lr: 0.0010\n",
      "Epoch 60/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5592 - val_student_loss: 2.2876 - lr: 0.0010\n",
      "Epoch 61/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5649 - val_student_loss: 2.2155 - lr: 0.0010\n",
      "Epoch 62/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5531 - val_student_loss: 2.3301 - lr: 0.0010\n",
      "Epoch 63/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5671 - val_student_loss: 2.2985 - lr: 0.0010\n",
      "Epoch 64/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5637 - val_student_loss: 2.3050 - lr: 0.0010\n",
      "Epoch 65/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5591 - val_student_loss: 2.3173 - lr: 0.0010\n",
      "Epoch 66/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5719 - val_student_loss: 2.2204 - lr: 0.0010\n",
      "Epoch 67/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5603 - val_student_loss: 2.2720 - lr: 0.0010\n",
      "Epoch 68/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0015 - val_categorical_accuracy: 0.5634 - val_student_loss: 2.1178 - lr: 0.0010\n",
      "Epoch 69/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5546 - val_student_loss: 2.1953 - lr: 0.0010\n",
      "Epoch 70/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5682 - val_student_loss: 2.2516 - lr: 0.0010\n",
      "Epoch 71/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5683 - val_student_loss: 2.1494 - lr: 0.0010\n",
      "Epoch 72/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0015 - val_categorical_accuracy: 0.5806 - val_student_loss: 2.3678 - lr: 0.0010\n",
      "Epoch 73/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5647 - val_student_loss: 2.3345 - lr: 0.0010\n",
      "Epoch 74/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0015 - val_categorical_accuracy: 0.5714 - val_student_loss: 2.2798 - lr: 0.0010\n",
      "Epoch 75/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5646 - val_student_loss: 2.3225 - lr: 0.0010\n",
      "Epoch 76/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5501 - val_student_loss: 2.3548 - lr: 0.0010\n",
      "Epoch 77/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5626 - val_student_loss: 2.2646 - lr: 0.0010\n",
      "Epoch 78/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.5664 - val_student_loss: 2.3504 - lr: 0.0010\n",
      "Epoch 79/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0015 - val_categorical_accuracy: 0.5584 - val_student_loss: 2.2560 - lr: 0.0010\n",
      "Epoch 80/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5574 - val_student_loss: 2.4007 - lr: 0.0010\n",
      "Epoch 81/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5963 - val_student_loss: 2.3555 - lr: 1.0000e-04\n",
      "Epoch 82/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5993 - val_student_loss: 2.3774 - lr: 1.0000e-04\n",
      "Epoch 83/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6021 - val_student_loss: 2.3315 - lr: 1.0000e-04\n",
      "Epoch 84/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.5993 - val_student_loss: 2.3592 - lr: 1.0000e-04\n",
      "Epoch 85/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6001 - val_student_loss: 2.3356 - lr: 1.0000e-04\n",
      "Epoch 86/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6013 - val_student_loss: 2.3563 - lr: 1.0000e-04\n",
      "Epoch 87/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6009 - val_student_loss: 2.3160 - lr: 1.0000e-04\n",
      "Epoch 88/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5995 - val_student_loss: 2.3195 - lr: 1.0000e-04\n",
      "Epoch 89/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.6010 - val_student_loss: 2.3146 - lr: 1.0000e-04\n",
      "Epoch 90/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6033 - val_student_loss: 2.3330 - lr: 1.0000e-04\n",
      "Epoch 91/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6017 - val_student_loss: 2.3613 - lr: 1.0000e-04\n",
      "Epoch 92/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6014 - val_student_loss: 2.3208 - lr: 1.0000e-04\n",
      "Epoch 93/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.6051 - val_student_loss: 2.3495 - lr: 1.0000e-04\n",
      "Epoch 94/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6011 - val_student_loss: 2.3165 - lr: 1.0000e-04\n",
      "Epoch 95/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6026 - val_student_loss: 2.3178 - lr: 1.0000e-04\n",
      "Epoch 96/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6022 - val_student_loss: 2.2997 - lr: 1.0000e-04\n",
      "Epoch 97/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6029 - val_student_loss: 2.3156 - lr: 1.0000e-04\n",
      "Epoch 98/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.6027 - val_student_loss: 2.2912 - lr: 1.0000e-04\n",
      "Epoch 99/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6043 - val_student_loss: 2.3425 - lr: 1.0000e-04\n",
      "Epoch 100/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6008 - val_student_loss: 2.2627 - lr: 1.0000e-04\n",
      "Epoch 101/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6030 - val_student_loss: 2.3188 - lr: 1.0000e-04\n",
      "Epoch 102/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6038 - val_student_loss: 2.3003 - lr: 1.0000e-04\n",
      "Epoch 103/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6029 - val_student_loss: 2.3379 - lr: 1.0000e-04\n",
      "Epoch 104/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.5998 - val_student_loss: 2.3326 - lr: 1.0000e-04\n",
      "Epoch 105/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5984 - val_student_loss: 2.3006 - lr: 1.0000e-04\n",
      "Epoch 106/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6029 - val_student_loss: 2.3296 - lr: 1.0000e-04\n",
      "Epoch 107/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6024 - val_student_loss: 2.3313 - lr: 1.0000e-04\n",
      "Epoch 108/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6015 - val_student_loss: 2.3290 - lr: 1.0000e-04\n",
      "Epoch 109/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.5992 - val_student_loss: 2.3134 - lr: 1.0000e-04\n",
      "Epoch 110/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6005 - val_student_loss: 2.3180 - lr: 1.0000e-04\n",
      "Epoch 111/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.5998 - val_student_loss: 2.2755 - lr: 1.0000e-04\n",
      "Epoch 112/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6033 - val_student_loss: 2.2898 - lr: 1.0000e-04\n",
      "Epoch 113/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6006 - val_student_loss: 2.2602 - lr: 1.0000e-04\n",
      "Epoch 114/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6017 - val_student_loss: 2.2894 - lr: 1.0000e-04\n",
      "Epoch 115/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6016 - val_student_loss: 2.2983 - lr: 1.0000e-04\n",
      "Epoch 116/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6044 - val_student_loss: 2.3078 - lr: 1.0000e-04\n",
      "Epoch 117/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.6032 - val_student_loss: 2.3257 - lr: 1.0000e-04\n",
      "Epoch 118/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6043 - val_student_loss: 2.3113 - lr: 1.0000e-04\n",
      "Epoch 119/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6021 - val_student_loss: 2.3228 - lr: 1.0000e-04\n",
      "Epoch 120/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6005 - val_student_loss: 2.2592 - lr: 1.0000e-04\n",
      "Epoch 121/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6032 - val_student_loss: 2.2916 - lr: 1.0000e-05\n",
      "Epoch 122/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6032 - val_student_loss: 2.2934 - lr: 1.0000e-05\n",
      "Epoch 123/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6041 - val_student_loss: 2.2943 - lr: 1.0000e-05\n",
      "Epoch 124/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6042 - val_student_loss: 2.2895 - lr: 1.0000e-05\n",
      "Epoch 125/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6038 - val_student_loss: 2.2932 - lr: 1.0000e-05\n",
      "Epoch 126/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6043 - val_student_loss: 2.2893 - lr: 1.0000e-05\n",
      "Epoch 127/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6041 - val_student_loss: 2.2940 - lr: 1.0000e-05\n",
      "Epoch 128/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6039 - val_student_loss: 2.2920 - lr: 1.0000e-05\n",
      "Epoch 129/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6050 - val_student_loss: 2.2926 - lr: 1.0000e-05\n",
      "Epoch 130/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6049 - val_student_loss: 2.2909 - lr: 1.0000e-05\n",
      "Epoch 131/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6043 - val_student_loss: 2.2820 - lr: 1.0000e-05\n",
      "Epoch 132/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6057 - val_student_loss: 2.2762 - lr: 1.0000e-05\n",
      "Epoch 133/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6048 - val_student_loss: 2.2901 - lr: 1.0000e-05\n",
      "Epoch 134/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6050 - val_student_loss: 2.2811 - lr: 1.0000e-05\n",
      "Epoch 135/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6051 - val_student_loss: 2.2893 - lr: 1.0000e-05\n",
      "Epoch 136/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6053 - val_student_loss: 2.2861 - lr: 1.0000e-05\n",
      "Epoch 137/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6054 - val_student_loss: 2.2848 - lr: 1.0000e-05\n",
      "Epoch 138/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6044 - val_student_loss: 2.2872 - lr: 1.0000e-05\n",
      "Epoch 139/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6044 - val_student_loss: 2.2814 - lr: 1.0000e-05\n",
      "Epoch 140/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6043 - val_student_loss: 2.2862 - lr: 1.0000e-05\n",
      "Epoch 141/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0020 - val_categorical_accuracy: 0.6041 - val_student_loss: 2.2910 - lr: 1.0000e-05\n",
      "Epoch 142/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6044 - val_student_loss: 2.2940 - lr: 1.0000e-05\n",
      "Epoch 143/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6036 - val_student_loss: 2.2871 - lr: 1.0000e-05\n",
      "Epoch 144/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6041 - val_student_loss: 2.2950 - lr: 1.0000e-05\n",
      "Epoch 145/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6041 - val_student_loss: 2.2949 - lr: 1.0000e-05\n",
      "Epoch 146/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6046 - val_student_loss: 2.2915 - lr: 1.0000e-05\n",
      "Epoch 147/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0015 - val_categorical_accuracy: 0.6037 - val_student_loss: 2.2949 - lr: 1.0000e-05\n",
      "Epoch 148/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6049 - val_student_loss: 2.2991 - lr: 1.0000e-05\n",
      "Epoch 149/200\n",
      "390/390 [==============================] - 15s 38ms/step - categorical_accuracy: 0.0015 - val_categorical_accuracy: 0.6048 - val_student_loss: 2.2985 - lr: 1.0000e-05\n",
      "Epoch 150/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6039 - val_student_loss: 2.2893 - lr: 1.0000e-05\n",
      "Epoch 151/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6041 - val_student_loss: 2.2868 - lr: 1.0000e-05\n",
      "Epoch 152/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6054 - val_student_loss: 2.2949 - lr: 1.0000e-05\n",
      "Epoch 153/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6046 - val_student_loss: 2.2939 - lr: 1.0000e-05\n",
      "Epoch 154/200\n",
      "390/390 [==============================] - 15s 38ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6044 - val_student_loss: 2.2945 - lr: 1.0000e-05\n",
      "Epoch 155/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6043 - val_student_loss: 2.2887 - lr: 1.0000e-05\n",
      "Epoch 156/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6039 - val_student_loss: 2.2893 - lr: 1.0000e-05\n",
      "Epoch 157/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6046 - val_student_loss: 2.2984 - lr: 1.0000e-05\n",
      "Epoch 158/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6042 - val_student_loss: 2.2907 - lr: 1.0000e-05\n",
      "Epoch 159/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6039 - val_student_loss: 2.2917 - lr: 1.0000e-05\n",
      "Epoch 160/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6034 - val_student_loss: 2.2939 - lr: 1.0000e-05\n",
      "Epoch 161/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6031 - val_student_loss: 2.2927 - lr: 1.0000e-06\n",
      "Epoch 162/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6044 - val_student_loss: 2.2941 - lr: 1.0000e-06\n",
      "Epoch 163/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6033 - val_student_loss: 2.2928 - lr: 1.0000e-06\n",
      "Epoch 164/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6037 - val_student_loss: 2.2905 - lr: 1.0000e-06\n",
      "Epoch 165/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6035 - val_student_loss: 2.2944 - lr: 1.0000e-06\n",
      "Epoch 166/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6028 - val_student_loss: 2.2881 - lr: 1.0000e-06\n",
      "Epoch 167/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6029 - val_student_loss: 2.2928 - lr: 1.0000e-06\n",
      "Epoch 168/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6027 - val_student_loss: 2.2897 - lr: 1.0000e-06\n",
      "Epoch 169/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6027 - val_student_loss: 2.2915 - lr: 1.0000e-06\n",
      "Epoch 170/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6026 - val_student_loss: 2.2947 - lr: 1.0000e-06\n",
      "Epoch 171/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6036 - val_student_loss: 2.2943 - lr: 1.0000e-06\n",
      "Epoch 172/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6029 - val_student_loss: 2.2971 - lr: 1.0000e-06\n",
      "Epoch 173/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6038 - val_student_loss: 2.2929 - lr: 1.0000e-06\n",
      "Epoch 174/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6042 - val_student_loss: 2.2885 - lr: 1.0000e-06\n",
      "Epoch 175/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6029 - val_student_loss: 2.2954 - lr: 1.0000e-06\n",
      "Epoch 176/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6030 - val_student_loss: 2.2980 - lr: 1.0000e-06\n",
      "Epoch 177/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6040 - val_student_loss: 2.2985 - lr: 1.0000e-06\n",
      "Epoch 178/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6028 - val_student_loss: 2.2960 - lr: 1.0000e-06\n",
      "Epoch 179/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6040 - val_student_loss: 2.2913 - lr: 1.0000e-06\n",
      "Epoch 180/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6036 - val_student_loss: 2.2987 - lr: 1.0000e-06\n",
      "Epoch 181/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6037 - val_student_loss: 2.2973 - lr: 5.0000e-07\n",
      "Epoch 182/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6030 - val_student_loss: 2.2947 - lr: 5.0000e-07\n",
      "Epoch 183/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6033 - val_student_loss: 2.2967 - lr: 5.0000e-07\n",
      "Epoch 184/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6035 - val_student_loss: 2.2982 - lr: 5.0000e-07\n",
      "Epoch 185/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6033 - val_student_loss: 2.2965 - lr: 5.0000e-07\n",
      "Epoch 186/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6025 - val_student_loss: 2.2903 - lr: 5.0000e-07\n",
      "Epoch 187/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6027 - val_student_loss: 2.2939 - lr: 5.0000e-07\n",
      "Epoch 188/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6040 - val_student_loss: 2.2941 - lr: 5.0000e-07\n",
      "Epoch 189/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6025 - val_student_loss: 2.2972 - lr: 5.0000e-07\n",
      "Epoch 190/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6034 - val_student_loss: 2.2927 - lr: 5.0000e-07\n",
      "Epoch 191/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6032 - val_student_loss: 2.2935 - lr: 5.0000e-07\n",
      "Epoch 192/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6030 - val_student_loss: 2.2957 - lr: 5.0000e-07\n",
      "Epoch 193/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6030 - val_student_loss: 2.2931 - lr: 5.0000e-07\n",
      "Epoch 194/200\n",
      "390/390 [==============================] - 14s 37ms/step - categorical_accuracy: 0.0016 - val_categorical_accuracy: 0.6033 - val_student_loss: 2.2936 - lr: 5.0000e-07\n",
      "Epoch 195/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6034 - val_student_loss: 2.2926 - lr: 5.0000e-07\n",
      "Epoch 196/200\n",
      "390/390 [==============================] - 15s 37ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6034 - val_student_loss: 2.2971 - lr: 5.0000e-07\n",
      "Epoch 197/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0019 - val_categorical_accuracy: 0.6038 - val_student_loss: 2.2888 - lr: 5.0000e-07\n",
      "Epoch 198/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6037 - val_student_loss: 2.2927 - lr: 5.0000e-07\n",
      "Epoch 199/200\n",
      "390/390 [==============================] - 14s 36ms/step - categorical_accuracy: 0.0017 - val_categorical_accuracy: 0.6029 - val_student_loss: 2.2923 - lr: 5.0000e-07\n",
      "Epoch 200/200\n",
      "390/390 [==============================] - 14s 35ms/step - categorical_accuracy: 0.0018 - val_categorical_accuracy: 0.6031 - val_student_loss: 2.2968 - lr: 5.0000e-07\n"
     ]
    }
   ],
   "source": [
    "embedding_loss_fn = tf.math.squared_difference\n",
    "student_loss_fn = tf.keras.losses.CategoricalCrossentropy(\n",
    "        reduction=tf.keras.losses.Reduction.NONE\n",
    "    )\n",
    "\n",
    "if _STUDENT_ARCHITECTURE == 'resnet':\n",
    "    embedding_dimension = 256\n",
    "    learning_rate = 0.05\n",
    "    with_lr_scheduler = False\n",
    "elif _STUDENT_ARCHITECTURE == 'mobilenet':\n",
    "    embedding_dimension = 1024\n",
    "    learning_rate = 0.001\n",
    "    with_lr_scheduler = True\n",
    "elif 'bert' in _STUDENT_ARCHITECTURE:\n",
    "    embedding_dimension = 768\n",
    "    learning_rate = 1e-6\n",
    "    with_lr_scheduler = False\n",
    "else:\n",
    "    raise ValueError(\n",
    "        f'student_architecture:{_STUDENT_ARCHITECTURE} is not supported.'\n",
    "    )\n",
    "\n",
    "history_list = []\n",
    "for trial_iter in range(_NUM_TRIALS):\n",
    "    if _NUM_SUBCLASSES is None:\n",
    "      trainset_x = train_x\n",
    "      trainset_y = train_y\n",
    "      teacher_subclass_fn = None\n",
    "    else:\n",
    "      trainset_x = train_x\n",
    "      trainset_y = train_y\n",
    "\n",
    "    clustering_info = prepare_subclasses_dataset_pca(\n",
    "            teacher_model=teacher_model,\n",
    "            teacher_architecture=_TEACHER_ARCHITECTURE,\n",
    "            train_x=train_x,\n",
    "            train_y=train_y,\n",
    "            num_classes=num_classes,\n",
    "            num_subclasses=_NUM_SUBCLASSES,\n",
    "            type_of_clustering=_TYPE_OF_CLUSTERING,\n",
    "            subclass_teacher=_USE_SUBCLASS_TEACHER,\n",
    "            text_data=text_data,\n",
    "        )\n",
    "\n",
    "    if clustering_info is not None:\n",
    "        teacher_subclass_fn = get_teacher_subclass_fn(\n",
    "            clustering_info,\n",
    "            _TEACHER_LABEL_TEMP,\n",
    "            True,\n",
    "            teacher_outer_label_temp=_TEMPERATURE,\n",
    "            num_classes=num_classes,\n",
    "        )\n",
    "        print('Dynamic Subclasses!')\n",
    "    else:\n",
    "        teacher_subclass_fn = None\n",
    "\n",
    "    params = DistillerParam(\n",
    "        num_classes=num_classes,\n",
    "        student_loss_fn=student_loss_fn,\n",
    "        distillation_loss_fn=tf.keras.losses.KLDivergence(\n",
    "            reduction=tf.keras.losses.Reduction.NONE\n",
    "        ),\n",
    "        embedding_loss_fn=embedding_loss_fn,\n",
    "        teacher_subclass_fn=teacher_subclass_fn,\n",
    "        resnet_depth=_STUDENT_RESNET_DEPTH,\n",
    "        teacher_architecture=_TEACHER_ARCHITECTURE,\n",
    "        learning_rate=learning_rate,\n",
    "        decay_steps=int(\n",
    "            _STUDENT_EPOCHS\n",
    "            * (1.0 * train_x.shape[0] / _STUDENT_BATCH_SIZE)\n",
    "        ),\n",
    "        temperature=_TEMPERATURE,\n",
    "        alpha=_ALPHA,\n",
    "        beta=_BETA,\n",
    "        metric_fn='categorical_accuracy',\n",
    "        embedding_dimension=embedding_dimension,\n",
    "        teacher_label_temperature=_TEACHER_LABEL_TEMP,\n",
    "    )\n",
    "\n",
    "    student_model = get_compiled_distilled_student_model(\n",
    "          teacher_model=teacher_model,\n",
    "          params=params,\n",
    "          student_architecture=_STUDENT_ARCHITECTURE,\n",
    "          num_subclasses=_NUM_SUBCLASSES,\n",
    "          dynamic_subclasses=teacher_subclass_fn is not None,\n",
    "          decomposed_logits=False,\n",
    "          add_mlp=_ADD_MLP,\n",
    "      )\n",
    "    if (\n",
    "        _STUDENT_ARCHITECTURE == 'resnet'\n",
    "        or _STUDENT_ARCHITECTURE == 'mobilenet'\n",
    "    ):\n",
    "      student_model.build((None, 32, 32, 3))\n",
    "    elif 'bert' in _STUDENT_ARCHITECTURE:\n",
    "      dummy_train_set = prepare_nlp_dataset(\n",
    "          examples=teacher_train_x[0 : _STUDENT_BATCH_SIZE],\n",
    "          labels=teacher_train_y[0:_STUDENT_BATCH_SIZE],\n",
    "          batch_size=_STUDENT_BATCH_SIZE,\n",
    "          text_data=text_data,\n",
    "      )\n",
    "      student_model.predict(dummy_train_set)\n",
    "\n",
    "    if (\n",
    "        _STUDENT_ARCHITECTURE == 'resnet'\n",
    "        or _STUDENT_ARCHITECTURE == 'mobilenet'\n",
    "    ):\n",
    "      student_history = train_model(\n",
    "          model=student_model,\n",
    "          train_x=trainset_x,\n",
    "          train_y=trainset_y,\n",
    "          test_x=test_x,\n",
    "          test_y=test_y,\n",
    "          data_augmentation='online',\n",
    "          with_lr_scheduler=with_lr_scheduler,\n",
    "          batch_size=_STUDENT_BATCH_SIZE,\n",
    "          epochs=_STUDENT_EPOCHS,\n",
    "      )\n",
    "    elif 'bert' in _STUDENT_ARCHITECTURE:\n",
    "      train_set = prepare_nlp_dataset(\n",
    "          examples=trainset_x,\n",
    "          labels=trainset_y,\n",
    "          batch_size=_STUDENT_BATCH_SIZE,\n",
    "          text_data=text_data,\n",
    "          expand_label_dims=False,\n",
    "      )\n",
    "      test_dataset = prepare_nlp_dataset(\n",
    "          examples=test_x,\n",
    "          labels=test_y,\n",
    "          batch_size=1,\n",
    "          text_data=text_data,\n",
    "          expand_label_dims=False,\n",
    "      )\n",
    "\n",
    "      student_history = student_model.fit(\n",
    "          train_set,\n",
    "          epochs=_STUDENT_EPOCHS,\n",
    "          validation_data=test_dataset,\n",
    "          shuffle=True,\n",
    "      )\n",
    "    else:\n",
    "      raise ValueError(\n",
    "          f'architecture:{_STUDENT_ARCHITECTURE} is not supported.'\n",
    "      )\n",
    "    history_list.append(student_history)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "last_runtime": {
    "build_target": "//learning/grp/tools/ml_python:ml_notebook",
    "kind": "private"
   },
   "provenance": [
    {
     "file_id": "1zhBPFCRR8hUTgVEWXGOp1zRKc731yJ4q",
     "timestamp": 1706795579700
    },
    {
     "file_id": "1E4wA_w015zwr9MM2FH3D9_Ua2yTwpc_n",
     "timestamp": 1692113177160
    },
    {
     "file_id": "1PCvysKa7ha8O9ZXJH6w4PGXRQRaH41Sx",
     "timestamp": 1690814257763
    },
    {
     "file_id": "12HAYGeg9eky9n0GCMawtrXkhWflntT5g",
     "timestamp": 1688749106217
    },
    {
     "file_id": "15OXCoJAt9tflIVEbBDcL72eJj-1w47Nv",
     "timestamp": 1675374818325
    },
    {
     "file_id": "16iux6HSF_fav3r88rMsMFkUglLA3MYSj",
     "timestamp": 1674663198394
    },
    {
     "file_id": "1kypcF3svum8PeZelkYZlDj2prLjpbV5o",
     "timestamp": 1665423411412
    },
    {
     "file_id": "1XTDIiAvYlr3PgACkrcYcM7ARUi1bulC0",
     "timestamp": 1665135649078
    },
    {
     "file_id": "1CbVFp70yBaRE8id2tBSsAgF2gTO9DkRo",
     "timestamp": 1660275155859
    }
   ],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
