import matplotlib.pyplot as plt
import pennylane as qml
import qiskit
import tensorflow as tf
from pennylane import numpy as np
from pennylane.templates import RandomLayers
from qiskit import BasicAer, QuantumCircuit
from tensorflow import keras

n_epochs = 30  # Number of optimization epochs
n_layers = 1  # Number of random layers
n_train = 50  # Size of the train dataset
n_test = 30  # Size of the test dataset

SAVE_PATH = "./QNN_save"  # Data saving folder
PREPROCESS = True  # If False, skip quantum processing and load data from SAVE_PATH
np.random.seed(0)  # Seed for NumPy random number generator
tf.random.set_seed(0)

mnist_dataset = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist_dataset.load_data()

# Reduce dataset size
train_images = train_images[:n_train]
train_labels = train_labels[:n_train]
test_images = test_images[:n_test]
test_labels = test_labels[:n_test]

# Normalize pixel values within 0 and 1
# 不加噪声则振幅编码时会出现[0,0,0,0]，无法编码
train_images = (train_images + np.random.rand(*train_images.shape) * 1e-8) / 255
test_images = (test_images + np.random.rand(*test_images.shape) * 1e-8) / 255

# Add extra dimension for convolution channels
train_images = np.array(train_images[..., tf.newaxis], requires_grad=False)
test_images = np.array(test_images[..., tf.newaxis], requires_grad=False)

from qiskit import Aer, execute, transpile

backend = Aer.get_backend("statevector_simulator")

from encoder import Encoding

dev = qml.device("default.qubit", wires=4)
# Random circuit parameters
rand_params = np.random.uniform(high=2 * np.pi, size=(n_layers, 4))
import numpy


@qml.qnode(dev, interface="autograd")
def circuit(phi):
    # Encoding of 4 classical input values
    embedding = transpile(
        Encoding(numpy.array(phi), encode_type="qubit_encoding").qcircuit
    )
    qml.from_qiskit(embedding)(wires=[3, 2, 1, 0])
    #     for j in range(4):
    #         qml.RY(np.pi * phi[j], wires=j)
    #     phi = np.random.rand(len(phi))*1e-8 + phi
    #     qml.templates.AmplitudeEmbedding(phi,wires=list(range(2)),normalize=True)
    # Random quantum circuit
    RandomLayers(rand_params, wires=list(range(4)))

    # Measurement producing 4 classical output values
    #     return qml.state()
    return [qml.expval(qml.PauliZ(j)) for j in range(4)]


# 结果与6中代码结果有差异
class QCNN:
    #     rand_params = None

    def __init__(self, phi, params, encode_type, n_layers, simulator="default.qubit"):
        self.embd_layer, self.qnum = self._encoder(phi, encode_type)
        # print(self.qnum)
        self.dev = qml.device(simulator, wires=self.qnum)
        self.rand_params = params

    #         if self.rand_params is None:
    #             self.rand_params = np.random.uniform(high=2 * np.pi, size=(n_layers, self.qnum))

    def _encoder(self, phi, encode_type="qubit_encoding"):
        embedding = transpile(
            Encoding(numpy.array(phi), encode_type=encode_type).qcircuit
        )
        qnum = Encoding(numpy.array(phi), encode_type=encode_type).num_qubits
        return embedding, qnum

    def circuit(self):
        qml.from_qiskit(self.embd_layer)(wires=list(range(self.qnum))[::-1])
        RandomLayers(self.rand_params, wires=list(range(self.qnum)))
        #         return qml.state()
        return [qml.expval(qml.PauliZ(j)) for j in range(self.qnum)]


random_data = np.random.rand(4)
print(random_data)
rand_params = np.random.uniform(high=2 * np.pi, size=(n_layers, 4))


model = QCNN(random_data, rand_params, n_layers=1, encode_type="qubit_encoding")
dev = model.dev
qnode = qml.QNode(model.circuit, dev)

qml.draw_mpl(circuit)
plt.show()


@qml.qnode(dev)
def c(state_vector=None, n=2):
    X = state_vector / np.linalg.norm(state_vector)  # pylint: disable=no-member
    qml.templates.AmplitudeEmbedding(X, wires=range(2))

    return qml.state()


#     return [qml.expval(qml.PauliZ(j)) for j in range(2)]


def quanv(image):
    """Convolves the input image with many applications of the same quantum circuit."""
    out = np.zeros((14, 14, 4))
    rand_params = np.random.uniform(high=2 * np.pi, size=(n_layers, 4))
    # Loop over the coordinates of the top-left pixel of 2X2 squares
    for j in range(0, 28, 2):
        for k in range(0, 28, 2):
            # Process a squared 2x2 region of the image with a quantum circuit
            qlayer = QCNN(
                [
                    image[j, k, 0],
                    image[j, k + 1, 0],
                    image[j + 1, k, 0],
                    image[j + 1, k + 1, 0],
                ],
                params=rand_params,
                encode_type="qubit_encoding",
                n_layers=n_layers,
            )
            dev = qlayer.dev
            qnode = qml.QNode(qlayer.circuit, dev)
            q_results = qnode()
            #             q_results = circuit(
            #                 [
            #                     image[j, k, 0],
            #                     image[j, k + 1, 0],
            #                     image[j + 1, k, 0],
            #                     image[j + 1, k + 1, 0]
            #                 ]
            #             )
            # Assign expectation values to different channels of the output pixel (j/2, k/2)
            for c in range(4):
                out[j // 2, k // 2, c] = q_results[c]
    return out


if PREPROCESS == True:
    q_train_images = []
    print("Quantum pre-processing of train images:")
    for idx, img in enumerate(train_images):
        print("{}/{}        ".format(idx + 1, n_train), end="\r")
        q_train_images.append(quanv(img))
    q_train_images = np.asarray(q_train_images)

    q_test_images = []
    print("\nQuantum pre-processing of test images:")
    for idx, img in enumerate(test_images):
        print("{}/{}        ".format(idx + 1, n_test), end="\r")
        q_test_images.append(quanv(img))
    q_test_images = np.asarray(q_test_images)

    # Save pre-processed images
    np.save(SAVE_PATH + "q_train_images.npy", q_train_images)
    np.save(SAVE_PATH + "q_test_images.npy", q_test_images)

    # Load pre-processed images
q_train_images = np.load(SAVE_PATH + "q_train_images.npy")
q_test_images = np.load(SAVE_PATH + "q_test_images.npy")

n_samples = 4
n_channels = 4
fig, axes = plt.subplots(1 + n_channels, n_samples, figsize=(10, 10))
for k in range(n_samples):
    axes[0, 0].set_ylabel("Input")
    if k != 0:
        axes[0, k].yaxis.set_visible(False)
    axes[0, k].imshow(train_images[k, :, :, 0], cmap="gray")

    # Plot all output channels
    for c in range(n_channels):
        axes[c + 1, 0].set_ylabel("Output [ch. {}]".format(c))
        if k != 0:
            axes[c, k].yaxis.set_visible(False)
        axes[c + 1, k].imshow(q_train_images[k, :, :, c], cmap="gray")

plt.tight_layout()
plt.show()


def MyModel():
    """Initializes and returns a custom Keras model
    which is ready to be trained."""
    model = keras.models.Sequential(
        [keras.layers.Flatten(), keras.layers.Dense(10, activation="softmax")]
    )

    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model


q_model = MyModel()

q_history = q_model.fit(
    q_train_images,
    train_labels,
    validation_data=(q_test_images, test_labels),
    batch_size=4,
    epochs=n_epochs,
    verbose=2,
)

c_model = MyModel()

c_history = c_model.fit(
    train_images,
    train_labels,
    validation_data=(test_images, test_labels),
    batch_size=4,
    epochs=n_epochs,
    verbose=2,
)

import matplotlib.pyplot as plt

plt.style.use("seaborn")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 9))

ax1.plot(q_history.history["val_accuracy"], "-ob", label="With quantum layer")
ax1.plot(c_history.history["val_accuracy"], "-og", label="Without quantum layer")
ax1.set_ylabel("Accuracy")
ax1.set_ylim([0, 1])
ax1.set_xlabel("Epoch")
ax1.legend()

ax2.plot(q_history.history["val_loss"], "-ob", label="With quantum layer")
ax2.plot(c_history.history["val_loss"], "-og", label="Without quantum layer")
ax2.set_ylabel("Loss")
ax2.set_ylim(top=2.5)
ax2.set_xlabel("Epoch")
ax2.legend()
plt.tight_layout()
plt.show()
