"""
Data re-uploading for a universal quantum classifier (binary classification)

Pérez-Salinas, Adrián, et al. "Data re-uploading for a universal quantum classifier." arXiv preprint arXiv:1907.02085 (2019).
"""

import pennylane as qml
import torch
import torch.nn as nn
from models.circuits import DRNN_circuit, qubit_dict, depth_dict, pad_input
import matplotlib.pyplot as plt


n_qubits = qubit_dict['drnn']
depth = depth_dict['drnn']



class DRNN(nn.Module):
    def __init__(self, dev, num_classes=2, tuple_size=1, target_dim=72, scaling=1.5, ent_train=False, noise=False):
        super(DRNN, self).__init__()
        self.dev = dev
        self.num_classes = num_classes
        self.target_dim = target_dim
        self.scaling = scaling
        self.ent_train = ent_train
        self.noise = noise
        self.embedding = 'interleaved'

        weight_shapes = {'weights_input': (depth, n_qubits, 3), 'weights_var': (depth, n_qubits, tuple_size+1) if ent_train else (depth, n_qubits, tuple_size)}
        self.ql = qml.qnn.TorchLayer(self.create_circuit(dev, target_dim=target_dim, noise=noise, scaling=scaling, ent_train=ent_train), weight_shapes=weight_shapes)

    @staticmethod
    def create_circuit(dev, target_dim, noise=False, scaling=1.5, ent_train=False):
        @qml.qnode(dev, interface='torch')
        def circuit(inputs, weights_input, weights_var):
            inputs = inputs.to(torch.float64)
            DRNN_circuit(inputs, weights_input, weights_var, n_qubits, depth, scaling=scaling, ent_train=ent_train, noise=noise, target_dim=target_dim)
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
        return circuit

    def circuit_state(self, inputs, weights, depth_=depth, exec_=True):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights_input, weights_var, depth_, exec_):
            if exec_:
                DRNN_circuit(inputs, weights_input, weights_var, n_qubits, depth_, scaling=self.scaling,
                             ent_train=self.ent_train,
                             noise=self.noise, target_dim=self.target_dim)
            return qml.state()
        return circuit(inputs, weights[0], weights[1], depth_, exec_)

    def circuit_prob(self, inputs, weights, depth_=depth):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights_input, weights_var, depth_):
            inputs = inputs.to(torch.float64)
            DRNN_circuit(inputs, weights_input, weights_var, n_qubits, depth_, scaling=self.scaling,
                         ent_train=self.ent_train,
                         noise=self.noise, target_dim=self.target_dim)
            return qml.probs()

        return circuit(inputs, weights[0], weights[1], depth_)

    def circuit2matrix_wo_embed(self, inputs, weights):
        @qml.qnode(self.dev, interface='torch')
        def circuit(inputs, weights_input, weights_var):
            DRNN_circuit(inputs, weights_input, weights_var, n_qubits, depth, scaling=self.scaling,
                         ent_train=self.ent_train, noise=self.noise, target_dim=self.target_dim)
            return qml.probs()

        return qml.matrix(circuit)(inputs, weights[0], weights[1])

    def forward(self, x, y):
        preds = self.predict(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(preds, y)
        return loss

    def predict(self, x):
        x = torch.flatten(x, start_dim=1)
        results = torch.zeros((x.shape[0], self.num_classes))
        for i, input in enumerate(x):
            results[i] = self.ql(input)[:self.num_classes]
        return results

    def visualize_circuit(self, x, weights, save_path):
        fig, ax = qml.draw_mpl(self.create_circuit(self.dev, self.target_dim, noise=self.noise, ent_train=True))(torch.flatten(x), weights[0], weights[1])
        fig.show()
        plt.savefig(save_path)

    def obtain_qtape(self, x, weights):
        circuit = self.create_circuit(self.dev, self.target_dim, noise=self.noise, ent_train=True)
        # circuit(torch.flatten(x), weights[0], weights[1])
        return qml.workflow.construct_tape(circuit)(torch.flatten(x), weights[0], weights[1])

    def preprocess_input(self, input):
        input = input.to(torch.float64)
        return pad_input(input, target_dim=self.target_dim, dim=0)

    def process_output(self, out):
        return torch.argmax(out[:self.num_classes])

