{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pEAjCLI8QCjU"
      },
      "source": [
        "##### Copyright 2020 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2lCKaZSM2Ac0"
      },
      "source": [
        "## RandBits MNIST using Contrastive Learning."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PZrwIvTl05P6"
      },
      "source": [
        "This notebook trains an unsupervised model on the RandBits MNIST dataset using contrastive learning, similar to the results shown in Figure 5(a) in ***Intriguing Properties of Contrastive Losses***."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WE6axfvMFuvW",
        "outputId": "d3006c02-c351-4f27-f18c-28f3bb9c9e93"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting snfpy\n",
            "  Downloading snfpy-0.2.2-py3-none-any.whl (550 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m550.6/550.6 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.14 in /usr/local/lib/python3.10/dist-packages (from snfpy) (1.25.2)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from snfpy) (1.2.2)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from snfpy) (1.11.4)\n",
            "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->snfpy) (1.4.2)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->snfpy) (3.5.0)\n",
            "Installing collected packages: snfpy\n",
            "Successfully installed snfpy-0.2.2\n"
          ]
        }
      ],
      "source": [
        "!pip install snfpy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "wHZSEBw5FgnE"
      },
      "outputs": [],
      "source": [
        "# import os\n",
        "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n",
        "\n",
        "#@title Imports.\n",
        "import tensorflow.compat.v2 as tf\n",
        "tf.enable_v2_behavior()\n",
        "\n",
        "import snf.compute\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p_YbAcw4FgnF",
        "outputId": "5db71c2e-7936-4598-a89c-c36d846163b0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n"
          ]
        }
      ],
      "source": [
        "\n",
        "print(tf.config.list_physical_devices('GPU'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "cellView": "form",
        "id": "0amaY7x4wgGr"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import math\n",
        "import pandas as pd"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EsD0gwvcFgnG",
        "outputId": "047e62f3-2f05-406e-c524-43a736c71f41"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "4.9.4\n"
          ]
        }
      ],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "\n",
        "# Print the version of TensorFlow Datasets\n",
        "print(tfds.__version__)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5mPAqRcbFgnG",
        "outputId": "550a1a9d-d6e1-4952-d6d4-9fc7719d962a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.15.0\n"
          ]
        }
      ],
      "source": [
        "print(tf.__version__)  # Ensure TensorFlow is 2.11"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "dAa3IRTov5Bk"
      },
      "outputs": [],
      "source": [
        "#@title Data preprocessing.\n",
        "def random_crop_and_resize(image):\n",
        "  sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(\n",
        "      image_size=tf.shape(image),\n",
        "      bounding_boxes=tf.constant(\n",
        "          [0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]),\n",
        "      min_object_covered=0.1,\n",
        "      aspect_ratio_range=(3. / 4, 4. / 3.),\n",
        "      area_range=(0.5, 1.0),\n",
        "      max_attempts=100,\n",
        "      use_image_if_no_bounding_boxes=True)\n",
        "  bbox_begin, bbox_size, _ = sample_distorted_bounding_box\n",
        "\n",
        "  # Crop the image to the specified bounding box.\n",
        "  offset_y, offset_x, _ = tf.unstack(bbox_begin)\n",
        "  target_height, target_width, _ = tf.unstack(bbox_size)\n",
        "  image = tf.image.crop_to_bounding_box(\n",
        "      image, offset_y, offset_x, target_height, target_width)\n",
        "\n",
        "  return tf.image.resize(image, [28, 28], method=tf.image.ResizeMethod.BILINEAR)\n",
        "\n",
        "def hash_image_to_bits(image, extra_channel_bits):\n",
        "  hash = tf.compat.v1.strings.to_hash_bucket_fast(\n",
        "      tf.image.encode_jpeg(image),\n",
        "      num_buckets=2**extra_channel_bits)\n",
        "  bits = tf.cast(\n",
        "      tf.math.mod(\n",
        "            tf.bitwise.right_shift(tf.expand_dims([hash], 1),\n",
        "            tf.range(extra_channel_bits, dtype=tf.int64)), 2),\n",
        "      tf.float32)\n",
        "  return bits\n",
        "\n",
        "def pack_extra_channels(image, bits):\n",
        "  extra_channel_bits = tf.shape(bits)[-1]\n",
        "  bits = tf.broadcast_to(bits, [28, 28, extra_channel_bits])\n",
        "  return tf.concat([image, tf.cast(bits, tf.float32)], axis=-1)\n",
        "\n",
        "def get_process_fns(extra_channel_bits):\n",
        "  def preprocess_train_fn(image, label):\n",
        "    bits = hash_image_to_bits(image, extra_channel_bits)\n",
        "    image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n",
        "    label = tf.cast(label, tf.int32)\n",
        "\n",
        "    image_a = random_crop_and_resize(image)\n",
        "    image_b = random_crop_and_resize(image)\n",
        "\n",
        "    # Pack extra channels.\n",
        "    if extra_channel_bits > 0:\n",
        "      image_a = pack_extra_channels(image_a, bits)\n",
        "      image_b = pack_extra_channels(image_b, bits)\n",
        "\n",
        "    image = tf.stack([image_a, image_b], axis=0)  # [2, h, w, c]\n",
        "    print(\"train image shape\")\n",
        "    print(image.shape)\n",
        "    return (image, label)\n",
        "\n",
        "  def preprocess_eval_fn(image, label):\n",
        "    bits = hash_image_to_bits(image, extra_channel_bits)\n",
        "    image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n",
        "    label = tf.cast(label, tf.int32)\n",
        "\n",
        "    if extra_channel_bits > 0:\n",
        "      image = pack_extra_channels(image, bits)\n",
        "    return (image, label)\n",
        "\n",
        "  return preprocess_train_fn, preprocess_eval_fn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "I94Fd2qCWaDE"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from scipy.linalg import eigh\n",
        "\n",
        "\n",
        "cls_loss_object = tf.keras.losses.CategoricalCrossentropy(\n",
        "    from_logits=True,\n",
        "    reduction=tf.keras.losses.Reduction.NONE,\n",
        "    label_smoothing=0.0 # Ensure this is a float.\n",
        ")\n",
        "\n",
        "def get_cls_loss(labels, outputs):\n",
        "  return tf.reduce_mean(cls_loss_object(labels, outputs))\n",
        "\n",
        "\n",
        "def repmat(A, N, M):\n",
        "    # A is the input matrix/tensor\n",
        "    # N is the number of times A is repeated along the rows\n",
        "    # M is the number of times A is repeated along the columns\n",
        "    expanded_A = tf.expand_dims(A, 0)  # Add a new dimension for tiling\n",
        "    tiled_A = tf.tile(expanded_A, [N, M, 1])  # Tile in the new dimension\n",
        "    final_shape = [N * tf.shape(A)[0], M * tf.shape(A)[1]]\n",
        "    return tf.reshape(tiled_A, final_shape)\n",
        "\n",
        "\n",
        "def get_nfda_loss(z1, z2, penal_para, margin):   # [batch_size, dim]\n",
        "    # print(\"the shape of z1\")\n",
        "    # print(z1.shape)\n",
        "    batch_size = tf.shape(z1)[0]\n",
        "    n=2*batch_size\n",
        "    # print(\"batch size\")\n",
        "    # print(batch_size)\n",
        "    dim = tf.shape(z1)[1]\n",
        "\n",
        "    z1 = tf.math.l2_normalize(z1, -1)\n",
        "    z2 = tf.math.l2_normalize(z2, -1)\n",
        "\n",
        "\n",
        "    z=tf.concat([z1, z2], axis=0) # dimension: (2*batch_size,dim)\n",
        "\n",
        "    Sb = tf.zeros((dim, dim), dtype=tf.float32)  # Between-class scatter matrix\n",
        "    Sw = tf.zeros((dim, dim), dtype=tf.float32)  # Within-class scatter matrix\n",
        "\n",
        "    zmean=tf.zeros((batch_size, z.shape[1]), dtype=z.dtype)\n",
        "\n",
        "    for i in tf.range(batch_size):\n",
        "        zc = tf.gather(z, indices=[i, i + batch_size], axis=0) #Kz[:, (i,i+batch_size)]\n",
        "\n",
        "        # 2. Sum the squared elements row-wise\n",
        "        col_sums_squared = tf.reduce_sum(tf.square(zc), axis=1)\n",
        "\n",
        "        # 4. Transpose to get a row matrix\n",
        "        zc2 = tf.reshape(col_sums_squared, (1, -1))  # Explicitly reshape to a row matrix if needed.\n",
        "\n",
        "        mean_zc = tf.reduce_mean(zc, axis=0)\n",
        "\n",
        "        updates = tf.reshape(mean_zc, [1, -1])\n",
        "\n",
        "\n",
        "        #indices = tf.constant([[int(i)]])\n",
        "        indices = tf.reshape(i, [1, 1])  #tf.constant([[i]], dtype=tf.int32)\n",
        "\n",
        "        zmean = tf.tensor_scatter_nd_update(zmean, indices, updates)\n",
        "\n",
        "\n",
        "        rep_zc2 = repmat(zc2, 2, 1)\n",
        "\n",
        "        rep_zc2_T = repmat(tf.transpose(zc2), 1, 2)\n",
        "\n",
        "        #distance2 = rep_zc2 + rep_zc2_T - 2 * tf.matmul(zc,tf.transpose(zc))\n",
        "\n",
        "\n",
        "        A = tf.ones((2, 2)) #tf.exp(-distance2) #get_affinity_matrix(distance2, 1, 2)\n",
        "\n",
        "\n",
        "        zc1 = tf.reduce_sum(zc, axis=0, keepdims=True)\n",
        "\n",
        "\n",
        "        colSums_A = tf.reshape(tf.reduce_sum(A, axis=0), [-1, 1])  # Sum columns and ensure it's a row vector\n",
        "        replicated_colSums_A = tf.tile(colSums_A, [1, tf.shape(zc)[1]])  # Replicate across rows to match Kc's row count\n",
        "\n",
        "        # Perform the matrix operations as described\n",
        "        # Note: In TensorFlow, matrix multiplication is done using tf.matmul or the '@' operator\n",
        "        Z = tf.matmul(tf.transpose(zc), replicated_colSums_A * zc) - tf.matmul(tf.matmul(tf.transpose(zc), A), zc)\n",
        "\n",
        "\n",
        "        Sb += (Z / tf.cast(n, tf.float32)) + tf.transpose(zc) @ zc * (1 - 2.0 / tf.cast(n, tf.float32)) + tf.transpose(zc1) @ zc1 / tf.cast(n, tf.float32)\n",
        "        Sw += Z / 2.0\n",
        "\n",
        "    z1 = tf.reduce_sum(z, axis=0, keepdims=True)\n",
        "\n",
        "    Sb = Sb -  tf.transpose(z1) @ z1 / tf.cast(n, tf.float32) - Sw\n",
        "\n",
        "    Sb = (Sb + tf.transpose(Sb)) / 2.0  # Final between-class scatter matrix\n",
        "    Sw = (Sw + tf.transpose(Sw)) / 2.0  # Final within-class scatter matrix\n",
        "\n",
        "\n",
        "    eye_mat = penal_para * tf.eye(tf.shape(Sw)[0], dtype=Sb.dtype)\n",
        "\n",
        "    B = Sw + eye_mat  # Make sure this is positive definite\n",
        "\n",
        "\n",
        "    temp = tf.linalg.pinv(B) @ Sb\n",
        "\n",
        "\n",
        "    temp=(temp + tf.transpose(temp)) / 2.0\n",
        "\n",
        "    evals, evecs = tf.linalg.eig(temp)\n",
        "\n",
        "    evals_real = tf.math.real(evals)\n",
        "    evecs_real = tf.math.real(evecs)\n",
        "\n",
        "\n",
        "\n",
        "    sorted_indices = tf.argsort(evals_real)\n",
        "    evals_sorted = tf.gather(evals_real, sorted_indices)\n",
        "    evecs_sorted = tf.gather(evecs_real, axis=1, indices=sorted_indices)\n",
        "\n",
        "    n_components = 128 - 1\n",
        "    evals = evals_sorted[-n_components:]\n",
        "\n",
        "\n",
        "    threshold = tf.reduce_min(evals) + margin  # default of margin is margin = 0.01\n",
        "    n_eig = tf.reduce_sum(tf.cast(evals < threshold, tf.float32))\n",
        "\n",
        "    # loss = -tf.reduce_mean(evals)\n",
        "\n",
        "    loss = -tf.reduce_mean(evals[:tf.cast(n_eig, tf.int32)])\n",
        "\n",
        "    return tf.reduce_mean(loss) #, accuracy #tf.reduce_mean(loss), sim\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "1_vu0d9tHlg2"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from scipy.linalg import eigh\n",
        "\n",
        "\n",
        "# #@title Objective functions.\n",
        "# cls_loss_object = tf.keras.losses.CategoricalCrossentropy(\n",
        "#     from_logits=True,\n",
        "#     reduction=tf.keras.losses.Reduction.NONE)\n",
        "\n",
        "cls_loss_object = tf.keras.losses.CategoricalCrossentropy(\n",
        "    from_logits=True,\n",
        "    reduction=tf.keras.losses.Reduction.NONE,\n",
        "    label_smoothing=0.0 # Ensure this is a float.\n",
        ")\n",
        "\n",
        "def get_cls_loss(labels, outputs):\n",
        "  return tf.reduce_mean(cls_loss_object(labels, outputs))\n",
        "\n",
        "\n",
        "def repmat(A, N, M):\n",
        "    # A is the input matrix/tensor\n",
        "    # N is the number of times A is repeated along the rows\n",
        "    # M is the number of times A is repeated along the columns\n",
        "    expanded_A = tf.expand_dims(A, 0)  # Add a new dimension for tiling\n",
        "    tiled_A = tf.tile(expanded_A, [N, M, 1])  # Tile in the new dimension\n",
        "    final_shape = [N * tf.shape(A)[0], M * tf.shape(A)[1]]\n",
        "    return tf.reshape(tiled_A, final_shape)\n",
        "\n",
        "\n",
        "def get_nfda_loss0(z1, z2, penal_para):   # [batch_size, dim]\n",
        "    # print(\"the shape of z1\")\n",
        "    # print(z1.shape)\n",
        "    batch_size = tf.shape(z1)[0]\n",
        "    n=2*batch_size\n",
        "    # print(\"batch size\")\n",
        "    # print(batch_size)\n",
        "    dim = tf.shape(z1)[1]\n",
        "\n",
        "    z1 = tf.math.l2_normalize(z1, -1)\n",
        "    z2 = tf.math.l2_normalize(z2, -1)\n",
        "\n",
        "\n",
        "    z=tf.concat([z1, z2], axis=0) # dimension: (2*batch_size,dim)\n",
        "\n",
        "    Sb = tf.zeros((dim, dim), dtype=tf.float32)  # Between-class scatter matrix\n",
        "    Sw = tf.zeros((dim, dim), dtype=tf.float32)  # Within-class scatter matrix\n",
        "\n",
        "    zmean=tf.zeros((batch_size, z.shape[1]), dtype=z.dtype)\n",
        "\n",
        "    for i in tf.range(batch_size):\n",
        "        zc = tf.gather(z, indices=[i, i + batch_size], axis=0) #Kz[:, (i,i+batch_size)]\n",
        "\n",
        "        # 2. Sum the squared elements row-wise\n",
        "        col_sums_squared = tf.reduce_sum(tf.square(zc), axis=1)\n",
        "\n",
        "        # 4. Transpose to get a row matrix\n",
        "        zc2 = tf.reshape(col_sums_squared, (1, -1))  # Explicitly reshape to a row matrix if needed.\n",
        "\n",
        "        mean_zc = tf.reduce_mean(zc, axis=0)\n",
        "\n",
        "        updates = tf.reshape(mean_zc, [1, -1])\n",
        "\n",
        "\n",
        "        #indices = tf.constant([[int(i)]])\n",
        "        indices = tf.reshape(i, [1, 1])  #tf.constant([[i]], dtype=tf.int32)\n",
        "\n",
        "        zmean = tf.tensor_scatter_nd_update(zmean, indices, updates)\n",
        "\n",
        "\n",
        "        rep_zc2 = repmat(zc2, 2, 1)\n",
        "\n",
        "        rep_zc2_T = repmat(tf.transpose(zc2), 1, 2)\n",
        "\n",
        "        #distance2 = rep_zc2 + rep_zc2_T - 2 * tf.matmul(zc,tf.transpose(zc))\n",
        "\n",
        "\n",
        "        A = tf.ones((2, 2)) #tf.exp(-distance2) #get_affinity_matrix(distance2, 1, 2)\n",
        "\n",
        "\n",
        "        zc1 = tf.reduce_sum(zc, axis=0, keepdims=True)\n",
        "\n",
        "\n",
        "        colSums_A = tf.reshape(tf.reduce_sum(A, axis=0), [-1, 1])  # Sum columns and ensure it's a row vector\n",
        "        replicated_colSums_A = tf.tile(colSums_A, [1, tf.shape(zc)[1]])  # Replicate across rows to match Kc's row count\n",
        "\n",
        "        # Perform the matrix operations as described\n",
        "        # Note: In TensorFlow, matrix multiplication is done using tf.matmul or the '@' operator\n",
        "        Z = tf.matmul(tf.transpose(zc), replicated_colSums_A * zc) - tf.matmul(tf.matmul(tf.transpose(zc), A), zc)\n",
        "\n",
        "\n",
        "        Sb += (Z / tf.cast(n, tf.float32)) + tf.transpose(zc) @ zc * (1 - 2.0 / tf.cast(n, tf.float32)) + tf.transpose(zc1) @ zc1 / tf.cast(n, tf.float32)\n",
        "        Sw += Z / 2.0\n",
        "\n",
        "    z1 = tf.reduce_sum(z, axis=0, keepdims=True)\n",
        "\n",
        "    Sb = Sb -  tf.transpose(z1) @ z1 / tf.cast(n, tf.float32) - Sw\n",
        "\n",
        "    Sb = (Sb + tf.transpose(Sb)) / 2.0  # Final between-class scatter matrix\n",
        "    Sw = (Sw + tf.transpose(Sw)) / 2.0  # Final within-class scatter matrix\n",
        "\n",
        "\n",
        "    eye_mat = penal_para * tf.eye(tf.shape(Sw)[0], dtype=Sb.dtype)\n",
        "\n",
        "    B = Sw + eye_mat  # Make sure this is positive definite\n",
        "\n",
        "\n",
        "    temp = tf.linalg.pinv(B) @ Sb\n",
        "\n",
        "\n",
        "    temp=(temp + tf.transpose(temp)) / 2.0\n",
        "\n",
        "    evals, evecs = tf.linalg.eig(temp)\n",
        "\n",
        "    evals_real = tf.math.real(evals)\n",
        "    evecs_real = tf.math.real(evecs)\n",
        "\n",
        "\n",
        "    sorted_indices = tf.argsort(evals_real)\n",
        "    evals_sorted = tf.gather(evals_real, sorted_indices)\n",
        "    evecs_sorted = tf.gather(evecs_real, axis=1, indices=sorted_indices)\n",
        "\n",
        "    n_components = 128 - 1\n",
        "    evals = evals_sorted[-n_components:]\n",
        "\n",
        "    loss = -tf.reduce_mean(evals)\n",
        "\n",
        "\n",
        "    return tf.reduce_mean(loss) #, accuracy #tf.reduce_mean(loss), sim\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "0usbo7eUseR8"
      },
      "outputs": [],
      "source": [
        "#@title Model.\n",
        "def dense_bn_relu(units):\n",
        "  return tf.keras.Sequential([\n",
        "      tf.keras.layers.Dense(\n",
        "          units, use_bias=False,\n",
        "          kernel_regularizer=tf.keras.regularizers.l2(1e-4)),\n",
        "      tf.keras.layers.BatchNormalization(center=True, scale=True),\n",
        "      tf.keras.layers.ReLU()\n",
        "  ])\n",
        "\n",
        "def conv2d_bn_relu(filters, kernel_size, strides):\n",
        "  return tf.keras.Sequential([\n",
        "      tf.keras.layers.Conv2D(\n",
        "          filters, kernel_size, strides, use_bias=False,\n",
        "          kernel_regularizer=tf.keras.regularizers.l2(1e-4)),\n",
        "      tf.keras.layers.BatchNormalization(center=True, scale=True),\n",
        "      tf.keras.layers.ReLU()\n",
        "  ])\n",
        "\n",
        "class ConvN(tf.keras.Model):\n",
        "\n",
        "  def __init__(self, width_multiplier):\n",
        "    super(ConvN, self).__init__()\n",
        "    self.num_classes = 10\n",
        "    self.latent_dim = 256 * width_multiplier\n",
        "    self.proj_dim = self.latent_dim / 2\n",
        "\n",
        "    self.enc = tf.keras.Sequential([\n",
        "        conv2d_bn_relu(32 * width_multiplier, 3, 2),\n",
        "        conv2d_bn_relu(64 * width_multiplier, 3, 2),\n",
        "        conv2d_bn_relu(64 * width_multiplier, 3, 2),\n",
        "        tf.keras.layers.Flatten(),\n",
        "        dense_bn_relu(self.latent_dim)\n",
        "    ])\n",
        "\n",
        "    self.proj = tf.keras.Sequential([\n",
        "        dense_bn_relu(self.latent_dim * 2),\n",
        "        tf.keras.layers.Dense(\n",
        "            self.proj_dim, use_bias=False, activation=None,\n",
        "            kernel_regularizer=tf.keras.regularizers.l2(1e-4)),\n",
        "    ])\n",
        "\n",
        "    self.classifier = tf.keras.layers.Dense(self.num_classes)\n",
        "\n",
        "  def call(self, inputs, training):\n",
        "    y = self.enc(inputs, training)\n",
        "    z = self.proj(y, training)\n",
        "    pred = self.classifier(tf.stop_gradient(y))\n",
        "    return y, z, pred"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "u9F2LOzm2zy4"
      },
      "outputs": [],
      "source": [
        "#@title Define train_and_eval() for contrastive learning.\n",
        "def train_and_eval(\n",
        "    batch_size=128,\n",
        "    width_multiplier=1,\n",
        "    extra_channel_bits=10,\n",
        "    nt_xent_temp=0.1,\n",
        "    margin=0.01,\n",
        "    learning_rate=0.001,\n",
        "    epochs=10,\n",
        "    log_summary_every_n_steps=100,\n",
        "    eval_every_n_steps=100,\n",
        "    print_output=False):\n",
        "  strategy = tf.distribute.MirroredStrategy()\n",
        "  #strategy = tf.distribute.get_strategy()\n",
        "\n",
        "  # Load dataset.\n",
        "  builder = tfds.builder('mnist')\n",
        "  builder.download_and_prepare()\n",
        "\n",
        "  preprocess_train_fn, preprocess_eval_fn = get_process_fns(extra_channel_bits)\n",
        "  train_dataset = builder.as_dataset(split='train', as_supervised=True)\n",
        "  train_dataset = train_dataset.repeat().map(preprocess_train_fn)\n",
        "  # print(\"train dataset size\")\n",
        "  # print(train_dataset.shape)\n",
        "  train_dataset = train_dataset.batch(batch_size)\n",
        "  train_iter = iter(train_dataset)\n",
        "\n",
        "\n",
        "  test_dataset = builder.as_dataset(split='test', as_supervised=True)\n",
        "  test_dataset = test_dataset.map(preprocess_eval_fn)\n",
        "  test_dataset = test_dataset.batch(batch_size, drop_remainder=False)\n",
        "\n",
        "\n",
        "  total_steps = int(60000 * epochs / batch_size)\n",
        "  steps_per_epoch_test = math.ceil(10000 / batch_size)\n",
        "\n",
        "  # Model and optimizer.\n",
        "  model = ConvN(width_multiplier)\n",
        "  global_step = tf.Variable(\n",
        "      1, trainable=False, name=\"global_step\", dtype=tf.int64)\n",
        "\n",
        "  lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
        "      initial_learning_rate=learning_rate,\n",
        "      decay_steps=total_steps,\n",
        "      end_learning_rate=0.)\n",
        "  optimizer = tf.keras.optimizers.Adam(lr_schedule)\n",
        "\n",
        "  # Define metrics.\n",
        "  loss_metrics = [\n",
        "      \"train_classification_loss\",\n",
        "      \"train_contrastive_loss\",\n",
        "      \"train_total_loss\",\n",
        "  ]\n",
        "  acc_metrics = [\"train_accuracy\", \"eval_accuracy\",\n",
        "                 \"train_contrastive_accuracy\"]\n",
        "  metric_list = {s: tf.keras.metrics.Mean(name=s) for s in loss_metrics}\n",
        "  metric_list.update({\n",
        "      s: tf.keras.metrics.SparseCategoricalAccuracy(name=s)\n",
        "      for s in acc_metrics\n",
        "  })\n",
        "\n",
        "  # Step functions.\n",
        "  @tf.function\n",
        "  def train_step(iterator):\n",
        "    def step_fn(inputs):\n",
        "      images, labels = inputs\n",
        "      # print(\"training images size\")\n",
        "      # print(images.shape)\n",
        "      labels_one_hot = tf.one_hot(labels, depth=10)\n",
        "      images_a, images_b = tf.unstack(images, num=2, axis=1)\n",
        "      # print(\"images_a shape\")\n",
        "      # print(images_a.shape)\n",
        "      with tf.GradientTape() as tape:\n",
        "        _, za, pred_a = model(images_a, training=True)\n",
        "        _, zb, pred_b = model(images_b, training=True)\n",
        "        # print(\"starting loss calculation\")\n",
        "        if margin==0:\n",
        "          contrastive_loss = (\n",
        "              get_nfda_loss0(za, zb, nt_xent_temp))\n",
        "        else:\n",
        "          contrastive_loss = (\n",
        "              get_nfda_loss(za, zb, nt_xent_temp, margin))\n",
        "\n",
        "\n",
        "        classifier_loss = get_cls_loss(labels_one_hot, pred_a)\n",
        "\n",
        "        wd_loss = sum(model.losses)\n",
        "\n",
        "        loss = classifier_loss + wd_loss + contrastive_loss\n",
        "\n",
        "        batch_size = tf.shape(images)[0]\n",
        "        metric_list[\"train_contrastive_loss\"].update_state(contrastive_loss)\n",
        "        metric_list[\"train_classification_loss\"].update_state(classifier_loss)\n",
        "        metric_list[\"train_accuracy\"].update_state(labels, pred_a)\n",
        "        # print(\"metric train_contrastive_accuracy\")\n",
        "        # # metric_list[\"train_contrastive_accuracy\"].update_state(\n",
        "        # #     tf.range(batch_size), contrastive_sim)\n",
        "        # print(\"metric train_total_loss\")\n",
        "        metric_list[\"train_total_loss\"].update_state(loss)\n",
        "\n",
        "      gradients = tape.gradient(loss, model.trainable_variables)\n",
        "      optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "\n",
        "    strategy.run(step_fn, args=(next(iterator),))\n",
        "    global_step.assign_add(1)\n",
        "\n",
        "  @tf.function\n",
        "  def eval_step(iterator):\n",
        "    def step_fn(inputs):\n",
        "      images, labels = inputs\n",
        "      _, _, predictions = model(images, training=False)\n",
        "      metric_list[\"eval_accuracy\"].update_state(labels, predictions)\n",
        "    strategy.run(step_fn, args=(next(iterator),))\n",
        "\n",
        "  # Train and eval loop.\n",
        "  steps = []\n",
        "  eval_accuracies = []\n",
        "  while global_step.numpy() <= total_steps:\n",
        "    train_step(train_iter)\n",
        "    step = global_step.numpy()\n",
        "\n",
        "    if step % log_summary_every_n_steps == 0:\n",
        "      log_msg = \"Steps: {}\".format(step)\n",
        "      for m in loss_metrics + acc_metrics:\n",
        "        if m.startswith(\"train\"):\n",
        "          log_msg += \", {}: {}\".format(m, metric_list[m].result())\n",
        "          metric_list[m].reset_states()\n",
        "      if print_output:\n",
        "        print(log_msg)\n",
        "\n",
        "    if (step % eval_every_n_steps == 0) or (step == total_steps):\n",
        "      eval_iter = iter(test_dataset)\n",
        "\n",
        "      for m in loss_metrics + acc_metrics:\n",
        "        if m.startswith(\"eval\"):\n",
        "          metric_list[m].reset_states()\n",
        "\n",
        "      for _ in range(steps_per_epoch_test):\n",
        "        eval_step(eval_iter)\n",
        "      if print_output:\n",
        "        print(\"Steps: {}, Test accuracy: {}\".format(\n",
        "            step, metric_list[\"eval_accuracy\"].result()))\n",
        "      steps.append(step)\n",
        "      eval_accuracies.append(metric_list[\"eval_accuracy\"].result())\n",
        "\n",
        "  return steps, eval_accuracies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "I3an389J6gCq",
        "outputId": "0a03ff50-9705-46c7-b74a-da503f0675fb"
      },
      "outputs": [],
      "source": [
        "# The actual training. This cell takes a long time to run, especially on a CPU.\n",
        "rows = []\n",
        "cols = ['bits', 'temp', 'eval_accuracy']\n",
        "margin=0.01\n",
        "temp=15\n",
        "\n",
        "i=0\n",
        "for bits in [0,2,4,6,8,10]:\n",
        "  accuracies = []\n",
        "  i=i+1\n",
        "  for _ in range(10):  # Run each setting 10 times\n",
        "    steps, run_accuracies= train_and_eval(\n",
        "      batch_size=128,\n",
        "      width_multiplier=1,\n",
        "      extra_channel_bits=bits,\n",
        "      nt_xent_temp=temp,  # [1,5,10]\n",
        "      margin=margin,\n",
        "      learning_rate=0.001,\n",
        "      epochs=10,\n",
        "      eval_every_n_steps=1000,\n",
        "      print_output=False)\n",
        "\n",
        "    # Collect the final accuracy of each run\n",
        "    accuracies.append(run_accuracies[-1].numpy())\n",
        "\n",
        "  # Calculate mean and standard deviation of the accuracies\n",
        "  mean_accuracy = np.mean(accuracies)\n",
        "  std_accuracy = np.std(accuracies)\n",
        "\n",
        "  # Append the results to the rows list\n",
        "  rows.append([bits, mean_accuracy, std_accuracy])\n",
        "\n",
        "  # Print the results for this bit setting\n",
        "  print(f\"bits={bits}, mean_accuracy={mean_accuracy}, std_accuracy={std_accuracy}\")\n",
        "\n",
        "# Create a DataFrame from the results\n",
        "plot_df = pd.DataFrame.from_records(rows, columns=cols)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.14"
    },
    "vscode": {
      "interpreter": {
        "hash": "7174141f7d389c6cc80b1893a3784b426b3ad9b0a7c2e082f9b1ab8b9efe1539"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
