"""
qml implementation for QCL
"""

import pennylane as qml
import torch
import torch.nn as nn
from pennylane.transforms import broadcast_expand
from models.circuits import QCL_circuit, qubit_dict, depth_dict, encoding_block
import matplotlib.pyplot as plt


n_qubits = qubit_dict['qcl']
l = []
for q in range(n_qubits):
    l.append(q)
depth = depth_dict['qcl']
# dev = qml.device('default.qubit', wires=n_qubits)


class QCL(nn.Module):
    def __init__(self, dev, num_classes=2, embedding='amplitude', noise=False):
        super(QCL, self).__init__()
        self.dev = dev
        self.num_classes = num_classes
        self.embedding = embedding
        self.noise = noise
        weight_shapes = {'weights': (depth, n_qubits, 3)}
        self.ql = qml.qnn.TorchLayer(self.create_circuit(dev, embedding, noise=noise), weight_shapes)

    @staticmethod
    def create_circuit(dev, embedding, noise=False):
        @qml.qnode(dev, interface='torch')
        def circuit(inputs, weights):
            inputs = inputs.to(torch.float64)
            # AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True, pad_with=0)
            encoding_block(inputs, n_qubits, name=embedding)
            QCL_circuit(depth, n_qubits, weights, noise=noise)
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
        return circuit

    def circuit_state(self, inputs, weights, depth_=depth, exec_=True):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights, depth_, exec_):
            encoding_block(inputs, n_qubits, name=self.embedding)
            if exec_:
                QCL_circuit(depth_, n_qubits, weights, self.noise)

            return qml.state()
        return circuit(inputs, weights, depth_=depth_, exec_=exec_)

    def circuit_prob(self, inputs, weights, depth_=depth, exec_=True, qubit_l=l):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights, depth_, exec_, qubit_l):
            inputs = inputs.to(torch.float64)
            encoding_block(inputs, n_qubits, name=self.embedding)
            if exec_:
                QCL_circuit(depth_, n_qubits, weights, self.noise)
            return qml.probs(wires=qubit_l)
        return circuit(inputs, weights, depth_=depth_, exec_=exec_, qubit_l=qubit_l)

    def circuit2matrix_wo_embed(self, inputs, weights):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights):
            # inputs = inputs.to(torch.float64)
            # AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True, pad_with=0)
            QCL_circuit(depth, n_qubits, weights, self.noise)
            return qml.state()

        return qml.matrix(circuit)(inputs, weights)

    def circuit_with_generator(self, inputs, weights, g_circuit, g_depth, g_qubits, g_weights):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights):
            encoding_block(inputs, n_qubits, name=self.embedding)
            g_circuit(g_depth, g_qubits, g_weights)
            QCL_circuit(depth, n_qubits, weights, self.noise)
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
        p = circuit(inputs, weights)
        p = torch.stack(p).T
        return p[:, :self.num_classes]

    def adversarial_state(self, inputs, g_circuit, g_depth, g_qubits, g_weights):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs):
            encoding_block(inputs, n_qubits, name=self.embedding)
            g_circuit(g_depth, g_qubits, g_weights)

            return qml.state()
        return circuit(inputs)

    def adversarial_inter_prob(self, inputs, weights, t_depth, g_circuit, g_depth, g_qubits, g_weights):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs):
            encoding_block(inputs, n_qubits, name=self.embedding)
            g_circuit(g_depth, g_qubits, g_weights)
            QCL_circuit(t_depth, n_qubits, weights, self.noise)
            return qml.probs()
        return circuit(inputs)

    def forward(self, x, y):
        preds = self.predict(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(preds, y)
        return loss

    def predict(self, x):
        x = torch.flatten(x, start_dim=1)
        if self.embedding == 'amplitude':
            x = self.ql(x)
            return x[:, :self.num_classes]
        elif self.embedding == 'angle':
            results = torch.zeros((x.shape[0], self.num_classes))
            for i, input in enumerate(x):
                results[i] = self.ql(input)[:self.num_classes]
            return results

    def visualize_circuit(self, x, weights, save_path):
        fig, ax = qml.draw_mpl(self.create_circuit(self.dev, noise=self.noise, embedding=self.embedding))(torch.flatten(x), weights)
        # fig.show()
        plt.savefig(save_path)
        plt.close(fig)

    def obtain_qtape(self, x, weights):
        circuit = self.create_circuit(self.dev, noise=self.noise, embedding=self.embedding)
        # circuit(torch.flatten(x), weights)
        return qml.workflow.construct_tape(circuit)(torch.flatten(x), weights)

    def preprocess_input(self, input):
        input = input.to(torch.float64)  # todo for finite shots
        encoding_block(input, n_qubits, name=self.embedding)
        return None

    def process_output(self, out):
        return torch.argmax(out[:, :self.num_classes], dim=1)
