from typing import List, Optional

import pennylane as qml
import torch

from configs import NUM_QUBITS, Q_DEVICE, TOP_TYPE


def RY_layer(weights):
    # weights: (n_qubits, )
    for i in range(weights.shape[0]):
        qml.RY(weights[i], wires=i)


def RZ_layer(weights):
    for i in range(weights.shape[0]):
        qml.RZ(weights[i], wires=i)


def EntanglingLayer(n_qubits, top_type=None):
    if top_type is None:
        for i in range(0, n_qubits - 1, 2):
            qml.CNOT(wires=[i, i + 1])

        for i in range(1, n_qubits - 1, 2):
            qml.CNOT(wires=[i, i + 1])
    elif top_type is not None:
        top_type = [[0, 3], [1, 3], [2, 3]]
        for top in top_type:
            qml.CNOT(wires=top)


def HadamardLayer(n_qubits):
    for i in range(n_qubits):
        qml.Hadamard(wires=i)


def single_aae_layer(num_qubits, top_type, weights):
    """Apply a single AAE layer.

    Args:
        num_qubits: Number of qubits in the quantum circuit
        top_type: topology type
        weights: encoder parameters
    """
    RY_layer(weights=weights)
    EntanglingLayer(num_qubits, top_type)


@qml.qnode(
    Q_DEVICE, interface="torch", diff_method="backprop"
)  # qml.state() only supported for backprop
def aae_encoder(inputs=None, weights=None):
    # to use StronglyEntanglingLayers
    # input_weights = torch.zeros(weights.shape + (3,))
    # input_weights[:,:,1] = weights
    # StronglyEntanglingLayers(input_weights, wires=range(n_qubits))  # only rotate along Y axis

    inputs = torch.tensor(
        [0.0], dtype=torch.float32
    )  # inputs of encoder doesn't exist (or always [1, 0, 0, ...]), but pennylane need this arg
    for l in range(weights.shape[0]):
        single_aae_layer(NUM_QUBITS, TOP_TYPE, weights[l])

    # post selection to deal with negative number in AAE
    # qml.Hadamard(wires=n_qubits-1)
    # qml.measure(wires=n_qubits-1, postselect=1)

    return qml.state()


@qml.qnode(
    Q_DEVICE, interface="torch", diff_method="backprop"
)  # qml.state() only supported for backprop
def aae_encoder_hadamard(inputs, weights):
    # to use StronglyEntanglingLayers
    # input_weights = torch.zeros(weights.shape + (3,))
    # input_weights[:,:,1] = weights
    # StronglyEntanglingLayers(input_weights, wires=range(n_qubits))  # only rotate along Y axis

    inputs = torch.tensor(
        [0.0], dtype=torch.float32
    )  # inputs of encoder doesn't exist (or always [1, 0, 0, ...]), but pennylane need this arg
    for l in range(weights.shape[0]):
        RY_layer(weights=weights[l])
        EntanglingLayer(NUM_QUBITS)

    # post selection to deal with negative number in AAE
    # qml.Hadamard(wires=n_qubits-1)
    # qml.measure(wires=n_qubits-1, postselect=1)
    HadamardLayer(NUM_QUBITS)
    return qml.state()


@qml.qnode(
    Q_DEVICE, interface="torch", diff_method="backprop"
)  # qml.state() only supported for backprop
def single_layer_aae(input_state, weights):
    qml.QubitStateVector(input_state, wires=range(NUM_QUBITS))
    single_aae_layer(NUM_QUBITS, TOP_TYPE, weights)
    return qml.state()


@qml.qnode(
    Q_DEVICE, interface="torch", diff_method="backprop"
)  # qml.state() only supported for backprop
def single_layer_aae_inverse(input_state, weights):
    qml.QubitStateVector(input_state, wires=range(NUM_QUBITS))
    qml.adjoint(single_aae_layer(NUM_QUBITS, TOP_TYPE, weights))
    return qml.state()


@qml.qnode(
    Q_DEVICE, interface="torch", diff_method="backprop"
)  # qml.state() only supported for backprop
def single_layer_aae_raw_circuit(weights):
    single_aae_layer(NUM_QUBITS, TOP_TYPE, weights)
    return qml.state()


ENCODERS = {
    "default": aae_encoder,
    "hadamard": aae_encoder_hadamard,
    "layer": single_layer_aae,
    "inverse_layer": single_layer_aae_inverse,
    "raw_layer": single_layer_aae_raw_circuit,
}


# AAE encoder
# TODO: remove unnecessary arguments
def get_aae_encoder(
    num_qubits: int = 4,
    structure: str = "default",
    top_type: Optional[List[int]] = None,
    single_layer_callable=None,
):
    _single_layer_callable = single_aae_layer
    if single_layer_callable is not None:
        _single_layer_callable = single_layer_callable

    if structure not in ENCODERS:
        raise ValueError(f"Unsupported structure: {structure}")

    return ENCODERS[structure]
