{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "SinglePhotonToyExample.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Mmq1t9gCDA_"
      },
      "source": [
        "#Single-photon Classification\n",
        "\n",
        "##Toy Example"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FgVqBR5D2H8r"
      },
      "source": [
        "from google.colab import files\n",
        "import math\n",
        "import numpy as np\n",
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow.compat.v2.keras as keras\n",
        "from matplotlib import pyplot as plt\n",
        "from matplotlib.font_manager import FontProperties\n",
        "from matplotlib import cm"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dKQbNNzeOmOL"
      },
      "source": [
        "np.set_printoptions(linewidth=200)\n",
        "rng = np.random.RandomState(seed=2)\n",
        "\n",
        "NUM_CLASSES = 2\n",
        "XDIM = 4\n",
        "XYDIM = 8\n",
        "# This forces the unitary transformation to be real (orthogonal.)\n",
        "# This does not affect the maximum achievable performance for this simple\n",
        "# example.\n",
        "USE_ORTHOGONAL = True\n",
        "\n",
        "c0 = np.array([[1, 1, 1, 0], [0, 1, 0, 0]], np.float64)\n",
        "c1 = np.array([[0, 1, 1, 1], [0, 0, 1, 0]], np.float64)\n",
        "\n",
        "x_train = np.stack([c0, c1], axis=0)\n",
        "y_train = np.array([0, 1], np.float64)\n",
        "x_test = x_train\n",
        "y_test = y_train\n"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vq7HSRp8i_nB"
      },
      "source": [
        "def quantum_states_from_xs(xs):\n",
        "  # This matches the expression under the square root in Formula (8).\n",
        "  amplitudes = np.sqrt(xs / np.einsum('byx->b', xs)[:, np.newaxis, np.newaxis])\n",
        "  return amplitudes.reshape(xs.shape[0], -1)\n",
        "  \n",
        "xq_train = quantum_states_from_xs(x_train)\n",
        "xq_test = quantum_states_from_xs(x_test)\n",
        "\n",
        "y_train_cat = keras.utils.to_categorical(y_train, NUM_CLASSES)\n",
        "y_test_cat = keras.utils.to_categorical(y_test, NUM_CLASSES)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ov7HK_uis4Tm"
      },
      "source": [
        "def accuracy_from_samples(samples, labels, decision):\n",
        "  decision_per_class = keras.utils.to_categorical(decision, NUM_CLASSES)\n",
        "  labels_per_class = keras.utils.to_categorical(labels, NUM_CLASSES)\n",
        "  # For every test-set image, we need to find out what fraction of photons\n",
        "  # would make our classifier predict the category correctly.\n",
        "  sample_total_photons_yielding_correct_cat = np.einsum(\n",
        "    'byx,yxc,bc->b', samples, decision_per_class, labels_per_class)\n",
        "  sample_fraction_photons_yielding_correct_cat = (\n",
        "    sample_total_photons_yielding_correct_cat / np.einsum('byx->b', samples))\n",
        "  return sample_fraction_photons_yielding_correct_cat.mean()\n"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bZ5Oy7qgC4Uz"
      },
      "source": [
        "# \"Classical baseline\" performance:\n",
        "# If a pixel arrives at (row, col), we need to know the most likely digit,\n",
        "# with probabilities as observed on the training set.\n",
        "x_train_intensity_per_cat = (\n",
        "    np.einsum('byx,bc->yxc', x_train, y_train_cat) +\n",
        "    # Trick: We add a tiny randomized-for-random-tie-breaking 'epsilon'\n",
        "    # to the brightnesses that only changes results for pixels that were\n",
        "    # dark in every single training example.\n",
        "    rng.uniform(low=0, high=1e-100, size=(XYDIM // XDIM, XDIM, NUM_CLASSES)))\n",
        "x_train_most_likely_pixel_cat = keras.utils.to_categorical(\n",
        "    x_train_intensity_per_cat.argmax(axis=2))"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v9IIvMHjvUmk",
        "outputId": "3bf1165a-c137-4c7c-c0c1-230750eaec1f",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "baseline_accuracy = np.round(accuracy_from_samples(\n",
        "    x_test, y_test, x_train_intensity_per_cat.argmax(axis=2)), 4)\n",
        "print('Classical-detection baseline accuracy: %.2f%%' %\n",
        "      (100 * baseline_accuracy))"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Classical-detection baseline accuracy: 75.00%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XJRN6wqDE_oO",
        "outputId": "081da031-92c9-40ed-9ae8-d86b09614b80",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "k_h0 = tf.keras.backend.variable(\n",
        "     rng.normal(size=(XYDIM, XYDIM), scale=0.01),\n",
        "     dtype='float64', name='H0')\n",
        "\n",
        "x_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)\n",
        "\n",
        "def get_loss_accuracies_p(h0, psi):\n",
        "  h0_T = tf.transpose(h0)\n",
        "  mult_imag = 0.0 if USE_ORTHOGONAL else 1.0\n",
        "  ih = tf.complex(h0 - h0_T, mult_imag * (h0 + h0_T))\n",
        "  u = tf.linalg.expm(ih)\n",
        "  rot = tf.math.real(tf.einsum('bp,qp->bq', tf.cast(psi, tf.complex128), u)).numpy()\n",
        "  u_psi = tf.reshape(\n",
        "      tf.einsum('bp,qp->bq', tf.cast(psi, tf.complex128), u),\n",
        "      (-1, NUM_CLASSES, XYDIM // NUM_CLASSES))\n",
        "  u_psi_r = tf.math.real(u_psi)\n",
        "  u_psi_i = tf.math.imag(u_psi)\n",
        "  probs = tf.einsum('bcs->bc', tf.math.square(u_psi_r) + tf.math.square(u_psi_i))\n",
        "  loss = x_loss(tf.eye(NUM_CLASSES, dtype=tf.float64), probs)\n",
        "  batch_quantum_accuracy = tf.linalg.trace(probs) / NUM_CLASSES\n",
        "  batch_manyphoton_accuracy = tf.math.reduce_mean(\n",
        "      keras.metrics.categorical_accuracy(\n",
        "          tf.eye(NUM_CLASSES, dtype=tf.float64), probs))\n",
        "  return loss, batch_quantum_accuracy, batch_manyphoton_accuracy, u\n",
        "\n",
        "  \n",
        "def train_step(optimizer, k_h0, psis):\n",
        "  c1 = tf.constant(1.0, dtype=tf.float64)\n",
        "  tape = tf.GradientTape()\n",
        "  with tape:\n",
        "    tape.watch(k_h0)\n",
        "    loss, accuracy_q, accuracy_m, u = get_loss_accuracies_p(k_h0, psis)\n",
        "  gradients = tape.gradient(loss, [k_h0])\n",
        "  optimizer.apply_gradients(zip(gradients, [k_h0]))\n",
        "  return loss, accuracy_q, accuracy_m, u\n",
        "\n",
        "\n",
        "def get_test_set_accuracies(h0):\n",
        "  loss, accuracy_q, accuracy_m, u = get_loss_accuracies_p(\n",
        "      h0, xq_test)\n",
        "  del loss, u  # Unused.\n",
        "  return accuracy_q, accuracy_m\n",
        "\n",
        "\n",
        "def train_a_model(num_steps=50):\n",
        "  \"\"\"Trains a model.\n",
        "  \n",
        "  Interrupting this function will produce a valid partially-trained model.\n",
        "  \"\"\"\n",
        "  optimizer = tf.keras.optimizers.SGD(learning_rate=0.3)\n",
        "  for ep in range(num_steps):\n",
        "    loss, accuracy_q, accuracy_m, u = train_step(\n",
        "          optimizer, k_h0, xq_train)\n",
        "    tr_u = tf.linalg.trace(u)\n",
        "    if ep % 20 == 0:\n",
        "      print('[Epoch=%d, loss=%.3f, acc_q=%.3f, acc_m=%.3f, tr_U=%s]' %\n",
        "            (ep, loss.numpy(), accuracy_q.numpy(), accuracy_m.numpy(),\n",
        "             np.round(tr_u.numpy(), 3)))\n",
        "      print('Test Set accuracies: %.4f / %.4f' % get_test_set_accuracies(k_h0))\n",
        "\n",
        "train_a_model()\n",
        "\n",
        "\n",
        "print('Model accuracies: %.4f / %.4f' % get_test_set_accuracies(k_h0))\n"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[Epoch=0, loss=0.830, acc_q=0.501, acc_m=0.500, tr_U=(7.995+0j)]\n",
            "Test Set accuracies: 0.8348 / 1.0000\n",
            "[Epoch=20, loss=0.069, acc_q=0.933, acc_m=1.000, tr_U=(6.825+0j)]\n",
            "Test Set accuracies: 0.9330 / 1.0000\n",
            "[Epoch=40, loss=0.069, acc_q=0.933, acc_m=1.000, tr_U=(6.825+0j)]\n",
            "Test Set accuracies: 0.9330 / 1.0000\n",
            "Model accuracies: 0.9330 / 1.0000\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YSGcWDw6vD4w"
      },
      "source": [
        "def PlotSamples(samples, filename, xlabel=None, image_labels=None, figsize=(4, 2)):\n",
        "  fig, axs = plt.subplots(1, 2, figsize=figsize)\n",
        "  norm = cm.colors.Normalize(vmax=0.8, vmin=0) # Boost a bit brightness to get clearer images.\n",
        "  font0 = FontProperties()\n",
        "  font1 = font0.copy()\n",
        "  font1.set_size('xx-small')\n",
        "  for i, ax in enumerate(axs):\n",
        "    ax.imshow(samples[i, :, :], cmap=plt.get_cmap('gray'), norm=norm, interpolation='nearest')\n",
        "    if image_labels is not None:\n",
        "      labels = image_labels[i, :, :]\n",
        "      n, m = labels.shape\n",
        "      for ii in range(0, n):\n",
        "        for jj in range(0, m):\n",
        "          if labels[ii, jj]:\n",
        "            ax.text(jj, ii, labels[ii, jj],\n",
        "                    horizontalalignment='center',\n",
        "                    verticalalignment='center')\n",
        "    ax.set_title('c = %d' % i)\n",
        "    ax.grid(True)\n",
        "    ax.set_yticklabels([])\n",
        "    ax.set_xticklabels([])\n",
        "    ax.set_xticks([])\n",
        "    ax.set_yticks([])\n",
        "  if xlabel:\n",
        "    fig.add_subplot(111, frameon=False)\n",
        "    plt.tick_params(labelcolor='red', top='off', bottom='off', left='off', right='off')\n",
        "    plt.grid(False)\n",
        "    plt.gca().set_xticklabels([])\n",
        "    plt.gca().set_yticklabels([])\n",
        "    plt.xlabel(xlabel)\n",
        "    plt.gca().xaxis.set_label_coords(0.5, 0)\n",
        "  plt.savefig(filename, bbox_inches='tight')\n",
        "  files.download(filename)"
      ],
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Sb-v54cpAOCN"
      },
      "source": [
        "def get_transformed_images(psi, h0=None, u=None):\n",
        "  if h0 is not None:\n",
        "    h0_T = tf.transpose(h0)\n",
        "    ih = tf.complex(h0 - h0_T, h0 + h0_T)\n",
        "    u = tf.linalg.expm(ih)\n",
        "  u_psi = tf.reshape(\n",
        "      tf.einsum('cp,qp->cq', tf.cast(psi, tf.complex128), u),\n",
        "      (NUM_CLASSES, XYDIM // XDIM, XDIM))\n",
        "  u_psi_r = tf.math.real(u_psi)\n",
        "  u_psi_i = tf.math.imag(u_psi)\n",
        "  return tf.math.square(u_psi_r) + tf.math.square(u_psi_i)\n",
        "\n"
      ],
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MrxZ7ShWU1On"
      },
      "source": [
        "# The simple suboptimal transformation presented in the paper.\n",
        "\n",
        "i4 = np.eye(XDIM, dtype=np.complex128)\n",
        "u_toy = math.sqrt(1 / 2.) * np.concatenate(\n",
        "    (np.concatenate((-i4, i4), axis=1),\n",
        "     np.concatenate((i4, i4), axis=1)), axis=0)\n"
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cxS5aCK7Bwfn",
        "outputId": "27023ec1-7d3c-4c78-f2d7-303cd541ecee",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 103
        }
      },
      "source": [
        "# The probability amplitude for the two symbols of the toy example.\n",
        "\n",
        "xq_train_initial = get_transformed_images(xq_train, u=np.eye(XYDIM))\n",
        "PlotSamples(xq_train_initial, 'toy_samples.png', xlabel='(a)', figsize=(3, 1))"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "download(\"download_94bff8f8-8552-47b6-a3e4-42f7db8ca710\", \"toy_samples.png\", 1644)"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAALkAAABWCAYAAACTvzcLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAE0ElEQVR4nO3cz4tVdRzG8ffjmDNNYUK5UdSKIiFI6MciCtFo0yKaVkISRYvJbYvAf0FI2rRpiDYStJIpM0IyXUZcJNEwWsRYCpFFU5lmap8W90xdYu7MPeM59xw/93nBgbn3fuecZ773mTPfuTP3KCIwy2xV0wHM6uaSW3ouuaXnklt6Lrml55Jbei65pTcyJVfXPkk/F9s+SWo6V0aSdko6JulXSXNN5xmZkgPTwBSwDXgIeBZ4tdFEef0BvAu83nQQaFHJJW2SdFDSheJM+1bFh3gJ2B8R5yLiPLAfeLniY9wU6p7riPgiIg4A31a535UqVXJJ03WEkDQGfAScBe4GNgLv9xn7jqT5JbbNfQ7zIHCy5/bJ4r6VZq5lLurOUHKuXxhkrls/FxEx8AZ0yowvsd/HgQvA6royANeBrT237wcC0Ar3V8tc1J2hzFzfaA7gaWCu6bloy3JlE3A2Iq7VeIyLwNqe22uBi1HM0AgZxly3yrIllzQtqSOpA5ypKcf3wGZJqwcY+6Wki0ts/ZYrX9H9pXPBtuK+lZq5gc+tykoyDDzXknYPONdtmIszCz0ttv+WL03/yC1OpGN018hvALcBE8ATFR9jD91v0o3ABroF39P01550rlcV+32G7tp/AljT1NfciuVKRFyn+5LefcB3wDlgV8WHeRs4BJwCTgOHi/tGypDmejtwGfgY2Fx8fKTiYwxMxXeeWVqtOJOb1cklt/RcckvPJbf0XHJLb5A/vvxrfHw8Jicn68rSGvPz85XtKyJW9O+8kip72WvdunVV7aq1Ll26xJUrVxad61Iln5ycZMeOHZWEarPZ2dmmI1RqFJ6z48eP933MyxVLzyW39FxyS88lt/RcckvPJbf0XHJLzyW39FxyS88lt/RcckvPJbf0XHJLzyW39FxyS88lt/RKXXelynerTE1NVbWrVr/JoQ3vDKpSlc8bVPvc9ZvrQa6HN033AvZmrVZcr3PBTETMwAAlLwbOFDtp5dnFDCAiHl3sfq/JLT2X3NJzyS09l9zSc8ktPZfc0nPJLT2X3NJzyS09l9zSc8ktPZfc0nPJLT2X3NJzyS09l9zSc8ktvbLv8bwAnK0vTjpbImL9Sj7Rc11a37kuVXKzm5GXK5aeS27pueSWnktu6bnklp5LbuktW3JJ05I6xXZgGKGWy9N0BmhHjjZkgHbkkHSgp6ed3kxl/xjU6XcprmFpQ4ayOSTdCnwCPBUR1xd5fA3wafH4tToy1KkNOZbK4OXKcLwCHFys4AAR8RdwFNg11FQjwiUfjt3AB5Jul3RU0glJpyQ91zNmthhnFVv2qrb/M1NLinLakAEGzFEsRe6NiDlJq4HnI+I3SXcBn0v6MLprxtPAY3VkGII25Oibwf+7UjNJG4DPImKrpFuAN4HtwN/AA8A9EfFDMfY8sDUifm8scEJlz+RW3mVgovh4N7AeeCQirkqa63kMYBz4c7jx8vOavGYR8QswJmkCuAP4sSj4TmDLwjhJdwI/RcTVhqKm5TP5cBwBngTeAw5JOgV0gK97xuwEDjeQLT2vyYdA0sPAaxHx4hJjDgJ7I+Kb4SUbDV6uDEFEnACOSRpb7PHiFZhZF7wePpNbej6TW3ouuaXnklt6Lrml55Jbev8ACek67C2RkZEAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 216x72 with 3 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EOUFDOfIMtcS",
        "outputId": "6e7829ac-a025-4d40-df8d-86a45c249554",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 103
        }
      },
      "source": [
        "# The probability amplitude for the optimal (learnt) transformation. This is not\n",
        "# included in the paper for lack of space.\n",
        "\n",
        "t_xq_train_optimal = get_transformed_images(xq_train, h0=k_h0).numpy()\n",
        "PlotSamples(t_xq_train_optimal, 'toy_projected_optimal.png', '(c)', figsize=(3, 1))"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "download(\"download_b74595c2-f1d6-4ae2-90e7-64404f21bb55\", \"toy_projected_optimal.png\", 1631)"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAALkAAABWCAYAAACTvzcLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAE20lEQVR4nO3dT4hVZRzG8e8zY3GnxEXXIBQtwkIIEiSIMIKijYuoVkES/VlMboOiQFoGCc7KNk3Rxk0gCf0TMdC1cRFFw1bhlEGgMQVOo+DMr8U91iBznXuv55x75jfPBw5475z7nmdennnn1Zl7VERgltnYqAOYVc0lt/RcckvPJbf0XHJLzyW39FxyS2/NlFxd+yX9WRz7JWnUuTKS9Kykk5L+lnRx1HnWTMmBSeAlYAfwOPAC8PZIE+U1B3wBvDfqINCgkkvaIumIpMvFSvtJyZd4HZiKiEsR8TswBbxR8jVWharnOiJ+jIhDwC9ljjusgUouabKKEJLGge+AGeAhYDPwZY9zP5f0122OrT0u8xhwdsnjs8Vzw2auZC6qzjDgXL/az1w3fi4iou8D6Axy/gDjPgVcBtZVlQFYALYvefwIEICGHK+Suag6wyBzfac5gOeBi6Oei6ZsV7YAMxFxo8JrXAU2LHm8AbgaxQytIXXMdaOsWHJJk5I6kjrAhYpy/AZslbSuj3PPSLp6m6PXduUnun/pvGlH8dywpu/gtWUZJkPfcy1pT59z3YS5uHCzp8Xx//Zl1N9yi4V0nO4e+QBwL9ACdpV8jb10v0g3A5voFnzvqD/3pHM9Voy7m+7evwXcParPuRHblYhYoPtPetuAX4FLwCslX+ZT4FvgHHAe+L54bk2paa6fAeaBo8DW4s/HS75G31R85Zml1YiV3KxKLrml55Jbei65peeSW3r9/PDlP+vXr492u13KhWdnZ0sZpwpzc3OljLO4uEhEDPXrvJJibKycNWhxcbGUcZqu11wPVPJ2u82+fftKCXT48OFSxgFYWFgobSyAU6dOlTLOtWvXhn7t2NgYExMTpeS4fv16KeNU4caN6n+7wNsVS88lt/RcckvPJbf0XHJLzyW39FxyS88lt/RcckvPJbf0XHJLzyW39FxyS88lt/RcckvPJbf0BrrviqRG3qRl48aNpY535cqV0sa6k3cGlZWh1WqVNRTz8/OljQVQ5v+DMPQ7g4p7yo381rxmKynu13nTdERMg1fyZXkl7201ruTek1t6Lrml55Jbei65peeSW3ouuaXnklt6Lrml55Jbei65peeSW3ouuaXnklt6Lrml55Jbei65peeSW3qDvjPoMjBTXZx0HoyI+4d5oed6YD3neqCSm61G3q5Yei65peeSW3ouuaXnklt6Lrmlt2LJJU1K6hTHoTpCrZRn1BmgGTmakAGakUPSoSU97SzNNOgPgzoR8UQlKVdRhmFySJoAjgHPRcTCMh8/AByNiBNVZahKE3LcLoO3K/V5CziyXMELB4EPasyzZrjk9dkDfA0g6X1J5ySdlfQxQETMAG1JD4wyZEYr3rr5FtOVpBhMEzLAADkk3Q08HBEXJe0GXgSejIh/JN235NTTwC7gq7IzVKwJOXpm8O+u1EDSJuBERGyXNAX8HBGfLXPeR8AfEXGw9pCJebtSj3mgn5uEt4pzrUQueQ0iYhYYl9QCfgDelHQPwC3blUeB8yOImJpLXp/jwNMRcQz4BuhIOgO8CyDpLmAb0Ok9hA3De/KaSNoJvBMRr/X4+MvAzoj4sN5k+Xklr0lEnAZOShrvcco6YKrGSGuGV3JLzyu5peeSW3ouuaXnklt6Lrml9y9DRz8clDbwmQAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 216x72 with 3 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rabj_Bq2ADHf",
        "outputId": "d1295cd0-9796-4bd6-a328-0b62bd09dc03",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "# Accuracy for the optimal transformation.\n",
        "\n",
        "optimal_accuracy = np.round(accuracy_from_samples(\n",
        "    t_xq_train_optimal, y_train, t_xq_train_optimal.argmax(axis=0)), 4)\n",
        "print('Optimal accuracy: %.2f%%' % (100 * optimal_accuracy))"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Optimal accuracy: 93.27%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6gxt6LBCUM3X",
        "outputId": "453e2557-f6dd-45ed-8388-db6fb4724923",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 103
        }
      },
      "source": [
        "# The probability amplitude for the suboptimal transformation presented in the\n",
        "# paper.\n",
        "\n",
        "t_xq_train_toy = get_transformed_images(xq_train, u=u_toy)\n",
        "PlotSamples(t_xq_train_toy, 'toy_projected_simple.png', '(b)', figsize=(3, 1))"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "download(\"download_572e3ba5-b761-4fad-9cf2-7d6e658f5c0c\", \"toy_projected_simple.png\", 1630)"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAALkAAABWCAYAAACTvzcLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAE1ElEQVR4nO3cz4tVdRzG8ffT3LEZAiEiCEWbfoEUKFQIUUhRmxZBtJGSKFxMEgS16B9opZCLwEUNEcJsWhmZtpBKgoKoIRINbVNj2UqiAkNhrnxa3DM4hXO99/o99xw/87zgwP1x5pxnPvPcM1/HmauIwCyzm5oOYFY3l9zSc8ktPZfc0nPJLT2X3NJzyS29NVNy9eyT9Ee17ZOkpnNlJOkJSccl/S1psek8a6bkwCzwLLAN2Ao8A7zSaKK8/gE+AN5sOgi0qOSSNkk6JOl8daU9UPgULwH7I+JcRPwO7AdeLnyOG0Lds46IbyNiHvi55HFHNVTJJc3WEULSBHAEOAvMABuBD1fZ931Jf/XZNq9ymgeAEyvun6geGzVzLbOoO8OQs35hkFm3fhYRMfAGLAyz/xDHfQQ4D3TqygBcBrasuH8fEIBGPF4ts6g7wzCzvt4cwFPAYtOzaMtyZRNwNiK6NZ7jArB+xf31wIWoJrSGjGPWrXLNkkualbQgaQE4XVOO34DNkjoD7PuDpAt9ttWWKz/S+0fnsm3VY6Oau46PLWWUDAPPWtKuAWfdhlmcXu5ptV1ZvjT9Lbe6kE7QWyO/DdwCTAGPFj7HHnov0o3ABnoF39P055501jdVx32a3tp/CljX1OfciuVKRFym9yO9e4FfgXPAzsKneQ/4BDgJnAKOVo+tKWOa9Q7gIvApsLm6fazwOQam6pVnllYrruRmdXLJLT2X3NJzyS09l9zSG+Q/X67s3OnE5ORkkRNfunSpyHEApqamih0LymaLiJF+nVdSsR97lZxPydlAuWxLS0t0u92rznqokk9OTjIzM1Mk1JkzZ4ocByiWaVnJbG1Qcj6lZ1Mq2+Li4qrPebli6bnklp5Lbum55JaeS27pueSWnktu6bnklp5Lbum55JaeS27pueSWnktu6bnklp5Lbum55JbeUH80UdLBgweLHWvv3r3FjmX9lfy6wXi+doO8H94svTewp9Np7DVh1le326V6v85lcxExBwOUvNpxDmB6etpvt2Wt1Ol0WFpaevhqz3lNbum55JaeS27pueSWnktu6bnklp5Lbum55JaeS27pueSWnktu6bnklp5Lbum55JaeS27pueSWnktu6Sli8D/2kXQeOFtfnHTujIjbR/lAz3poq856qJKb3Yi8XLH0XHJLzyW39FxyS88lt/RcckvvmiWXNCtpodrmxxHqWnmazgDtyNGGDNCOHJLmV/R04T+ZImLgDVgYZv86tjZkGDYHMA18CTwJHFlln8+AW7PPookMXq6Mx27gEHC5zz7zwKvjibO2uOTjsQv4uLq9XtJRST9JelfS8tfgMPB8M/FyG7bkc7WkGE4bMsCAOSStA+6OiMXqoe3Aa8D9wD3AcwAR8Sdws6TbSmcYgzbkWDWDf3elZpI2AF9ExBZJjwNvRcSO6rndwNaIeL26/zWwJyJONhY4IS9X6ncRmFpx//9XlZX3p6r9rSCXvGbVMmRC0nLRt0u6q1qL7wS+ApAk4A5gsZGgibnk43EMeKy6/R1wADgN/AJ8VD3+EPBNRHTHHy83r8nHQNKDwBsR8WKffd4BDkfE5+NLtjb4Sj4GEfE9cFzSRJ/dTrng9fCV3NLzldzSc8ktPZfc0nPJLT2X3NL7F24NFLQNgPOPAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 216x72 with 3 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YX_5Sv9cYem0",
        "outputId": "96bbcd60-dc80-4b18-d491-80696b106085",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 150
        }
      },
      "source": [
        "# The values of the probability amplitude for the suboptimal transformation.\n",
        "\n",
        "frac_8 = '$1/\\sqrt{8}$'\n",
        "frac_2 = '$1/\\sqrt{2}$'\n",
        "labels = np.array([[[frac_8, '$0$', frac_8, '$0$'], [frac_8, frac_2, frac_8, '$0$']],\n",
        "                   [['$0$', frac_8, '$0$', frac_8], ['$0$', frac_8, frac_2, frac_8]]])\n",
        "PlotSamples(np.ones((2, 2, 4)), 'toy_projected_values_simple.png', '(b)', image_labels=labels, figsize=(6, 2))"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "download(\"download_347afaf1-876f-4c08-a3e1-b16381b00a96\", \"toy_projected_values_simple.png\", 3708)"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAACFCAYAAABlnyLJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANF0lEQVR4nO3dX2xUZ3rH8d+b2O72gj+lxsSMcc10wHLGwcKYRRUVktUV8UVBKwNaXFJtRSSKFKGqEb1NFSmV2IsVicpFiqJyEaRyVeSwFNiCHJQaEdcGBRZMC6yd2uaipoIGGIw99OkFg8cT28NMPGfeM/b3I43kc3zw++jhmR+HM+M5zswEACi+13wXAAALFQEMAJ4QwADgSV4B7JzbF1QhpYZepNGLNHqRRi/SZutFvmfANDSNXqTRizR6kUYv0goSwACAAil71QGpU+eX6d0fbDkl5ajvAkKEXqTRizR6kdbvnOudsn3UzI463gcMAH5wCcIz98IvnHP/k3r8wjnnfNcFzIVzrtU51+Wc+1/n3KDvesKKAPZvn6SfSmqStE7SNkl/6bUiYO6eSPpHSX/ju5AwI4BfwTm3yjn3z8650dQZ6pECL/FzSb80s2EzG5H0S0l/UeA1gAxBz7WZ9ZjZ55J+W8ifO98QwFk4516X9CtJ30qqkxSRdGKWY//MOfcwy6N2lmXikr6Zsv1Nah8QiCLNNXLAi3BZOOf+SNIXkqrNLBnQGs8lxc3sVmp7jaT/lPSa8ZeDABRjrqes9RNJn5lZXZDrlCrOgLNbJenbgIf0saTFU7YXS3pM+CJAxZhr5IAAzm5IUq1zLpf3S+9xzj3O8pjtv2o39OIFuJeaUvuAoBRjrpEDLkFkkbpWdkXSv0r6W0nPJW0ws+4CrrFf0l9J+okkS63192b2aaHWAKYq0ly/JqlCUqukTyXVS/o/Mxsv1BrzAWfAWZjZc714W1hM0n9JGpb0swIv8w+STkm6Luk3kk6n9gGBKNJcb5H0VNK/SKpNff3rAq9R8jgDBgBPOAMGAE8IYADwhAAGAE8IYADw5JXvA5yqsrLS6urqAioFC93g4KDu379f9E+CY64RpGxznVcA19XVqbe399UHAj9AS0uLl3WZawQp21xzCQIAPCGAAcATAhgAPCGAAcCTVwawc26fc67XOdc7OjpajJoAYF6ZmqOpxz4ph3dBmNlRpW4v3dLSwgdHAECepuboVFyCAABPCGAA8IQABgBPCGAA8CSvX0UOgnO5/+o/Hx5PvzB3zFB+guyX1zPga9eu6ebNmzKznB4LHf3CXDFD+Qm6X14D+NatW2poaPBZQkmhX5grZig/QffLawDzL2x+6BfmihnKT9D9KlgA7927V1VVVWpsbJz2vf3796u7O/OO16Ojo6qqqsrYd/jwYcXjcTU2Nqqjo0NjY2OFKu+Vzp49q/r6esViMR06dCjw9ejX/MMM5Yd+STlf2zAzbdiwwWZz8eJF6+vrs3g8Pu17TU1NlkwmM/adPHnSxsfHJ7eHh4etrq7OEomEmZnt2rXLjh07Nut6hZRMJi0ajdrdu3ft2bNntm7dOrtx40aga9Kv6VLzlddMFuKRba5zxQzlZyH1K9tcF+wMeMuWLVq2bNm0/f39/Vq7dq0ePXqkAwcOTO6fmJhQeXl5xrHJZFJPnz5VMplUIpHQypUrC1VeVj09PYrFYopGo6qoqNDu3bvV2dkZ6Jr0a35hhvJDv14I/BrwmTNn1NbWpqVLlyoWi+n69euamJhQRUVFxnGRSEQHDx5UbW2tqqurtWTJEm3dujXo8iRJIyMjWrVq1eR2TU2NRkZGirL299Gv0hSmnjBD+fHZr8AD+Ny5c2pra5Mkbdu2TadOnVJ3d7c2b96ccdyDBw/U2dmpgYEB3bt3T0+ePNHx48eDLi906BfmihnKj89+BRrAiURCDx8+nDxNj0ajunPnjkZHR1VZWZlx7Pnz57V69WotX75c5eXlam9v16VLl4Isb1IkEtHQ0NDk9vDwsCKRSFHWnop+la6w9IQZyo/vfgUawF1dXWptbc3Yt2LFCiUSiWnH1tbW6vLly0okEjIzXbhwoWjvV9y4caNu376tgYEBjY+P68SJE9q+fXtR1p6KfpWusPSEGcqP734V7FeROzo69OWXX+r+/fuqqanRhx9+qKtXr2rnzp0Zx7W3t2vRokXT/vymTZu0c+dONTc3q6ysTOvXr9e+ffsKVV5WZWVlOnLkiN5++209f/5ce/fuVTweD3RN+jW/MEP5oV8vOMvjjcYtLS2Wz+27m5ub9fXXX097JREzW+j9amlpUW9vb+6/eF+4dfOa6zBb6DOUr2L0K9tcB/phPFeuXAnyx8879AtzxQzlx3e/+DhKAPCEAAYAT7grMgAEjLsiA4Anxl2RASBcCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPuCknAAQstDfldM7lfKxZ8e4JSl2Yr8I6Q/O5rlDelPPatWu6efOmzCynB3WFsy6UjrDO0EKty2sA37p1Sw0NDT5LmBF1Yb4K6wwt1Lq8BnBYz9KoC/NVWGdoodZVsADeu3evqqqq1NjYOO17+/fvV3d3d8a+0dFRVVVVZew7fPiw4vG4Ghsb1dHRobGxMe91DQ0NqbW1VW+++abi8bg++eSTOddUiLqkYPqVq7Nnz6q+vl6xWEyHDh0q2rphVuyehHWGeM7lIddrG2amDRs22GwuXrxofX19Fo/Hp32vqanJkslkxr6TJ0/a+Pj45Pbw8LDV1dVZIpEwM7Ndu3bZsWPHZl0vV3Ot6969e9bX12dmZt99952tWbPGbty44b2uoPqVi2QyadFo1O7evWvPnj2zdevWFaQnqfnKayYL8cg217kKqifZhHWGeM5lyjbXBTsD3rJli5YtWzZtf39/v9auXatHjx7pwIEDk/snJiZUXl6ecWwymdTTp0+VTCaVSCS0cuVK73VVV1erublZkrRo0SI1NDRoZGTEe11SMP3KRU9Pj2KxmKLRqCoqKrR79251dnYWZe2w8tGTsM4Qz7ncBX4N+MyZM2pra9PSpUsVi8V0/fp1TUxMqKKiIuO4SCSigwcPqra2VtXV1VqyZIm2bt3qva6pBgcHdfXqVW3atMl7XcXu11QjIyNatWrV5HZNTU1BniClLEw9CesM8ZybLvAAPnfunNra2iRJ27Zt06lTp9Td3a3NmzdnHPfgwQN1dnZqYGBA9+7d05MnT3T8+HHvdb30+PFj7dixQx9//LEWL17sva5i9wulI6wzxHNuukADOJFI6OHDh5On6dFoVHfu3NHo6KgqKyszjj1//rxWr16t5cuXq7y8XO3t7bp06ZL3uqQX/xXZsWOH9uzZo/b29kBqyreuYvbr+yKRiIaGhia3h4eHFYlEirJ2WIWlJ2GdIZ5zMws0gLu6utTa2pqxb8WKFUokEtOOra2t1eXLl5VIJGRmunDhQmDvv8unLjPTu+++q4aGBr3//vuB1PND6ipmv75v48aNun37tgYGBjQ+Pq4TJ05o+/btRVk7rMLSk7DOEM+5Wcz26txMj2yvFu/evdveeOMNKysrs0gkYp999pm999571tXVlXFcT0+P9ff3z/gzPvjgA6uvr7d4PG7vvPOOjY2NvfIVxleZa11fffWVSbK33nrLmpqarKmpyU6fPu29LrNg+pWr06dP25o1aywajdpHH31UkJ9Zyu+CMAumJ9mEdYZ4zmXKNteBDur69esz3sYRFtQVTqUewGEQ1hlayHVlm+tXfhjPXFy5ciXIH/+DURfmq7DOEHXNjM8DBgBPCGAA8IQABgBPCGAA8IQABgBPCGAA8IQABgBPuCsyAAQstHdFBoD5zsJ4V2QAWMgIYADwhAAGAE+cWe6XdZ1zo5K+Da4cLHB/YGbLi70oc42AzTrXeQUwAKBwuAQBAJ4QwADgCQEMAJ4QwADgCQEMAJ4QwADgCQEMAJ4QwADgCQEMAJ4QwADgCQEMAJ4QwADgCQEMAJ4QwADgCQEMAJ7kdVdk59znxSiqFLy8qymK3wvn3O865y465/7EOferWY4575z7vWLWlVqXuUihF2nOuc9nuityvnfE6DWzlsCqLCH0Iq3YvXDOvacXd/T+RtJBM/vTGY75uaQaM/u7YtWVWpe5SKEXabP1gksQKEV7JHWmvl7snDvtnPsP59ynzrmXM/2FpA4/5QG5IYBRUpxzFZKiZjaY2vVjSQckvSnpDyW1S5KZPZD0O8653/dRJ5CLfAP4aCBVlCZ6kVbMXlRKejhlu8fMfmtmzyX9k6Q/nvK9/5a0soi1SczFVPQibcZe5BXAZkZDU+hFWpF78VTSj6Yu//1ypnz9o9TxRcNcpNGLtNl6wSUIlJTUpYXXnXMvQ/jHzrnVqWu/P5P0b5LknHOS3pA06KVQIAcEMErRr5W+1PDvko5I6pc0IOlkav8GSZfNLFn88oDc5PU2NCAMnHPNkv7azP48yzGfSPrCzC4UrzIgP5wBo+SY2RVJXc6517Mc9hvCF2HHGTAAeMIZMAB4QgADgCcEMAB4QgADgCcEMAB4QgADgCf/D7Jmra9H4+MwAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x144 with 3 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3StCIJEQaLdK",
        "outputId": "9a8b3fa8-8ff9-4905-8e85-b611ec35cea6",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "toy_accuracy = np.round(accuracy_from_samples(\n",
        "    t_xq_train_toy, y_train, np.argmax(t_xq_train_toy, axis=0)), 4)\n",
        "print('Toy example accuracy: %.2f%%' % (100 * toy_accuracy))"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Toy example accuracy: 87.50%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yW_5udoNih8G"
      },
      "source": [
        "Max accuracy computed analytically is the following:\n",
        "\n",
        "$a_M=(\\sqrt{1-\\cos^2\\alpha} + 1)/2$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kCULFsrIa9g5",
        "outputId": "ffde9363-5d42-48a8-b00d-a1de49ecf4ab",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "cos_a = np.dot(xq_train[0], xq_train[1])\n",
        "max_accuracy = (np.sqrt(1 - cos_a ** 2) + 1) / 2\n",
        "print('Max possible accuracy %.2f%%' % (100 * max_accuracy))"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Max possible accuracy 93.30%\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}