"""
circuit structures, sub-structure
dict key names for saving weights
"""
import numpy as np
import pennylane as qml
from pennylane import AmplitudeEmbedding, AngleEmbedding
import torch.nn.functional as F
import torch
from math import ceil
from models.layers import *


def pad_input(x, target_dim=270, dim=1):
    # x: (N, n_features)
    cur_dim = x.shape[dim]
    pad_dim = target_dim - cur_dim
    if pad_dim <= 0:
        return x
    pad = [0] * (2 * x.dim())
    pad[-2 * (dim + 1)] = pad_dim
    return F.pad(x, pad)


def encoding_block(features, n_qubits, name='amplitude', weights=None, normalize=True):
    if name == 'amplitude':
        AmplitudeEmbedding(features, wires=range(n_qubits), normalize=normalize, pad_with=0)
    elif name == 'angle':
        # pad
        features = pad_input(features, target_dim=264, dim=0)
        feature_per_qubit = features.shape[0] // n_qubits  # 27
        for q in range(n_qubits):
            for i in range(feature_per_qubit):
                qml.RX(features[q * feature_per_qubit + i], wires=q)
    elif name == 'encoding-first':
        features = pad_input(features, target_dim=264, dim=0)
        feature_per_qubit = features.shape[0] // n_qubits  # 27
        for q in range(n_qubits):
            for f in range(feature_per_qubit):
                qml.RX(1.5 * features[q * feature_per_qubit + f] + weights[q, f], wires=q)


def channel_noise_simple(n_qubits, p=0.5):
    rng = np.random.default_rng(0)
    target = rng.integers(0, n_qubits)
    noise = rng.choice(['depolarizing', 'phase_damping',
                        'bit_flip', 'phase_flip'])

    if noise == "depolarizing":
        if rng.random() < p:
            gate = rng.choice(["X", "Y", "Z"])
            {"X": qml.PauliX, "Y": qml.PauliY, "Z": qml.PauliZ}[gate](target)

    elif noise == "bit_flip":
        if rng.random() < p:
            qml.PauliX(target)

    elif noise == "phase_flip":
        if rng.random() < p:
            qml.PauliZ(target)

    elif noise == "phase_damping":
        lam = (1.0 - np.sqrt(1.0 - p)) / 2.0
        if rng.random() < lam:
            qml.PauliZ(target)


def gate_noise(mu=0, sigma=0.01):
    return random.gauss(mu, sigma)


########## QCL ##########
def QCL_circuit(depth, n_qubits, weights, noise=False):
    for d in range(depth):
        QCL_block(d, n_qubits, weights)
    if noise:
        # channel_noise(n_qubits)
        channel_noise_simple(n_qubits)


########## QCNN ##########
def pure_qcnn_circuit(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2, weights_fc, noise=False):
    # conv1
    QCNN_conv1(n_qubits, weights_conv1)

    # pool1
    QCNN_pool1(n_qubits, weights_pool1)

    # conv2
    QCNN_conv2(n_qubits, weights_conv2)

    # pool2
    QCNN_pool2(n_qubits, weights_pool2)

    # fc
    QCNN_fc(n_qubits, weights_fc)

    if noise:
        # channel_noise(n_qubits)
        channel_noise_simple(n_qubits)

def pure_qcnn_block1(n_qubits, weights_conv1, weights_pool1):
    QCNN_conv1(n_qubits, weights_conv1)
    QCNN_pool1(n_qubits, weights_pool1)

def pure_qcnn_block2(n_qubits, weights_conv1, weights_conv2, weights_pool1, weights_pool2):
    QCNN_conv1(n_qubits, weights_conv1)
    QCNN_pool1(n_qubits, weights_pool1)
    QCNN_conv2(n_qubits, weights_conv2)
    QCNN_pool2(n_qubits, weights_pool2)


def DRNN_circuit(x, weights_input, weights_var, n_qubits, depth, scaling=1.5, ent_train=False, noise=False, target_dim=144):
    """
    Each layer is in the form:
         Data uploading
         Rotation
         Entanglement
    x: (n_features, )
    weights_input: (depth, n_qubits, 3)
    weights_var: (depth, n_qubits, 4) if entangling layer is trainable else (depth, n_qubits, 3)
    n_qubits: number of qubits
    depth: number of layers
    scaling: scaling factor for each feature, e.g. 1.5~3.5
    """

    def entangling_layer():
        # cnot
        if n_qubits < 2:
            return
        # linear entanglement
        for i in range(n_qubits-1):
            qml.CNOT(wires=[i, i+1])
        qml.CNOT(wires=[n_qubits - 1, 0])

    def entangling_layer_trainable(l, weights):
        # cz
        if n_qubits < 2:
            return
        for i in range(n_qubits-1):
            qml.CRZ(weights[l, i, -1], wires=[i, i+1])
        qml.CRZ(weights[l, n_qubits-1, -1], wires=[n_qubits - 1, 0])

    x = pad_input(x, target_dim=target_dim, dim=0)
    n_features = x.shape[0]
    feature_per_layer = ceil(n_features // depth)
    feature_per_qubit = 3
    for l in range(depth):
        # encoding
        for q in range(n_qubits):
            qml.RX(scaling*x[l*feature_per_layer+q*feature_per_qubit+0]+weights_input[l,q,0], wires=q)
            qml.RZ(scaling*x[l*feature_per_layer+q*feature_per_qubit+1]+weights_input[l,q,1], wires=q)
            qml.RX(scaling*x[l*feature_per_layer+q*feature_per_qubit+2]+weights_input[l,q,2], wires=q)
        # rotation (variable parameters)
        for q in range(n_qubits):
            qml.RX(weights_var[l,q,0], wires=q)
            # qml.RY(weights_var[l,q,1], wires=q)
            # qml.RZ(weights_var[l,q,2], wires=q)
        if ent_train:
            entangling_layer_trainable(l, weights_var)
        else:
            entangling_layer()

    if noise:
        # channel_noise(n_qubits)
        channel_noise_simple(n_qubits)


def hqnn_circuit(weights, n_qubits, depth, noise=False):
    for d in range(depth):
        for i in range(n_qubits):
            qml.RY(weights[d, i, 0], wires=i)
            qml.RZ(weights[d, i, 1], wires=i)
            qml.RY(weights[d, i, 2], wires=i)
        for i in range(n_qubits):
            for idx, j in enumerate(range(n_qubits - 1)):
                if i + j + 1 >= n_qubits:
                    j -= n_qubits
                qml.CRZ(weights[d, i, idx + 3], wires=[i, i + j + 1])

        for i in range(n_qubits):
            qml.RY(weights[d, i, -1], wires=i)

    if noise:
        # channel_noise(n_qubits)
        channel_noise_simple(n_qubits)



########## dict ##########
weight_dict = {'qcl': 'ql.weights',
               'qcnn': ['cir.weights_conv1', 'cir.weights_conv2', 'cir.weights_pool1', 'cir.weights_pool2', 'cir.weights_fc'],
               'drnn': ['ql.weights_input', 'ql.weights_var'],
               'hqnn': ['cl.weight', 'cl.bias', 'ql.weights']}
depth_dict = {'qcl': 5,
              'qcnn': 3,  # max value in block_dict
              'drnn': 4,
              'hqnn': 2}
qubit_dict = {'qcl': 8,
              'qcnn': 8,
              'drnn': 6,
              'hqnn': 4}

layer_size_dict = {  # angle encoding layer size
    'qcl': 270,
    'qcnn': 270,
    'drnn': [18, 30, 3],
    'hqnn': 270
}



