import pennylane as qml

# from configs import NUM_QUBITS, Q_DEVICE, TOP_TYPE


# ## Model
def RY_layer(weights, to_float: bool = False):  # modify for batched inputs
    # weights: (B, n_qubits)
    for i in range(weights.shape[1]):
        theta = weights[:, i]
        if to_float:
            theta = theta.tolist()[0]
        qml.RY(theta, wires=i)


def EntanglingLayer(n_qubits):
    for i in range(0, n_qubits - 1, 2):
        qml.CNOT(wires=[i, i + 1])

    for i in range(1, n_qubits - 1, 2):
        qml.CNOT(wires=[i, i + 1])


def HadamardLayer(n_qubits):
    for i in range(n_qubits):
        qml.Hadamard(wires=i)


def aae_encoder_for_train(
    encoder_params, num_layers, num_qubits, to_float: bool = False
):  # modify original aae_encoder for batched inputs
    """
    Args:
        encoder_params: 1 dimensional vector of rotation angles
        num_layers: number of RY+CX layer
        num_qubits: number of qubits
        to_float: whether to transform torch tensor to float value, this
            is useful when transform qnode to qasm string, note that in
            this case the batch size if forced to be 1.
    """
    # expecting (B, n_layers*n_qubits)
    inputs_shape = encoder_params.shape
    for l in range(num_layers):
        encoder_params = encoder_params.view(inputs_shape[0], num_layers, num_qubits)
        RY_layer(weights=encoder_params[:, l], to_float=to_float)
        EntanglingLayer(num_qubits)

    encoder_params = encoder_params.view(inputs_shape)

    # post selection to deal with negative number in AAE
    # qml.Hadamard(wires=n_qubits-1)
    # qml.measure(wires=n_qubits-1, postselect=1)

    # return qml.state()


def single_aae_layer(num_qubits, top_type, weights):
    """Apply a single AAE layer.

    Args:
        num_qubits: Number of qubits in the quantum circuit
        top_type: topology type
        weights: encoder parameters
    """
    RY_layer(weights=weights)
    EntanglingLayer(num_qubits)


# @qml.qnode(
#     Q_DEVICE, interface="torch", diff_method="backprop"
# )  # qml.state() only supported for backprop
# def batch_single_layer_aae_raw_circuit(weights):
#     single_aae_layer(NUM_QUBITS, TOP_TYPE, weights)
#     return qml.state()


# # wrap around aae_encoder_for_train
# @qml.qnode(Q_DEVICE, interface="torch", diff_method="backprop")
# @qml.simplify
# def aae_encoder(
#     encoder_params, num_layers, num_qubits
# ):  # TODO: need refactor for batch_encoders.py and encoders.py
#     aae_encoder_for_train(encoder_params, num_layers, num_qubits)
#     return qml.state()
