"""
Hybrid Classical and Quantum Neural Network
CNN as feature extractor, QNN as classifier
23 Hybrid Quantum Neural Network Structures for Image Multi-classification.pdf
"""
import numpy as np
import pennylane as qml
import torch
import torch.nn as nn
from pennylane import AmplitudeEmbedding
from models.circuits import hqnn_circuit, qubit_dict, encoding_block, depth_dict
import matplotlib.pyplot as plt


n_qubits = qubit_dict['hqnn']
depth = depth_dict['hqnn']
l = []
for q in range(n_qubits):
    l.append(q)


class HQNN(nn.Module):
    def __init__(self, dev, n_features, num_classes=2, embedding='amplitude', noise=False):
        super(HQNN, self).__init__()
        self.dev = dev
        self.num_classes = num_classes
        self.embedding = embedding
        self.noise = noise
        weight_shapes = {'weights': (depth, n_qubits, (4+n_qubits-1))}
        self.cl = nn.Linear(n_features, 2**n_qubits)
        self.ql = qml.qnn.TorchLayer(self.create_circuit(dev, embedding, noise=noise), weight_shapes)

    @staticmethod
    def create_circuit(dev, embedding, noise=False):
        @qml.qnode(dev, interface='torch')
        def circuit(inputs, weights):
            inputs = inputs.to(torch.float64)
            encoding_block(inputs, n_qubits, embedding)
            hqnn_circuit(weights, n_qubits, depth, noise)
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
        return circuit

    def circuit_state(self, inputs, weights, depth_=depth, exec_=False):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights, depth_, exec_):
            # inputs = inputs.to(torch.float64)
            # AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True, pad_with=0)
            encoding_block(inputs, n_qubits, self.embedding)
            if exec_:
                hqnn_circuit(weights, n_qubits, depth_, self.noise)
            return qml.state()

        x = torch.relu(self.cl(inputs))
        return circuit(x, weights[-1], depth_, exec_)

    def circuit_prob(self, inputs, weights, depth_=depth, exec_=True, qubit_l=l):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights, depth_, exec_, qubit_l):
            inputs = inputs.to(torch.float64)
            encoding_block(inputs, n_qubits, self.embedding)
            if exec_:
                hqnn_circuit(weights, n_qubits, depth_, noise=self.noise)
            return qml.probs(wires=qubit_l)
        x = torch.relu(self.cl(inputs))
        return circuit(x, weights[-1], depth_, exec_, qubit_l)

    def circuit2matrix_wo_embed(self, inputs, weights):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights):
            encoding_block(inputs, n_qubits, self.embedding)
            hqnn_circuit(weights, n_qubits, depth, noise=self.noise)
            return qml.state()
        x = torch.relu(self.cl(inputs))
        return qml.matrix(circuit)(x, weights[-1])

    def forward(self, x, y):
        preds = self.predict(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(preds, y)
        return loss

    def predict(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.cl(x)
        x = torch.relu(x)
        mask = torch.all(x == 0, dim=1, keepdim=True)  # [batch_size, 1]
        x = x + mask * 1e-6

        x = torch.stack([self.ql(xi) for xi in x])
        return x[:, :self.num_classes]

    def visualize_circuit(self, x, weights, save_path):
        x = torch.relu(self.cl(torch.flatten(x)))
        fig, ax = qml.draw_mpl(self.create_circuit(self.dev, noise=self.noise, embedding=self.embedding))(x, weights[-1])
        fig.show()
        plt.savefig(save_path)

    def obtain_qtape(self, x, weights):
        circuit = self.create_circuit(self.dev, noise=self.noise, embedding=self.embedding)
        x = torch.relu(self.cl(torch.flatten(x)))
        circuit(torch.flatten(x), weights[-1])
        return circuit.qtape

    def preprocess_input(self, input):
        x = torch.relu(self.cl(input))
        encoding_block(x, n_qubits, name=self.embedding)

    def process_output(self, out):
        return torch.argmax(torch.tensor(out[:self.num_classes]))