# Implementation of Quantum Convolutional Neural Network (QCNN) circuit structure.
import pennylane as qml


class QCNN:
    def __init__(self, args):
        self.n_qubits = args.n_qubits
        self.n_layers = args.n_layers

    def make_qcnn(self,):
        n_layers = self.n_layers
        U_unitary = U_SU4 # 15 params
        V_unitary = Pooling_ansatz1 #2 params
        n_conv_params = 15
        n_pooling_params = 2
        n_params = n_conv_params + n_pooling_params

        conv_layer = self.make_conv_layer(U_unitary) 
        pooling_layer = self.make_pooling_layer(V_unitary)

        def QCNN_structure(params):
            '''    
                params.shape       : (34,)
            '''
            # print(f'QCNN_structure - params.shape: {params.shape}')

            layer_ind = 0
            for _ in range(n_layers):
                conv_param = params[n_params*layer_ind : n_params*layer_ind +n_conv_params]
                pooling_param = params[n_params*layer_ind +n_conv_params: n_params*(layer_ind+1)]

                conv_layer(conv_param)
                pooling_layer(pooling_param)
                layer_ind += 1
        return QCNN_structure 

    def make_conv_layer(self, U_unitary):
        # ctrl_qubits = list(range(self.n_qubits))  # [0, 1, ... n-1, n]
        # target_qubits = ctrl_qubits[1:] + [0]     # [1, 2, ... ,n, 0]
        n_qubits = self.n_qubits
        def conv_layer(params):
            '''
            n_qubits: Even Number
            '''
            for i in range(0, n_qubits, 2):
                U_unitary(params, wires=[i, (i+1) % n_qubits])
            for i in range(1, n_qubits, 2):
                U_unitary(params, wires=[i, (i+1) % n_qubits])
        return conv_layer

    def make_pooling_layer(self, V_unitary):
        n_qubits = self.n_qubits
        def pooling_layer(params):
            for i in range(0, n_qubits, 2):
                V_unitary(params, wires=[i, (i+1) % n_qubits])
        return pooling_layer

    
class QCNNNN:
    def __init__(self, args):
        self.n_qubits = args.n_qubits
        self.n_layers = args.n_layers

    def make_qcnn(self):
        n_qubits = self.n_qubits
        n_layers = self.n_layers
        U_unitary = U_SU4 # 15 params
        V_unitary = Pooling_ansatz1 #2 params
        shape_function = lambda angle: angle.reshape((1, self.n_qubits, 3))

        conv_layer = self.make_conv_layer(U_unitary) 
        pooling_layer = self.make_pooling_layer(V_unitary)

        n_conv_params = 15
        n_pooling_params = 2
        n_qnn_params = 3 * self.n_qubits
        n_params = n_conv_params + n_pooling_params + n_qnn_params

        def QCNN_structure(params):
            '''    
                params.shape       : (34,)
            '''
            # print(f'QCNN_structure - params.shape: {params.shape}')

            layer_ind = 0
            for _ in range(n_layers):
                conv_param = params[n_params*layer_ind : n_params*layer_ind +n_conv_params]
                pooling_param = params[n_params*layer_ind +n_conv_params: n_params*layer_ind +n_conv_params+n_pooling_params]
                qnn_param = params[n_params*layer_ind +n_conv_params+n_pooling_params : n_params*(layer_ind+1)]
                qnn_param = shape_function(qnn_param)

                conv_layer(conv_param)
                pooling_layer(pooling_param)
                qml.StronglyEntanglingLayers(weights=qnn_param, wires=range(n_qubits))
                layer_ind += 1
        return QCNN_structure 

    def make_conv_layer(self, U_unitary):
        # ctrl_qubits = list(range(self.n_qubits))  # [0, 1, ... n-1, n]
        # target_qubits = ctrl_qubits[1:] + [0]     # [1, 2, ... ,n, 0]
        n_qubits = self.n_qubits
        def conv_layer(params):
            '''
            n_qubits: Even Number
            '''
            for i in range(0, n_qubits, 2):
                U_unitary(params, wires=[i, (i+1) % n_qubits])
            for i in range(1, n_qubits, 2):
                U_unitary(params, wires=[i, (i+1) % n_qubits])
        return conv_layer

    def make_pooling_layer(self, V_unitary):
        n_qubits = self.n_qubits
        def pooling_layer(params):
            for i in range(0, n_qubits, 2):
                V_unitary(params, wires=[i, (i+1) % n_qubits])
        return pooling_layer




##### QCNN Backbone #####
def QCNN_structure_without_pooling(U, params, n_params):
    param1 = params[0:n_params]
    param2 = params[n_params: 2 * n_params]
    param3 = params[2 * n_params: 3 * n_params]

    conv_layer1(U, param1)
    conv_layer2(U, param2)
    conv_layer3(U, param3)

def QCNN_1D_circuit(U, params, n_params):
    param1 = params[0: n_params]
    param2 = params[n_params: 2*n_params]
    param3 = params[2*n_params: 3*n_params]

    for i in range(0, 8, 2):
        U(param1, wires=[i, i + 1])
    for i in range(1, 7, 2):
        U(param1, wires=[i, i + 1])

    U(param2, wires=[2,3])
    U(param2, wires=[4,5])
    U(param3, wires=[3,4])


def QCNN_structure_without_pooling(U, params, n_params):
    param1 = params[0:n_params]
    param2 = params[n_params: 2 * n_params]
    param3 = params[2 * n_params: 3 * n_params]

    conv_layer1(U, param1)
    conv_layer2(U, param2)
    conv_layer3(U, param3)

def QCNN_1D_circuit(U, params, n_params):
    param1 = params[0: n_params]
    param2 = params[n_params: 2*n_params]
    param3 = params[2*n_params: 3*n_params]

    for i in range(0, 8, 2):
        U(param1, wires=[i, i + 1])
    for i in range(1, 7, 2):
        U(param1, wires=[i, i + 1])

    U(param2, wires=[2,3])
    U(param2, wires=[4,5])

    U(param3, wires=[3,4])


# Unitary Ansatze for Convolutional Layer
def U_TTN(params, wires):  # 2 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])


def U_5(params, wires):  # 10 params
    qml.RX(params[0], wires=wires[0])
    qml.RX(params[1], wires=wires[1])
    qml.RZ(params[2], wires=wires[0])
    qml.RZ(params[3], wires=wires[1])
    qml.CRZ(params[4], wires=[wires[1], wires[0]])
    qml.CRZ(params[5], wires=[wires[0], wires[1]])
    qml.RX(params[6], wires=wires[0])
    qml.RX(params[7], wires=wires[1])
    qml.RZ(params[8], wires=wires[0])
    qml.RZ(params[9], wires=wires[1])


def U_6(params, wires):  # 10 params
    qml.RX(params[0], wires=wires[0])
    qml.RX(params[1], wires=wires[1])
    qml.RZ(params[2], wires=wires[0])
    qml.RZ(params[3], wires=wires[1])
    qml.CRX(params[4], wires=[wires[1], wires[0]])
    qml.CRX(params[5], wires=[wires[0], wires[1]])
    qml.RX(params[6], wires=wires[0])
    qml.RX(params[7], wires=wires[1])
    qml.RZ(params[8], wires=wires[0])
    qml.RZ(params[9], wires=wires[1])


def U_9(params, wires):  # 2 params
    qml.Hadamard(wires=wires[0])
    qml.Hadamard(wires=wires[1])
    qml.CZ(wires=[wires[0], wires[1]])
    qml.RX(params[0], wires=wires[0])
    qml.RX(params[1], wires=wires[1])


def U_13(params, wires):  # 6 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CRZ(params[2], wires=[wires[1], wires[0]])
    qml.RY(params[3], wires=wires[0])
    qml.RY(params[4], wires=wires[1])
    qml.CRZ(params[5], wires=[wires[0], wires[1]])


def U_14(params, wires):  # 6 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CRX(params[2], wires=[wires[1], wires[0]])
    qml.RY(params[3], wires=wires[0])
    qml.RY(params[4], wires=wires[1])
    qml.CRX(params[5], wires=[wires[0], wires[1]])


def U_15(params, wires):  # 4 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CNOT(wires=[wires[1], wires[0]])
    qml.RY(params[2], wires=wires[0])
    qml.RY(params[3], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])


def U_SO4(params, wires):  # 6 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.RY(params[2], wires=wires[0])
    qml.RY(params[3], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.RY(params[4], wires=wires[0])
    qml.RY(params[5], wires=wires[1])


def U_SU4(params, wires): # 15 params
    qml.U3(params[0], params[1], params[2], wires=wires[0])
    qml.U3(params[3], params[4], params[5], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.RY(params[6], wires=wires[0])
    qml.RZ(params[7], wires=wires[1])
    qml.CNOT(wires=[wires[1], wires[0]])
    qml.RY(params[8], wires=wires[0])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.U3(params[9], params[10], params[11], wires=wires[0])
    qml.U3(params[12], params[13], params[14], wires=wires[1])

# Pooling Layer

def Pooling_ansatz1(params, wires): #2 params
    qml.CRZ(params[0], wires=[wires[0], wires[1]])
    qml.PauliX(wires=wires[0])
    qml.CRX(params[1], wires=[wires[0], wires[1]])

def Pooling_ansatz2(wires): #0 params
    qml.CRZ(wires=[wires[0], wires[1]])

def Pooling_ansatz3(*params, wires): #3 params
    qml.CRot(*params, wires=[wires[0], wires[1]])


# Convolutional layers
def conv_layer1(U, params):
    U(params, wires=[0, 7])
    for i in range(0, 8, 2):
        # [0,1], [2,3], [4,5], [6,7]
        U(params, wires=[i, i + 1])
    for i in range(1, 7, 2):
        # [1,2], [3,4], [5,6]
        U(params, wires=[i, i + 1])
def conv_layer2(U, params):
    U(params, wires=[0, 6])
    U(params, wires=[0, 2])
    U(params, wires=[4, 6])
    U(params, wires=[2, 4])
def conv_layer3(U, params):
    U(params, wires=[0,4])

# Pooling layers
def pooling_layer1(V, params):
    for i in range(0, 8, 2):
        print(f'pooling_layer1: i == {i}')
        V(params, wires=[i + 1, i])
def pooling_layer2(V, params):
    V(params, wires=[2,0])
    V(params, wires=[6,4])
def pooling_layer3(V, params):
    V(params, wires=[0,4])