"""
qml implementation for QCNN (conv1, conv2, pool1, pool2, fc)
2/3分类
"""

import pennylane as qml
import torch
import torch.nn as nn
from pennylane.templates.embeddings import AmplitudeEmbedding
import math
from models.circuits import pure_qcnn_circuit, pure_qcnn_block1, pure_qcnn_block2, qubit_dict, encoding_block


n_qubits = qubit_dict['qcnn']
l = []
for q in range(n_qubits):
    l.append(q)
# dev = qml.device('default.qubit', wires=n_qubits)


class QCNN(nn.Module):
    def __init__(self, dev, num_classes=2, embedding='amplitude', noise=False):
        super(QCNN, self).__init__()
        self.dev = dev
        self.num_classes = num_classes
        self.embedding = embedding
        self.depth = 3
        self.noise = noise

        self.cir = qml.qnn.TorchLayer(self.create_circuit(dev, embedding, noise), {'weights_conv1': (n_qubits, 15),
                                                'weights_conv2': (math.ceil((n_qubits - 2) / 2), 15),
                                                'weights_pool1': (math.ceil(n_qubits / 2), 2),
                                                'weights_pool2': (math.ceil(n_qubits / 4), 2),
                                                'weights_fc': (3,)})

    @staticmethod
    def create_circuit(dev, embedding, noise=False):
        @qml.qnode(dev, interface='torch')
        def circuit(inputs, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc):
            inputs = inputs.to(torch.float64)
            encoding_block(inputs, n_qubits, embedding)
            pure_qcnn_circuit(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc, noise=noise)

            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
        return circuit

    def circuit_state(self, inputs, weights, exec_=True, depth_=1):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc, exec_,
                          depth_):
            inputs = inputs.to(torch.float64)
            encoding_block(inputs, n_qubits, self.embedding)
            if exec_:
                if depth_ == 1:
                    pure_qcnn_block1(n_qubits, weights_conv1, weights_pool1)
                elif depth_ == 2:
                    pure_qcnn_block2(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2)
                elif depth_ == 3:
                    pure_qcnn_circuit(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc)
            return qml.state()
        return circuit(inputs, weights[0], weights[1], weights[2], weights[3], weights[4], exec_, depth_)

    def circuit_prob(self, inputs, weights, depth_=1, qubit_l=l, exec_=True):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc, depth_,
                         qubit_l):
            inputs = inputs.to(torch.float64)
            encoding_block(inputs, n_qubits, self.embedding)
            if exec_:
                if depth_ == 1:
                    pure_qcnn_block1(n_qubits, weights_conv1, weights_pool1)
                elif depth_ == 2:
                    pure_qcnn_block2(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2)
                elif depth_ == 3:
                    pure_qcnn_circuit(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc)
            return qml.probs(wires=qubit_l)
        return circuit(inputs, weights[0], weights[1], weights[2], weights[3], weights[4], depth_, qubit_l)

    def circuit2matrix_wo_embed(self, inputs, weights):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc):
            # inputs = inputs.to(torch.float64)
            # AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True, pad_with=0)
            pure_qcnn_circuit(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc)
            return qml.state()
        return qml.matrix(circuit)(inputs, weights[0], weights[1], weights[2], weights[3], weights[4])

    def forward(self, x, y):
        preds = self.predict(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(preds, y)
        return loss

    def predict(self, x):
        x = torch.flatten(x, start_dim=1)
        if self.embedding == 'amplitude':
            results = self.cir(x)
        else:
            results = torch.zeros((x.shape[0], n_qubits))
            for i, input in enumerate(x):
                results[i] = self.cir(input)
        if self.num_classes == 2:
            return results[:, torch.tensor([0, 2])] if n_qubits == 8 else results[:, torch.tensor([0, 4])]
        else:
            return results[:, torch.tensor([0, 2, 4])] if n_qubits == 8 else results[:, torch.tensor([0, 4, 8])]

    def visualize_circuit(self, x, weights, save_path):
        import matplotlib.pyplot as plt
        fig, ax = qml.draw_mpl(self.create_circuit(self.dev, noise=self.noise, embedding=self.embedding))(torch.flatten(x), weights[0], weights[1], weights[2], weights[3], weights[4])
        #fig.show()
        plt.savefig(save_path)
        plt.close(fig)

    def obtain_qtape(self, x, weights):
        circuit = self.create_circuit(self.dev, noise=self.noise, embedding=self.embedding)
        # circuit(torch.flatten(x), weights[0], weights[1], weights[2], weights[3], weights[4])
        return qml.workflow.construct_tape(circuit)(torch.flatten(x), weights[0], weights[1], weights[2], weights[3], weights[4])

    def preprocess_input(self, input):
        input = input.to(torch.float64)
        encoding_block(input, n_qubits, name=self.embedding)
        return None

    def process_output(self, out):
        if self.num_classes == 2:
            r = [out[:, 0], out[:, 2]] if n_qubits == 8 else [out[:, 0], out[:, 4]]
        else:
            r = [out[:, 0], out[:, 2], out[:, 4]] if n_qubits == 8 else [out[:, 0], out[:, 4], out[:, 8]]
        return torch.argmax(torch.stack(r, dim=1), dim=1)
