
import keras
import pennylane as qml
from datetime import datetime
import data_preprocessing as data_pre

mnist_dataset = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist_dataset.load_data()

train_size = 400
test_size = 100
lr = 0.4
shot_train = 100
shot_test = 1000
iteration = 400
check_slide = 5
num_qubits = 8
alter_layer = 1

#data preprocessing, 784 to 256, norm normalizes to 1

for i in range(2):
	data_pre.output_data(train_images, train_labels, test_images, test_labels, i, 4)

import aa_encoding as aae


a = aae.mnist_encoding(0, "train_images_0_16_normalized", train_size, alter_layer)
a = aae.mnist_encoding(0, "test_images_0_16_normalized", train_size, alter_layer)
a = aae.mnist_encoding(1, "train_images_1_16_normalized", test_size, alter_layer)
a = aae.mnist_encoding(1, "test_images_1_16_normalized", test_size, alter_layer)

import qnn

t0 = datetime.now()
dev = qml.device('default.qubit', wires = num_qubits)
qcircuit = qml.QNode(qnn.ttn_circuit, dev)
a = qnn.mnist_classification(8, 15, qcircuit, "ttn", iteration, check_slide, lr, shot_train, shot_test, test_size, train_size, False)
t1 = datetime.now()
print('seconds for the ttn circuit: ', (t1-t0).seconds)
dev = qml.device('default.qubit', wires = num_qubits)
qcircuit = qml.QNode(qnn.random_circuit, dev)
a = qnn.mnist_classification(8, 15, qcircuit, "random", iteration, check_slide, lr, shot_train, shot_test, test_size, train_size, False)
t2 = datetime.now()
print('seconds for the random circuit: ', (t2-t1).seconds)

t0 = datetime.now()
dev = qml.device('default.qubit', wires = num_qubits)
qcircuit = qml.QNode(qnn.ttn_circuit, dev)
a = qnn.mnist_classification(8, 15, qcircuit, "ttn", iteration, check_slide, lr, shot_train, shot_test, test_size, train_size, True)
t1 = datetime.now()
print('seconds for the ttn circuit: ', (t1-t0).seconds)
dev = qml.device('default.qubit', wires = num_qubits)
qcircuit = qml.QNode(qnn.random_circuit, dev)
a = qnn.mnist_classification(8, 15, qcircuit, "random", iteration, check_slide, lr, shot_train, shot_test, test_size, train_size, True)
t2 = datetime.now()
print('seconds for the random circuit: ', (t2-t1).seconds)

import figure

a = figure.draw(iteration, check_slide)