{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "SinglePhotonFigures.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python2",
      "display_name": "Python 2"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2E3uzJVKjxsB"
      },
      "source": [
        "#Single-photon Classification\n",
        "\n",
        "##Paper figures (see SinglePhotonToyExample.ipynb for the content of Fig. 1) "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dKQbNNzeOmOL"
      },
      "source": [
        "from colorsys import hls_to_rgb\n",
        "from google.colab import files\n",
        "import math\n",
        "from matplotlib import pyplot as plt\n",
        "from matplotlib.font_manager import FontProperties\n",
        "from matplotlib import cm\n",
        "import numpy as np\n",
        "from scipy.linalg import expm\n",
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow.compat.v2.keras as keras"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tr3Jq7P29vQR"
      },
      "source": [
        "np.set_printoptions(linewidth=200)\n",
        "rng = np.random.RandomState(seed=2)\n",
        "\n",
        "# The computationally expensive part is exponentiating the Lie algebra element\n",
        "# to obtain the SU(790) \"de-mixing\" basis transformation.\n",
        "# This is then applied to all items in the batch uniformly. Since the rest\n",
        "# of the computation is easy, this suggests we should go with large batch sizes.\n",
        "BATCH_SIZE = 8192\n",
        "NUM_CLASSES = 10\n",
        "XDIM = 28\n",
        "XYDIM = 784\n",
        "XYDIM_EXT = 790\n",
        "\n",
        "DATASET_MNIST = tf.keras.datasets.mnist\n",
        "DATASET_FASHION = tf.keras.datasets.fashion_mnist"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "65v25La5-hU7"
      },
      "source": [
        "def UploadFile(user_message):\n",
        "  print(user_message)\n",
        "  uploaded_files = None\n",
        "  while True:\n",
        "    uploaded_files = files.upload()\n",
        "    if (len(uploaded_files) > 2 or\n",
        "        not all(x.endswith('.npy') for x in uploaded_files.keys())):\n",
        "      print('Please only upload exactly one .npy file')\n",
        "      continue\n",
        "    try:\n",
        "      uploaded_h0 = np.load(list(uploaded_files)[0]).astype(np.float64)\n",
        "      if uploaded_h0.shape != (XYDIM_EXT, XYDIM_EXT):\n",
        "        print('Matrix size mismatch.')\n",
        "        continue\n",
        "      break\n",
        "    except Exception as e:\n",
        "      print('Parse error:', e)\n",
        "  return uploaded_h0"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lq7RSKZj_ovi"
      },
      "source": [
        "# Upload trained models to visualize confusion matrixes and projectors.\n",
        "# for fashion_mnist use fashion_mnist_trained_U790.npy\n",
        "# for mnist use mnist_trained_U790.npy\n",
        "\n",
        "mnist_h0 = UploadFile(\"Upload model for MNIST dataset.\")\n",
        "fashion_mnist_h0 = UploadFile(\"Upload model for Fashion-MNIST dataset.\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "86_Sq--ZhKHI"
      },
      "source": [
        "def brightness_normalized_examples(examples):\n",
        "  # Normalize each example such that total pixel-brightness sums to 1.0\n",
        "  # We need to do this as we are (both in the classical and quantum case)\n",
        "  # interested in the first photon that passes the image-filter.\n",
        "  # For a 'bright' digit (like an \"8\"), it will take less time to see the\n",
        "  # first photon than for an equally illuminated 'dark' digit (like an '1'),\n",
        "  # but for each example, we care about the probabilities of the first photon\n",
        "  # to come from each of its pixels - hence, these probabilities must sum to 1.\n",
        "  return examples / (\n",
        "      np.einsum('byx->b', examples)[:, np.newaxis, np.newaxis])\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vq7HSRp8i_nB"
      },
      "source": [
        "def ReadSets(dataset, num_classes):\n",
        "  (x_train255, y_train), (x_test255, y_test) = dataset.load_data()\n",
        "  x_train = brightness_normalized_examples(x_train255.astype(float))\n",
        "  x_test = brightness_normalized_examples(x_test255.astype(float))\n",
        "  y_train_cat = keras.utils.to_categorical(y_train, num_classes)\n",
        "  y_test_cat = keras.utils.to_categorical(y_test, num_classes)\n",
        "  return x_train, y_train_cat, x_test, y_test_cat\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m7zOdHKcM-Hi"
      },
      "source": [
        "\n",
        "def MostLikelyCategory(x_train, y_train_cat):\n",
        "  # \"Classical baseline\" performance:\n",
        "  # If a pixel arrives at (row, col), we need to know the most likely digit,\n",
        "  # with probabilities as observed on the training set.\n",
        "  x_train_intensity_per_cat = np.einsum('byx,bc->yxc', x_train, y_train_cat)\n",
        "  \n",
        "  # Trick: We add a tiny randomized-for-random-tie-breaking 'epsilon'\n",
        "  # to the brightnesses that only changes results for pixels that were\n",
        "  # dark in every single training example.\n",
        "  p_class_given_pixel = keras.utils.to_categorical(\n",
        "    x_train_intensity_per_cat.argmax(axis=2))\n",
        "  x_train_most_likely_pixel_cat = (x_train_intensity_per_cat +\n",
        "    rng.uniform(low=0, high=1e-100, size=(28, 28, 10))).argmax(axis=2)\n",
        "  x_train_prob = x_train_intensity_per_cat.max(axis=2)\n",
        "  return (x_train_most_likely_pixel_cat, x_train_prob)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4-b65wzs8fUt"
      },
      "source": [
        "def Confusion(x_test, y_test_cat, x_train_most_likely_pixel_cat):\n",
        "  scale_sample = np.einsum('byx->b', x_test)\n",
        "  p_pixel_given_class = np.einsum('byx,bc->yxc',\n",
        "                                  x_test / scale_sample.reshape([-1, 1, 1]),\n",
        "                                  y_test_cat) / x_test.shape[0]\n",
        "  p_est_class_given_pixel = keras.utils.to_categorical(\n",
        "      x_train_most_likely_pixel_cat)\n",
        "  p_confusion = np.einsum('yxc,yxd->dc',\n",
        "                          p_est_class_given_pixel,\n",
        "                          p_pixel_given_class)\n",
        "  return p_confusion"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "flRIPxhuNwTJ"
      },
      "source": [
        "def quantum_states_from_xs(xs):\n",
        "  amplitudes = np.sqrt(\n",
        "      xs / np.einsum('byx->b', xs)[:, np.newaxis, np.newaxis])\n",
        "  return np.pad(amplitudes.reshape(xs.shape[0], -1),\n",
        "                  ((0, 0), (0, XYDIM_EXT - XYDIM)),\n",
        "                   'constant', constant_values=(0, 0)) "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vWYWlqKIMnvg"
      },
      "source": [
        "def parameters_to_unitary(h0):\n",
        "  h0_T = np.transpose(h0)\n",
        "  ih = np.zeros_like(h0, dtype=np.complex128)\n",
        "  ih.real = h0 - h0_T\n",
        "  ih.imag = h0 + h0_T\n",
        "  u = expm(ih)\n",
        "  return u"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qApxP8LSNaj3"
      },
      "source": [
        "def ConfusionOurs(x_test, y_test_cat, h0):\n",
        "  u = parameters_to_unitary(h0)\n",
        "  xq_test = quantum_states_from_xs(x_test)\n",
        "  u_psi = np.reshape(\n",
        "      np.einsum('bp,qp->bq', xq_test.astype(np.complex128), u),\n",
        "      (-1, NUM_CLASSES, XYDIM_EXT // NUM_CLASSES))\n",
        "  u_psi_r = np.real(u_psi)\n",
        "  u_psi_i = np.imag(u_psi)\n",
        "  probs = np.einsum('bcs->bc', np.square(u_psi_r) + np.square(u_psi_i))\n",
        "  p_confusion = np.einsum('bc,bd->dc', probs, y_test_cat) / x_test.shape[0]\n",
        "  return p_confusion"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vuonexr6Umtg"
      },
      "source": [
        "def PlotPixelMap(x_most_likely_pixel_cat, prob, filename, xlabel=None):\n",
        "  fig = plt.figure(figsize=(6, 6))\n",
        "  n, m = x_most_likely_pixel_cat.shape\n",
        "  plt.axis([-1, n, m, -1])\n",
        "  #plt.imshow(prob)\n",
        "  font0 = FontProperties()\n",
        "  font1 = font0.copy()\n",
        "  font1.set_size('xx-small')\n",
        "  plt.rcParams['axes.facecolor'] = 'white'\n",
        "  for (r, c), value in np.ndenumerate(x_most_likely_pixel_cat):\n",
        "    if prob[r, c] > 0:\n",
        "      plt.text(c, r, str(value), horizontalalignment='center',\n",
        "               verticalalignment='center', fontproperties=font1)\n",
        "  #plt.axis('off')\n",
        "  ax = plt.gca()\n",
        "  ax.set_yticklabels([])\n",
        "  ax.set_xticklabels([])\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  if xlabel:\n",
        "    ax.set_xlabel(xlabel)\n",
        "  plt.savefig(filename, bbox_inches='tight')\n",
        "  files.download(filename)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w9dDcSrzC3II"
      },
      "source": [
        "def PlotConfusionMatrix(p_confusion, filename, xlabel=None):\n",
        "  fig = plt.figure(figsize=(6, 6))\n",
        "  font0 = FontProperties()\n",
        "  font1 = font0.copy()\n",
        "  font1.set_size('xx-small')\n",
        "  plt.rcParams['axes.facecolor'] = 'white'\n",
        "  ax = fig.gca()\n",
        "  ax.grid(False)\n",
        "  #plt.xticks([])\n",
        "  #plt.yticks([])\n",
        "  ax.set_xticks(range(0, 10))\n",
        "  ax.set_yticks(range(0, 10))\n",
        "  ax.set_ylabel('True')\n",
        "  ax.set_title('Predicted')\n",
        "  if xlabel:\n",
        "    ax.set_xlabel(xlabel)\n",
        "  im = plt.imshow(p_confusion, interpolation='none', norm=plt.Normalize(vmin=0, vmax=0.06))\n",
        "  fig.colorbar(im, ax=ax, ticks=[0, 0.012, 0.024, 0.036, 0.048, 0.060], shrink=0.8)\n",
        "  plt.savefig(filename, bbox_inches='tight')\n",
        "  files.download(filename)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EKopaAbpMLk_"
      },
      "source": [
        "def ComputeProjectors(h0):\n",
        "  cat_projectors = [np.diag(\n",
        "        np.einsum('c,x->cx',\n",
        "                  np.array([c == n for n in range(10)]),\n",
        "                  np.ones(XYDIM_EXT // NUM_CLASSES)).reshape(-1))\n",
        "              for c in range(NUM_CLASSES)]\n",
        "  u = parameters_to_unitary(h0)\n",
        "  u_inv = u.T.conj()\n",
        "  projectors = np.stack([np.matmul(u_inv, np.matmul(cp, u))\n",
        "    for cp in cat_projectors], axis=0)\n",
        "  return projectors"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NS4RFHVORVsF"
      },
      "source": [
        "def DecomposeSamples(h0, x_test, test_indices):\n",
        "  projectors = ComputeProjectors(h0)\n",
        "  xq_test = quantum_states_from_xs(x_test[test_indices, :, :])\n",
        "  xq_decomposed = np.einsum('cyx,bx->bcy', projectors, xq_test)\n",
        "  xq_decomposed_res = np.reshape(xq_decomposed[:, :, :XYDIM], (-1, NUM_CLASSES, XYDIM // XDIM, XDIM))\n",
        "  return xq_decomposed_res\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CSOLkxjD2Ykx"
      },
      "source": [
        "def ComplexToRGB(sample, mag_scale):\n",
        "  height, width = sample.shape\n",
        "  mag = np.absolute(sample) * mag_scale\n",
        "  h = (np.angle(sample) + math.pi) / (2 * math.pi) + 0.5\n",
        "  l = 1.0 - 1.0/(1.0 + mag ** 0.8)\n",
        "  s = 0.8 * np.ones((height, width), dtype=np.float64)\n",
        "  r, g, b = np.vectorize(hls_to_rgb)(h, l, s)\n",
        "  rgb = np.stack((r, g, b), axis=2)\n",
        "  return rgb"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2btI1quiPuLV"
      },
      "source": [
        "#Test Complex Visualization\n",
        "\n",
        "To visualize complex probability amplitudes, we use brightness for magnitude and hue for phase. This is visualized in the following figure. \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3okGuDYw7YNl"
      },
      "source": [
        "ex_g = np.zeros((100, 100), np.complex)\n",
        "eg_g_magnitude = np.reshape(np.linspace(0, 0.1, 100, endpoint=True), (-1, 1))\n",
        "ex_g.real = eg_g_magnitude * np.cos([np.linspace(-math.pi, math.pi, 100)])\n",
        "ex_g.imag = eg_g_magnitude * np.sin([np.linspace(-math.pi, math.pi, 100)])\n",
        "ex_g_rgb = ComplexToRGB(ex_g, 10.0)\n",
        "fig = plt.figure(figsize=(4, 4))\n",
        "plt.imshow(ex_g_rgb)\n",
        "ax = fig.gca()\n",
        "ax.grid(False)\n",
        "ax.set_yticks([0, 50, 100])\n",
        "ax.set_yticklabels([0, 0.05, 0.1])\n",
        "ax.set_xticks([0, 50, 100])\n",
        "ax.set_xticklabels(['$-\\pi$', '$0$', '$\\pi$'])\n",
        "ax.set_ylabel('Magnitude')\n",
        "ax.set_xlabel('Phase')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lpbpg83bxYVC"
      },
      "source": [
        "def PlotSamples(samples, dec_samples, filename, figsize=(16, 8), xlabel=None):\n",
        "  num_images, num_classes, height, width = dec_samples.shape\n",
        "  fig, axs = plt.subplots(num_images, 12, figsize=figsize,\n",
        "                          gridspec_kw={'width_ratios': [1] * 11 + [0.25],\n",
        "                                       'height_ratios': [1] * num_images})\n",
        "  norm = cm.colors.Normalize(vmax=1.03, vmin=0)\n",
        "  norm_dec = cm.colors.Normalize(vmax=0.01, vmin=0)\n",
        "  font0 = FontProperties()\n",
        "  font1 = font0.copy()\n",
        "  font1.set_size('xx-small')\n",
        "  bar = np.zeros((7, 28), np.complex)\n",
        "  bar_magnitude = 0.07\n",
        "  bar.real = bar_magnitude * np.cos([np.linspace(-math.pi, math.pi, 28)])\n",
        "  bar.imag = bar_magnitude * np.sin([np.linspace(-math.pi, math.pi, 28)])\n",
        "  bar_rgb = ComplexToRGB(bar.T, 6.0)\n",
        "  for im_index in range(num_images):\n",
        "    ax = axs[im_index, 0]\n",
        "    ax.imshow(samples[im_index, :, :], cmap=plt.get_cmap('gray'), norm=norm, interpolation='nearest')\n",
        "    ax.grid(False)\n",
        "    ax.set_yticklabels([])\n",
        "    ax.set_xticklabels([])\n",
        "    ax.set_xticks([])\n",
        "    ax.set_yticks([])\n",
        "    if not im_index:\n",
        "      ax.set_title('Sample')\n",
        "    for c_index in range(0, num_classes):\n",
        "      ax = axs[im_index, c_index + 1]\n",
        "      rgb = ComplexToRGB(dec_samples[im_index, c_index, :, :], 6)\n",
        "      ax.imshow(rgb, interpolation='nearest')  # , cmap=plt.get_cmap('gray'), norm=norm_dec)\n",
        "      if not im_index: \n",
        "        ax.set_title('c = %d' % c_index)\n",
        "      ax.grid(False)\n",
        "      ax.set_yticklabels([])\n",
        "      ax.set_xticklabels([])\n",
        "      ax.set_xticks([])\n",
        "      ax.set_yticks([])\n",
        "    ax = axs[im_index, num_classes + 1]\n",
        "    ax.imshow(bar_rgb)\n",
        "    ax.grid(False)\n",
        "    ax.yaxis.set_label_position(\"right\")\n",
        "    ax.yaxis.tick_right()\n",
        "    ax.set_yticklabels(['$-\\pi$', '$0$', '$\\pi$'])\n",
        "    ax.set_yticks([0, 14, 28])\n",
        "    ax.set_xticklabels([])\n",
        "    ax.set_xticks([])\n",
        "  plt.subplots_adjust(wspace=0.03, hspace=0.01)\n",
        "  if xlabel:\n",
        "    fig.add_subplot(111, frameon=False)\n",
        "    plt.tick_params(labelcolor='red', top='off', bottom='off', left='off', right='off')\n",
        "    plt.grid(False)\n",
        "    plt.gca().set_xticklabels([])\n",
        "    plt.gca().set_yticklabels([])\n",
        "    plt.xlabel(xlabel)\n",
        "    plt.gca().xaxis.set_label_coords(0.5, -0.08)\n",
        "  plt.savefig(filename, bbox_inches='tight')\n",
        "  files.download(filename)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WdDC3KEjLqqI"
      },
      "source": [
        "#Fashion MNIST"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q_0QwhcDJRoR"
      },
      "source": [
        "x_train, y_train_cat, x_test, y_test_cat = ReadSets(DATASET_FASHION, NUM_CLASSES)\n",
        "x_train_most_likely_pixel_cat, x_train_prob = MostLikelyCategory(x_train, y_train_cat)\n",
        "PlotPixelMap(x_train_most_likely_pixel_cat, x_train_prob, 'fashion_map.png', '(a)')\n",
        "x_test_most_likely_pixel_cat, x_test_prob = MostLikelyCategory(x_test, y_test_cat)\n",
        "PlotPixelMap(x_test_most_likely_pixel_cat, x_test_prob, 'fashion_map_test.png', '(a)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kafbwXRLKMqe"
      },
      "source": [
        "p_confusion_test = Confusion(x_test, y_test_cat, x_train_most_likely_pixel_cat)\n",
        "PlotConfusionMatrix(p_confusion_test, 'fashion_mnist_confusion.png', '(a)')\n",
        "p_confusion_test_test = Confusion(x_test, y_test_cat, x_test_most_likely_pixel_cat)\n",
        "PlotConfusionMatrix(p_confusion_test_test, 'fashion_mnist_confusion_tt.png', '(a)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hs5LtxsodPdA"
      },
      "source": [
        "print(\"Accuray baseline \", np.trace(p_confusion_test))\n",
        "print(\"Accuray bound \", np.trace(p_confusion_test_test))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TMfMFqVWU4J-"
      },
      "source": [
        "p_confusion_ours = ConfusionOurs(x_test, y_test_cat, fashion_mnist_h0)\n",
        "PlotConfusionMatrix(p_confusion_ours, 'fashion_mnist_confusion_ours.png', '(c)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xKDie-JbboXA"
      },
      "source": [
        "print(\"Accuray ours \", np.trace(p_confusion_ours))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FJbY5xr3x1hb"
      },
      "source": [
        "test_indices = [0, 2]\n",
        "dec_samples = DecomposeSamples(fashion_mnist_h0, x_test, test_indices)\n",
        "PlotSamples(x_test[test_indices, :, :], dec_samples, 'fashion_mnist_projection.png', figsize=(18, 1.5 * (len(test_indices) + 1)), xlabel='(a)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e-1WGHLgLfXI"
      },
      "source": [
        "##MNIST"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BCdP3eiNIjPv"
      },
      "source": [
        "# Most likely class for each photon arrival location (figure 2(b).)\n",
        "\n",
        "x_train, y_train_cat, x_test, y_test_cat = ReadSets(DATASET_MNIST, NUM_CLASSES)\n",
        "x_train_most_likely_pixel_cat, x_train_prob = MostLikelyCategory(x_train, y_train_cat)\n",
        "PlotPixelMap(x_train_most_likely_pixel_cat, x_train_prob, 'mnist_map.png', '(b)')\n",
        "x_test_most_likely_pixel_cat, x_test_prob = MostLikelyCategory(x_test, y_test_cat)\n",
        "PlotPixelMap(x_test_most_likely_pixel_cat, x_test_prob, 'mnist_map_tt.png', '(b)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-FCibnp49mSq"
      },
      "source": [
        "# Confusion matrix (figure 3(b).)\n",
        "\n",
        "p_confusion_test = Confusion(x_test, y_test_cat, x_train_most_likely_pixel_cat)\n",
        "PlotConfusionMatrix(p_confusion_test, 'mnist_confusion.png', '(b)')\n",
        "p_confusion_test_test = Confusion(x_test, y_test_cat, x_test_most_likely_pixel_cat)\n",
        "PlotConfusionMatrix(p_confusion_test_test, 'mnist_confusion_tt.png', '(b)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2ygd-Ll-ddS2"
      },
      "source": [
        "print(\"Accuray baseline \", np.trace(p_confusion_test))\n",
        "print(\"Accuray bound \", np.trace(p_confusion_test_test))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g4JnwN5fDguH"
      },
      "source": [
        "p_confusion_ours = ConfusionOurs(x_test, y_test_cat, mnist_h0)\n",
        "PlotConfusionMatrix(p_confusion_ours, 'mnist_confusion_ours.png', '(d)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C-UxyX-Oa84c"
      },
      "source": [
        "print(\"Accuray ours \", np.trace(p_confusion_ours))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dogn7No9sDAl"
      },
      "source": [
        "# Visualize decomposition of some test instances with respect to the set of\n",
        "# projectors.\n",
        "\n",
        "test_indices = [0, 3] # These are indexes used for the decomposition.\n",
        "dec_samples = DecomposeSamples(mnist_h0, x_test, test_indices)\n",
        "PlotSamples(x_test[test_indices, :, :], dec_samples, 'mnist_projection.png', figsize=(18, 1.5 * (len(test_indices) + 1)), xlabel='(b)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PKH2yel6LyNx"
      },
      "source": [
        "#An example of factorization\n",
        "\n",
        "Below, we give an example of a 'measurement' of an atom spin."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dk7_9jbiizn8"
      },
      "source": [
        "import numpy as np\n",
        "import math\n",
        "import scipy.linalg\n",
        "\n",
        "# Initial psi atom.\n",
        "atom_initial_up = np.array([1.0, 0.0])\n",
        "atom_initial_down = np.array([0.0, 1.0])\n",
        "c = np.sqrt([0.9, 0.1])\n",
        "atom_initial = c[0] * atom_initial_up + c[1] * atom_initial_down\n",
        "# Initial psi apparatus\n",
        "apparatus_initial = np.array([0.0, 1.0, 0.0, 0.0])\n",
        "\n",
        "# Initial joint psi atom-apparatus.\n",
        "all_initial = np.einsum('a,A->aA', atom_initial, apparatus_initial).reshape(-1)\n",
        "print(\"psi_all_initial = \", all_initial)\n",
        "\n",
        "# We model the interaction between atom and apparatus with a rotation (this\n",
        "# represents the solution of the Schroedinger equation modeling the\n",
        "# interaction.)\n",
        "rng = np.random.RandomState(seed=0)\n",
        "gen0 = rng.normal(size=(8, 8))\n",
        "gen = gen0 - gen0.T\n",
        "rotation = scipy.linalg.expm(gen)\n",
        "\n",
        "# Final state after the interaction.\n",
        "all_final = np.dot(rotation, all_initial)\n",
        "\n",
        "print(\"psi_all_final = \", all_final.round(3))\n",
        "\n",
        "xstate = lambda atom, app: np.einsum('a,A->aA', atom, app).reshape(-1)\n",
        "normalized = lambda v: v / (v*v).sum()**.5\n",
        "\n",
        "# Apparatus state when spin up is observed.\n",
        "apparatus_up = normalized(all_final[0:4])\n",
        "# Apparatus state when spin down is observed.\n",
        "apparatus_down = normalized(all_final[4:])\n",
        "# Probability amplitudes of observing spin up or down\n",
        "alpha = all_final[0] / apparatus_up[0]\n",
        "beta = all_final[4] / apparatus_down[0]\n",
        "print(\"Factorization of psi_all_final \",\n",
        "      alpha * xstate([1.0, 0.0], apparatus_up) +\n",
        "      beta * xstate([0.0, 1.0], apparatus_down))\n",
        "print(\"apparatus_up \", apparatus_up)\n",
        "print(\"apparatus_down \", apparatus_down)\n",
        " "
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}