{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "SinglePhotonTraining.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8FKkZJSgYmur"
      },
      "source": [
        "#Single-photon Classification\n",
        "\n",
        "##Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dKQbNNzeOmOL"
      },
      "source": [
        "from google.colab import files\n",
        "import numpy\n",
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow.compat.v2.keras as keras\n",
        "from matplotlib import pyplot\n",
        "\n",
        "print('TensorFlow Version:', tf.__version__)\n",
        "\n",
        "numpy.set_printoptions(linewidth=200)\n",
        "rng = numpy.random.RandomState(seed=2)\n",
        "\n",
        "# The computationally expensive part is exponentiating the Lie algebra element\n",
        "# to obtain the SU(790) \"de-mixing\" basis transformation.\n",
        "# This is then applied to all items in the batch uniformly. Since the rest\n",
        "# of the computation is easy, this suggests we should go with large batch sizes.\n",
        "\n",
        "BATCH_SIZE = 8192\n",
        "NUM_CLASSES = 10\n",
        "XDIM = 28\n",
        "XYDIM = 784\n",
        "XYDIM_EXT = 790\n",
        "\n",
        "# Choose one of the following sets.\n",
        "\n",
        "DATASET = tf.keras.datasets.mnist\n",
        "# DATASET = tf.keras.datasets.fashion_mnist\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PUS91GCsOqqQ"
      },
      "source": [
        "# Uploading a model trained earlier (if available). In this way, one can train\n",
        "# further an existing model.\n",
        "# e.g.:\n",
        "# for fashion_mnist use fashion_mnist_trained_U790.npy\n",
        "# for mnist use mnist_trained_U790.npy\n",
        "\n",
        "uploaded_h0 = None\n",
        "\n",
        "print('Upload a pre-trained model *.npy file if available.\\n'\n",
        "      'Otherwise, a new model will be trained.\\n\\n')\n",
        "uploaded_files = None\n",
        "while True:\n",
        "  uploaded_files = files.upload()\n",
        "  if not uploaded_files:  # User does not upload a model file.\n",
        "    print('No user-provided model. Training a new model.')\n",
        "    break\n",
        "  if (len(uploaded_files) > 2 or\n",
        "      not all(x.endswith('.npy') for x in uploaded_files.keys())):\n",
        "    print('Please only upload exactly one .npy file')\n",
        "    continue\n",
        "  try:\n",
        "    uploaded_h0 = numpy.load(list(uploaded_files)[0]).astype(numpy.float64)\n",
        "    if uploaded_h0.shape != (XYDIM_EXT, XYDIM_EXT):\n",
        "      print('Matrix size mismatch.')\n",
        "      continue\n",
        "    break\n",
        "  except Exception as e:\n",
        "    print('Parse error:', e)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XJRN6wqDE_oO"
      },
      "source": [
        "# Training the current model (or a new one, if none was uploaded).\n",
        "# Results are periodically saved in the file h0.npy on the colab server.\n",
        "# Training can be stopped any time with the cell \"stop\" button and the most\n",
        "# recent model file can be downloaded locally with the download_model()\n",
        "# function from the next cell. \n",
        "\n",
        "def brightness_normalized_examples(examples):\n",
        "  # Normalize each example such that total pixel-brightness sums to 1.0\n",
        "  # We need to do this as we are (both in the classical and quantum case)\n",
        "  # interested in the first photon that passes the image-filter.\n",
        "  # For a 'bright' digit (like an \"8\"), it will take less time to see the\n",
        "  # first photon than for an equally illuminated 'dark' digit (like an '1'),\n",
        "  # but for each example, we care about the probabilities of the first photon\n",
        "  # to come from each of its pixels - hence, these probabilities must sum to 1.\n",
        "  return examples / (\n",
        "      numpy.einsum('byx->b', examples)[:, numpy.newaxis, numpy.newaxis])\n",
        "\n",
        "\n",
        "(x_train255, y_train), (x_test255, y_test) = DATASET.load_data()\n",
        "x_train = brightness_normalized_examples(x_train255.astype(float))\n",
        "x_test = brightness_normalized_examples(x_test255.astype(float))\n",
        "\n",
        "# Quantum amplitudes of the incoming photon as it passed the filter.\n",
        "# We must normalize to total_intensity=1, since we see 1 photon.\n",
        "# Also, we pad to\n",
        "\n",
        "def quantum_states_from_xs(brightness_normalized_examples):\n",
        "  # Pixel quantum amplitudes are the square roots of the per-example pixel\n",
        "  # probabilities (up to an un-observable normalization factor).\n",
        "  # Here, we flatten the quantum state space to 1-index and pad to a multiple\n",
        "  # of the number of classes (= 10).\n",
        "  amplitudes = numpy.sqrt(brightness_normalized_examples)\n",
        "  return numpy.pad(\n",
        "      amplitudes.reshape(brightness_normalized_examples.shape[0], -1),\n",
        "      ((0, 0), (0, XYDIM_EXT - XYDIM)))\n",
        "\n",
        "xq_train = quantum_states_from_xs(x_train)\n",
        "xq_test = quantum_states_from_xs(x_test)\n",
        "\n",
        "y_train_cat = keras.utils.to_categorical(y_train, NUM_CLASSES)\n",
        "y_test_cat = keras.utils.to_categorical(y_test, NUM_CLASSES)\n",
        "\n",
        "train_ds = tf.data.Dataset.from_tensor_slices(\n",
        "    (xq_train, y_train_cat)).shuffle(10000).batch(BATCH_SIZE)\n",
        "\n",
        "test_ds = tf.data.Dataset.from_tensor_slices(\n",
        "    (xq_test, y_test_cat)).batch(BATCH_SIZE)\n",
        "\n",
        "# \"Classical baseline\" proven performance threshold:\n",
        "# If a pixel arrives at (row, col), we need to know the most likely digit,\n",
        "# with probabilities as observed on the test(!) set\n",
        "# (see appendix A.2 for explanation).\n",
        "x_test_intensity_per_cat = (\n",
        "    numpy.einsum('byx,bc->yxc', x_test, y_test_cat) +\n",
        "    # Trick: We add a tiny randomized-for-random-tie-breaking 'epsilon'\n",
        "    # to the brightnesses that only changes results for pixels that were\n",
        "    # dark in every single training example.\n",
        "    rng.uniform(low=0, high=1e-100, size=(28, 28, 10)))\n",
        "x_test_most_likely_pixel_cat = keras.utils.to_categorical(\n",
        "    x_test_intensity_per_cat.argmax(axis=2))\n",
        "\n",
        "# For every test-set image, we need to find out what fraction of photons\n",
        "# would make our classifier predict the category correctly.\n",
        "x_test_batched_total_photons_yielding_correct_cat = numpy.einsum(\n",
        "    'byx,yxc,bc->b', x_test, x_test_most_likely_pixel_cat, y_test_cat)\n",
        "x_test_fraction_photons_yielding_correct_cat = (\n",
        "    x_test_batched_total_photons_yielding_correct_cat / numpy.einsum('byx->b', x_test))\n",
        "baseline_accuracy = numpy.round(x_test_fraction_photons_yielding_correct_cat.mean(), 4)\n",
        "# Baseline accuracy 21.27% is found to be RNG seed independent.\n",
        "# (Seed influences assignment of dark-everywhere-on-training-set pixels.\n",
        "# So, apparently, the test-set does not have relevant ink on pixels that are\n",
        "# dark on the training set.)\n",
        "print('Classical-detection baseline accuracy: %.2f%%' %\n",
        "      (100 * baseline_accuracy))\n",
        "\n",
        "k_h0 = tf.keras.backend.variable(\n",
        "     (uploaded_h0 if uploaded_h0 is not None\n",
        "      else rng.normal(size=(XYDIM_EXT, XYDIM_EXT), scale=0.01)),\n",
        "     dtype='float64', name='H0')\n",
        "\n",
        "x_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)\n",
        "\n",
        "\n",
        "@tf.function\n",
        "def get_loss_accuracies_u(h0, psi, labels):\n",
        "  h0_T = tf.transpose(h0)\n",
        "  DDD_ORTHO_ONLY = 1.0  # DDD Use 1.0 for unitary.\n",
        "  ih = tf.complex(h0 - h0_T, DDD_ORTHO_ONLY * (h0 + h0_T))\n",
        "  u = tf.linalg.expm(ih)\n",
        "  u_psi = tf.reshape(\n",
        "      tf.einsum('bp,qp->bq', tf.cast(psi, tf.complex128), u),\n",
        "      (-1, NUM_CLASSES, XYDIM_EXT // NUM_CLASSES))\n",
        "  u_psi_r = tf.math.real(u_psi)\n",
        "  u_psi_i = tf.math.imag(u_psi)\n",
        "  probs = tf.einsum('bcs->bc', tf.math.square(u_psi_r) + tf.math.square(u_psi_i))\n",
        "  loss = x_loss(tf.cast(labels, tf.float64), probs)\n",
        "  batch_quantum_accuracy = (\n",
        "      tf.einsum('bc,bc->', tf.cast(labels, tf.float64), probs) /\n",
        "      tf.reduce_sum(tf.ones_like(labels[:, 0], dtype=tf.float64))) # XXX Fix code wart.\n",
        "  batch_manyphoton_accuracy = tf.math.reduce_mean(\n",
        "      keras.metrics.categorical_accuracy(labels, probs))\n",
        "  return loss, batch_quantum_accuracy, batch_manyphoton_accuracy, u\n",
        "\n",
        "\n",
        "def train_step(optimizer, k_h0, psis, labels):\n",
        "  c1 = tf.constant(1.0, dtype=tf.float64)\n",
        "  tape = tf.GradientTape()\n",
        "  with tape:\n",
        "    tape.watch(k_h0)\n",
        "    loss, accuracy_q, accuracy_m, u = get_loss_accuracies_u(k_h0, psis, labels)\n",
        "  gradients = tape.gradient(loss, [k_h0])\n",
        "  optimizer.apply_gradients(zip(gradients, [k_h0]))\n",
        "  return loss, accuracy_q, accuracy_m, u\n",
        "\n",
        "def get_test_set_accuracies(h0):\n",
        "  loss, accuracy_q, accuracy_m, u = get_loss_accuracies_u(\n",
        "      h0, xq_test, y_test_cat)\n",
        "  del loss, u  # Unused.\n",
        "  return accuracy_q, accuracy_m\n",
        "\n",
        "\n",
        "def download_model():\n",
        "  \"\"\"Downloads a model from colab.\"\"\"\n",
        "  accuracy_q, accuracy_m = get_test_set_accuracies(k_h0.numpy())\n",
        "  print(\"This model's test set accuracy: Quantum=%.2f%%, Manyphoton=%.2f%%\" %\n",
        "        (100 * accuracy_q, 100 * accuracy_m))\n",
        "  files.download('h0.npy') \n",
        "\n",
        "\n",
        "def train_a_model(stage1_steps=2, stage2_steps=2):\n",
        "  # Actual training would use much larger number of stage1_steps and\n",
        "  # stage2_steps.\n",
        "  \"\"\"Trains a model.\n",
        "  \n",
        "  Interrupting this function will produce a valid partially-trained model.\n",
        "  \"\"\"\n",
        "  optimizer1 = tf.keras.optimizers.SGD(learning_rate=0.3)\n",
        "  optimizer2 = tf.keras.optimizers.SGD(learning_rate=0.03)\n",
        "  for ep in range(stage1_steps):\n",
        "    for n, (xq_img, labels) in enumerate(train_ds):\n",
        "      loss, accuracy_q, accuracy_m, u = train_step(\n",
        "          optimizer1, k_h0, xq_img, labels)\n",
        "      tr_u = tf.linalg.trace(u)\n",
        "      if n % 5 == 0:\n",
        "        print('[Epoch=%d STEP=%04d, loss=%.3f, acc_q=%.3f, acc_m=%.3f, tr_U=%s]' %\n",
        "              (ep, n, loss.numpy(), accuracy_q.numpy(), accuracy_m.numpy(),\n",
        "               numpy.round(tr_u.numpy(), 3)))\n",
        "      if n == 0 and ep % 10 == 0:\n",
        "        print('Test Set accuracies: %.4f / %.4f' %\n",
        "              get_test_set_accuracies(k_h0))\n",
        "  for ep in range(stage2_steps):\n",
        "    for n, (xq_img, labels) in enumerate(train_ds):\n",
        "      loss, accuracy_q, accuracy_m, u = train_step(\n",
        "          optimizer2, k_h0, xq_img, labels)\n",
        "      tr_u = tf.linalg.trace(u)\n",
        "      if n % 5 == 0:\n",
        "        print('[Epoch=%d STEP=%04d, loss=%.3f, acc_q=%.3f, acc_m=%.3f, tr_U=%s]' %\n",
        "              (ep, n, loss.numpy(), accuracy_q.numpy(), accuracy_m.numpy(),\n",
        "               numpy.round(tr_u.numpy(), 3)))\n",
        "      if n == 0 and ep % 10 == 0:\n",
        "        print('Test Set accuracies: %.4f / %.4f' % get_test_set_accuracies(k_h0))\n",
        "        numpy.save('h0.npy', k_h0.numpy())\n",
        "        print('Saved model')\n",
        "\n",
        "train_a_model()\n",
        "\n",
        "print('Model accuracies: %.4f / %.4f' % get_test_set_accuracies(k_h0))\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qnyudg7GR1Vq"
      },
      "source": [
        "# Download model file locally.\n",
        "\n",
        "download_model()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eY7AowH5Ud_N"
      },
      "source": [
        "# Studying projectors.\n",
        "\n",
        "cat_projectors = [\n",
        "    numpy.diag(\n",
        "        numpy.einsum('c,x->cx',\n",
        "                     numpy.array([c == n for n in range(10)]),\n",
        "                     numpy.ones(XYDIM_EXT // NUM_CLASSES)).reshape(-1))\n",
        "              for c in range(NUM_CLASSES)]\n",
        "\n",
        "n_u = get_loss_accuracies_u(k_h0, xq_test, y_test_cat)[-1].numpy()\n",
        "n_u_inv = n_u.T.conj()  # Using unitarity.\n",
        "assert numpy.allclose(n_u_inv @ n_u, numpy.eye(n_u.shape[0]))\n",
        "\n",
        "projectors = numpy.stack([n_u_inv @ cp @ n_u for cp in cat_projectors], axis=0)\n",
        "assert numpy.allclose(projectors.sum(axis=0), numpy.eye(projectors[0].shape[0]))\n",
        "\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wNwpHy3CYkPn"
      },
      "source": [
        "def decompose_img(xq_img):\n",
        "  padding = XYDIM_EXT - xq_img.size\n",
        "  psi = numpy.pad(xq_img.reshape(-1), (0, padding))\n",
        "  proj_psi = numpy.einsum('cqp,p->cq', projectors, psi)\n",
        "  proj_probs = (proj_psi * proj_psi.conj()).real\n",
        "  proj_prob_residuals = proj_probs[:, XYDIM:].sum(axis=1)\n",
        "  proj_pieces = proj_psi[:, :XYDIM].reshape((-1,) + xq_img.shape)\n",
        "  return proj_prob_residuals, proj_pieces\n",
        "\n",
        "def show_test_images(test_indices):\n",
        "  for idx in test_indices:\n",
        "    y = y_test[idx]\n",
        "    xq = xq_test[idx][:XYDIM].reshape(x_test.shape[1:])\n",
        "    x_probs = (xq * xq.conj()).real\n",
        "    p_max = x_probs.max()\n",
        "    residuals, q_pieces = decompose_img(xq)\n",
        "    probs_pieces = (q_pieces * q_pieces.conj()).real\n",
        "    tot_probs_pieces = probs_pieces.sum(axis=(1, 2)) # XXX\n",
        "    fig = pyplot.figure(figsize=(9, 8), dpi=75, facecolor='w', edgecolor='k')\n",
        "    ax = fig.gca()\n",
        "    ax.set_title('Label=%s, Residuals=%5.3g, p=%.3f\\nps=%s, t=%s' % \n",
        "                 (y, residuals.sum(), x_probs.sum(),\n",
        "                  numpy.round(tot_probs_pieces, 3),\n",
        "                  tot_probs_pieces.sum()))\n",
        "    digit_locations = (3, 4, 5, 6, 7, 10, 11, 12, 13, 14)\n",
        "    sp = fig.add_subplot(2, 7, 1)\n",
        "    sp.matshow(x_probs, vmin=0, vmax=p_max, cmap='gray')\n",
        "    # This does not work, due to missing destructive interference.\n",
        "    # sp = fig.add_subplot(2, 7, 8)\n",
        "    # sp.matshow(probs_pieces.sum(axis=0), vmin=0, vmax=p_max)\n",
        "    sp = fig.add_subplot(2, 7, 8)\n",
        "    sp.matshow(abs(q_pieces.sum(axis=0))**2, vmin=0, vmax=p_max, cmap='gray')\n",
        "    for n in range(10):\n",
        "      sp = fig.add_subplot(2, 7, digit_locations[n])\n",
        "      sp.matshow(probs_pieces[n], vmin=0, vmax=p_max, cmap='gray')\n",
        "    fig.show()\n",
        "\n",
        "show_test_images([3000 + list(y_test[3000:]).index(k) for k in range(10)])"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}