#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@title: Mitigating Barren Plateaus in Quantum Neural Networks via an AI-Driven Submartingale-Based Framework.
@topic: Quantum Model and Variational Quantum Circuits.
@author: anonymous
"""

import torch
import pennylane as qml


class QuantumModel(torch.nn.Module):
    def __init__(self, circuit, qml_dev, params_shapes, out_shape):
        super(QuantumModel, self).__init__()
        # Select the circuit & represents a quantum node in the hybrid computational graph.
        self.qnode = qml.QNode(circuit, qml_dev, interface="torch")
        # Wrap the qnode via TorchLayer given weight_shapes = (nlayers, nqubits, nrot).
        self.qlayer = qml.qnn.TorchLayer(self.qnode, weight_shapes={"params": params_shapes})
        # Transform the output shape (nqubits, nclasses) by a linear layer.
        self.fc = torch.nn.Linear(params_shapes[1], out_shape, dtype=torch.float64)

    def forward(self, inputs):
        inputs = self.qlayer(inputs)
        inputs = self.fc(inputs)
        return inputs


def _layer(params):
    """
    @descriptions: the layer of parameterized rotation structure.
    @inputs:
        params [nqubits x NUM_ROT]: the parameters in each layer (NUM_ROT=3).
    @return:
        None.
    """
    nqubits = params.shape[0]
    for i in range(nqubits):
        # decide the weight size based on the params here: (3, ) for each layer
        qml.Rot(params[i, 0], params[i, 1], params[i, 2], wires=i)
    for i in range(nqubits - 1):
        qml.CNOT(wires=[i, i+1])


#@qml.qnode(device, interface="autograd")
def circuit1(params, inputs):
    """
    @descriptions: this vqc 1) encodes classical data by AngleEmbedding and 
        2) uses a parameterized rotation structure to manipulate qubits and
        3) measure the expectation value on the 1st qubit of the Pauli-Z gate for classification.
    @limitation: The number of qubits (nqubits) must be equal to or greater than
        the number of dimensions of the data (d). (d <= nqubits)
    @inputs:
        params: the parameters of the vqc.
        inputs: the data used for AngleEmbedding.
    @return:
        The expectation value on the selected qubit(s) of the Pauli-Z gate.
    """
    assert len(params.shape) == 3, f"Expected param shape len: 3, got {len(params.shape)}"
    nqubits = params.shape[1]
    qml.AngleEmbedding(inputs, wires=range(nqubits), rotation="X")
    for param in params:
        _layer(param)
    return [qml.expval(qml.PauliZ(i)) for i in range(nqubits)]


#@qml.qnode(device, interface="autograd")
def circuit2(params, inputs):
    """
    @descriptions: this vqc 1) encodes classical data by AngleEmbedding and 
        2) uses RandomLayers to manipulate qubits and
        3) measure the expectation value on the 1st qubit of the Pauli-Z gate for classification.
    @inputs:
        params: the parameters of the vqc.
        inputs: the data used for AngleEmbedding.
    @return:
        The expectation value on the 1st qubit of the Pauli-Z gate.
    @ref:
        https://docs.pennylane.ai/en/stable/code/api/pennylane.RandomLayers.html
    """
    nqubits = params.shape[1]
    qml.AngleEmbedding(inputs, wires=range(nqubits), rotation="X")
    qml.RandomLayers(weights=params[:,:,0], wires=range(nqubits), seed=None)
    # params' shape: (nlayers, nqubits)
    return [qml.expval(qml.PauliZ(i)) for i in range(nqubits)]


#@qml.qnode(device, interface="autograd")
def circuit3(params, inputs):
    """
    @descriptions: this vqc 1) encodes classical data by AngleEmbedding and 
        2) uses StronglyEntanglingLayers to manipulate qubits and
        3) measure the expectation value on the 1st qubit of the Pauli-Z gate for classification.
    @inputs:
        params: the parameters of the vqc.
        inputs: the data used for AngleEmbedding.
    @return:
        The expectation value on the 1st qubit of the Pauli-Z gate.
    @ref:
        https://docs.pennylane.ai/en/stable/code/api/pennylane.StronglyEntanglingLayers.html
    """
    assert len(params.shape) == 3 and params.shape[2] == 3
    nqubits = params.shape[1]
    qml.AngleEmbedding(inputs, wires=range(nqubits), rotation="X")
    qml.StronglyEntanglingLayers(weights=params, wires=range(nqubits), 
                                     ranges=None, imprimitive=qml.ops.CZ)
    # params' shape: (nlayers, nqubits, 3)
    return [qml.expval(qml.PauliZ(i)) for i in range(nqubits)]
