"""Validates MNIST accuracy claims for mnist10x10_compiled.py

The file mnist10x10_compiled.py contains generated python code that was
synthesized from the trained model-parameters in mnist10x10_3605.npy
to simulate an optical circuit with only phase-shifters and beam-splitters.

These model-parameters achieve 36.05% test set accuracy on MNIST
padded by one pixel on each side from 28x28 to 30x30, then resampled
to 10x10.

This code demonstrates that:

  (a) The compiled code provides a faithful implementation
      of the unitary transform obtained from training with TensorFlow, i.e.
      the optical network's output matches the unitary transform on all
      test set examples.

  (b) Test set classification accuracy is above 36%.

"""

import pdb
import numpy
import scipy.linalg
import tensorflow as tf

import mnist10x10_compiled


DATASET = tf.keras.datasets.mnist
XYDIM = 100
XYDIM_EXT = 100
NUM_CLASSES = 10

_u100_params = numpy.load('mnist10x10_3605.npy')
_u100 = scipy.linalg.expm(1j * (_u100_params + _u100_params.T) +
                          (_u100_params - _u100_params.T))


def amplitudes_from_u100(q_features):
   return _u100.dot(q_features.ravel())


def amplitudes_from_optical_circuit(q_features):
   return numpy.array(mnist10x10_compiled.transform(q_features.ravel()))

                      
(x_train255_28, y_train), (x_test255_28, y_test) = DATASET.load_data()
x_train_28, x_test_28 = (x_train255_28 / 255.0), (x_test255_28 / 255.0)

x_train = numpy.pad(
    x_train_28,
    [(0, 0), (1, 1), (1, 1)]).reshape(-1, 10, 3, 10, 3).sum(axis=(2, 4))
x_test = numpy.pad(
    x_test_28,
    [(0, 0), (1, 1), (1, 1)]).reshape(-1, 10, 3, 10, 3).sum(axis=(2, 4))


# Quantum amplitudes of the incoming photon as it passed the filter.
# We must normalize to total_intensity=1, since we see 1 photon.
def quantum_states_from_xs(xs):
  amplitudes = numpy.sqrt(
      xs / numpy.einsum('byx->b', xs)[:, numpy.newaxis, numpy.newaxis])
  return numpy.pad(amplitudes.reshape(xs.shape[0], -1),
                  ((0, 0), (0, XYDIM_EXT - XYDIM)))
  

xq_test = quantum_states_from_xs(x_test)


probs_correct = []

for label, xq_test in zip(y_test, xq_test):
  amplitudes_u100 = amplitudes_from_u100(xq_test).reshape(NUM_CLASSES, -1)
  amplitudes_circuit = (
      amplitudes_from_optical_circuit(xq_test).reshape(NUM_CLASSES, -1))
  # Assert claim (a): the circuit acts as a faithful implementation of the
  # unitary transform on every test set example.
  assert numpy.allclose(amplitudes_u100, amplitudes_circuit)
  per_class_probability = numpy.einsum(
      'ca,ca->c', amplitudes_circuit, amplitudes_circuit.conj()).real
  probs_correct.append(per_class_probability[label])
  if len(probs_correct) % 100 == 0:
    print('%5d Examples processed, accuracy: %.2f%%' %
          (len(probs_correct), 100 * numpy.mean(probs_correct)))

print('Test set accuracy: %.2f' % (100 * numpy.mean(probs_correct)))
# Assert claim (b): At least 36% accuracy.
assert 100 * numpy.mean(probs_correct) > 36.0

