# Implementation of Quantum Convolutional Neural Network (QCNN) circuit structure.
import pennylane as qml


class QCNN2:
    def __init__(self):
        # self.n_qubits = args.n_qubits
        n_qubits = 2
    def make_qcnn(self,):
        n_params = 15
        U_unitary = U_SU4 # 15 params
        V_unitary = Pooling_ansatz1 #2 params

        conv_layer1 = self.make_conv_layer1(U_unitary) 
        pooling_layer1 = self.make_pooling_layer1(V_unitary)
        conv_layer2 = self.make_conv_layer2(U_unitary) 
        pooling_layer2 = self.make_pooling_layer2(V_unitary)

        def QCNN_structure1(params):
            conv_param1 = params[0:n_params]                         
            pooling_param1 = params[n_params: n_params+2]
            '''
                0:15
                15:17
            '''
            conv_layer1(conv_param1)
            pooling_layer1(pooling_param1)

        def QCNN_structure2(params):
            conv_param1 = params[0:n_params]                         
            conv_param2 = params[n_params:2*n_params]
            pooling_param1 = params[2*n_params: 2*n_params+2]   
            pooling_param2 = params[2*n_params+2: 2*n_params+4]
            '''
                0:15
                15:30
                30:32
                32:34
            '''
            conv_layer1(conv_param1)
            pooling_layer1(pooling_param1)
            conv_layer2(conv_param2)
            pooling_layer2(pooling_param2)

        return QCNN_structure1

    def make_conv_layer1(self, U_unitary):
        def conv_layer(params):
            U_unitary(params, wires=[0, 1])
        return conv_layer
    
    def make_pooling_layer1(self, V_unitary):
        def pooling_layer(params):
            V_unitary(params, wires=[0,1])
        return pooling_layer
    
    def make_conv_layer2(self, U_unitary):
        def conv_layer(params):
            U_unitary(params, wires=[1,0])
        return conv_layer
    
    def make_pooling_layer2(self, V_unitary):
        def pooling_layer(params):
            V_unitary(params, wires=[1,0])
        return pooling_layer
    
    
class QCNN3:
    def __init__(self):
        n_qubits = 3
    def make_qcnn(self,):
        # U_unitary = U_SU4 # 15 params
        U_unitary = U_TTN # 2 params
        V_unitary = Pooling_ansatz1 #2 params

        conv_layer1 = self.make_conv_layer1(U_unitary) 
        conv_layer2 = self.make_conv_layer2(U_unitary)

        pooling_layer1 = self.make_pooling_layer1(V_unitary)
        pooling_layer2 = self.make_pooling_layer2(V_unitary)

        def QCNN_structure(params):

            conv_param1 = params[0:2]
            pooling_param1 = params[2:4]

            conv_layer1(conv_param1)
            pooling_layer1(pooling_param1)

        return QCNN_structure

    def make_conv_layer1(self, U_unitary):
        def conv_layer(params):
            U_unitary(params, wires=[0, 1])
            U_unitary(params, wires=[1, 2])
            U_unitary(params, wires=[2, 0])
        return conv_layer
    
    def make_pooling_layer1(self, V_unitary):
        def pooling_layer(params):
            V_unitary(params, wires=[0,1])
        return pooling_layer
    
    def make_conv_layer2(self,U_unitary):
        def conv_layer(params):
            U_unitary(params, wires=[1, 2])
        return conv_layer
    
    def make_pooling_layer2(self, V_unitary):
        def pooling_layer(params):
            V_unitary(params, wires=[1, 2])
        return pooling_layer


    
class QCNN4:
    def __init__(self, args):
        # self.n_qubits = args.n_qubits
        print(f'QCNN n qubit == 4')
        n_qubits = 4
        self.n_layers = args.n_layers
    def make_qcnn(self,):
        n_params = 15
        U_unitary = U_SU4 # 15 params
        V_unitary = Pooling_ansatz1 #2 params

        conv_layer1 = self.make_conv_layer1(U_unitary) 
        conv_layer2 = self.make_conv_layer2(U_unitary)

        pooling_layer1 = self.make_pooling_layer1(V_unitary)
        pooling_layer2 = self.make_pooling_layer2(V_unitary)

        def QCNN_structure(params):
            '''    
                params.shape       : (34,)
            '''
            # print(f'QCNN_structure - params.shape: {params.shape}')
            conv_param1 = params[0:n_params]                         # [0  :15 ]
            conv_param2 = params[n_params: 2 * n_params]             # [15 :30 ]
            conv_param3 = params[2* n_params: 3 * n_params]          # [30 :45 ]

            # pooling_param1 = params[2 * n_params: 2 * n_params + 2]     # [30 :32 ]
            # pooling_param2 = params[2 * n_params + 2: 2 * n_params + 4] # [32 :34 ]

            conv_layer1(conv_param1)
            # pooling_layer1(pooling_param1)

            conv_layer1(conv_param2)
            conv_layer1(conv_param3)
            # pooling_layer2(pooling_param2)

        return QCNN_structure 

    def make_conv_layer1(self, U_unitary):
        # ctrl_qubits = list(range(self.n_qubits))  # [0, 1, ... n-1, n]
        # target_qubits = ctrl_qubits[1:] + [0]     # [1, 2, ... ,n, 0]
        def conv_layer(params):
            U_unitary(params, wires=[0, 1])
            U_unitary(params, wires=[2, 3])

            U_unitary(params, wires=[1, 2])
            U_unitary(params, wires=[3, 0])
        return conv_layer
    
    def make_pooling_layer1(self, V_unitary):
        def pooling_layer(params):
            V_unitary(params, wires=[0,1])
            V_unitary(params, wires=[2,3])
        return pooling_layer
    
    def make_conv_layer2(self,U_unitary):
        def conv_layer(params):
            U_unitary(params, wires=[1, 3])
            U_unitary(params, wires=[3, 1])
        return conv_layer
    
    def make_pooling_layer2(self, V_unitary):
        def pooling_layer(params):
            V_unitary(params, wires=[1, 3])
        return pooling_layer



class QCNN6:
    def __init__(self):
        # self.n_qubits = args.n_qubits
        n_qubits = 6
    def make_qcnn(self,):
        n_params = 15
        U_unitary = U_SU4 # 15 params
        V_unitary = Pooling_ansatz1 #2 params

        conv_layer1 = self.make_conv_layer1(U_unitary) 
        # conv_layer2 = self.make_conv_layer2(U_unitary)
        # pooling_layer1 = self.make_pooling_layer1(V_unitary)
        # pooling_layer2 = self.make_pooling_layer2(V_unitary)

        def QCNN_structure(params):
            '''    
                params.shape       : (34,)
            '''
            print(f'QCNN_structure - params.shape: {params.shape}')
            conv_param1 = params[0:n_params]                         # [0  :15 ]
            # conv_param2 = params[n_params: 2 * n_params]             # [15 :30 ]

            # pooling_param1 = params[2 * n_params: 2 * n_params + 2]     # [30 :32 ]
            # pooling_param2 = params[2 * n_params + 2: 2 * n_params + 4] # [32 :34 ]

            conv_layer1(conv_param1)
            # pooling_layer1(pooling_param1)
            # conv_layer2(conv_param2)
            # pooling_layer2(pooling_param2)
        return QCNN_structure 

    def make_conv_layer1(self, U_unitary):
        # ctrl_qubits = list(range(self.n_qubits))  # [0, 1, ... n-1, n]
        # target_qubits = ctrl_qubits[1:] + [0]     # [1, 2, ... ,n, 0]
        def conv_layer(params):
            U_unitary(params, wires=[0, 1])
            U_unitary(params, wires=[2, 3])
            U_unitary(params, wires=[4, 5])

            U_unitary(params, wires=[1, 2])
            U_unitary(params, wires=[3, 4])
            U_unitary(params, wires=[5, 0])
        return conv_layer    


class QCNN8:
    def __init__(self):
        # self.n_qubits = args.n_qubits
        n_qubits = 8
    def make_qcnn(self, n_layers ):
        n_params = 15
        U_unitary = U_SU4 # 15 params
        V_unitary = Pooling_ansatz1 #2 params

        conv_layer1 = self.make_conv_layer1(U_unitary) 
        # conv_layer2 = self.make_conv_layer2(U_unitary)

        # pooling_layer1 = self.make_pooling_layer1(V_unitary)
        # pooling_layer2 = self.make_pooling_layer2(V_unitary)

        def QCNN_structure(params):
            '''    
                params.shape       : (34,)
            '''
            # print(f'params : {params.shape}')
            
            layer_ind = 0
            for n in range(n_layers):
                conv_param = params[n_params*layer_ind : n_params*(layer_ind+1)]                         # [0  :15 ]
                # conv_param2 = params[n_params: 2 * n_params]             # [15 :30 ]
                # pooling_param1 = params[2 * n_params: 2 * n_params + 2]     # [30 :32 ]
                # pooling_param2 = params[2 * n_params + 2: 2 * n_params + 4] # [32 :34 ]

                conv_layer1(conv_param)
                # pooling_layer1(pooling_param1)
                # conv_layer2(conv_param2)
                # pooling_layer2(pooling_param2)
                layer_ind += 1
        return QCNN_structure 

    def make_conv_layer1(self, U_unitary):
        # ctrl_qubits = list(range(self.n_qubits))  # [0, 1, ... n-1, n]
        # target_qubits = ctrl_qubits[1:] + [0]     # [1, 2, ... ,n, 0]
        def conv_layer(params):
            U_unitary(params, wires=[0, 1])
            U_unitary(params, wires=[2, 3])
            U_unitary(params, wires=[4, 5])
            U_unitary(params, wires=[6, 7])

            U_unitary(params, wires=[1, 2])
            U_unitary(params, wires=[3, 4])
            U_unitary(params, wires=[5, 6])
            U_unitary(params, wires=[7, 0])
        return conv_layer


class QCNN:
    def __init__(self, args):
        self.n_qubits = args.n_qubits
        self.n_layers = args.n_layers

    def make_qcnn(self,):
        n_layers = self.n_layers
        U_unitary = U_SU4 # 15 params
        V_unitary = Pooling_ansatz1 #2 params
        n_conv_params = 15
        n_pooling_params = 2
        n_params = n_conv_params + n_pooling_params

        conv_layer = self.make_conv_layer(U_unitary) 
        pooling_layer = self.make_pooling_layer(V_unitary)

        def QCNN_structure(params):
            '''    
                params.shape       : (34,)
            '''
            # print(f'QCNN_structure - params.shape: {params.shape}')

            layer_ind = 0
            for _ in range(n_layers):
                conv_param = params[n_params*layer_ind : n_params*layer_ind +n_conv_params]
                pooling_param = params[n_params*layer_ind +n_conv_params: n_params*(layer_ind+1)]

                conv_layer(conv_param)
                pooling_layer(pooling_param)
                layer_ind += 1
        return QCNN_structure 

    def make_conv_layer(self, U_unitary):
        # ctrl_qubits = list(range(self.n_qubits))  # [0, 1, ... n-1, n]
        # target_qubits = ctrl_qubits[1:] + [0]     # [1, 2, ... ,n, 0]
        n_qubits = self.n_qubits
        def conv_layer(params):
            '''
            n_qubits: Even Number
            '''
            for i in range(0, n_qubits, 2):
                U_unitary(params, wires=[i, (i+1) % n_qubits])
            for i in range(1, n_qubits, 2):
                U_unitary(params, wires=[i, (i+1) % n_qubits])
        return conv_layer

    def make_pooling_layer(self, V_unitary):
        n_qubits = self.n_qubits
        def pooling_layer(params):
            for i in range(0, n_qubits, 2):
                V_unitary(params, wires=[i, (i+1) % n_qubits])
        return pooling_layer




##### QCNN Backbone #####

def QCNN_structure_without_pooling(U, params, n_params):
    param1 = params[0:n_params]
    param2 = params[n_params: 2 * n_params]
    param3 = params[2 * n_params: 3 * n_params]

    conv_layer1(U, param1)
    conv_layer2(U, param2)
    conv_layer3(U, param3)

def QCNN_1D_circuit(U, params, n_params):
    param1 = params[0: n_params]
    param2 = params[n_params: 2*n_params]
    param3 = params[2*n_params: 3*n_params]

    for i in range(0, 8, 2):
        U(param1, wires=[i, i + 1])
    for i in range(1, 7, 2):
        U(param1, wires=[i, i + 1])

    U(param2, wires=[2,3])
    U(param2, wires=[4,5])
    U(param3, wires=[3,4])


    

def QCNN_structure_without_pooling(U, params, n_params):
    param1 = params[0:n_params]
    param2 = params[n_params: 2 * n_params]
    param3 = params[2 * n_params: 3 * n_params]

    conv_layer1(U, param1)
    conv_layer2(U, param2)
    conv_layer3(U, param3)

def QCNN_1D_circuit(U, params, n_params):
    param1 = params[0: n_params]
    param2 = params[n_params: 2*n_params]
    param3 = params[2*n_params: 3*n_params]

    for i in range(0, 8, 2):
        U(param1, wires=[i, i + 1])
    for i in range(1, 7, 2):
        U(param1, wires=[i, i + 1])

    U(param2, wires=[2,3])
    U(param2, wires=[4,5])

    U(param3, wires=[3,4])


# Unitary Ansatze for Convolutional Layer
def U_TTN(params, wires):  # 2 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])


def U_5(params, wires):  # 10 params
    qml.RX(params[0], wires=wires[0])
    qml.RX(params[1], wires=wires[1])
    qml.RZ(params[2], wires=wires[0])
    qml.RZ(params[3], wires=wires[1])
    qml.CRZ(params[4], wires=[wires[1], wires[0]])
    qml.CRZ(params[5], wires=[wires[0], wires[1]])
    qml.RX(params[6], wires=wires[0])
    qml.RX(params[7], wires=wires[1])
    qml.RZ(params[8], wires=wires[0])
    qml.RZ(params[9], wires=wires[1])


def U_6(params, wires):  # 10 params
    qml.RX(params[0], wires=wires[0])
    qml.RX(params[1], wires=wires[1])
    qml.RZ(params[2], wires=wires[0])
    qml.RZ(params[3], wires=wires[1])
    qml.CRX(params[4], wires=[wires[1], wires[0]])
    qml.CRX(params[5], wires=[wires[0], wires[1]])
    qml.RX(params[6], wires=wires[0])
    qml.RX(params[7], wires=wires[1])
    qml.RZ(params[8], wires=wires[0])
    qml.RZ(params[9], wires=wires[1])


def U_9(params, wires):  # 2 params
    qml.Hadamard(wires=wires[0])
    qml.Hadamard(wires=wires[1])
    qml.CZ(wires=[wires[0], wires[1]])
    qml.RX(params[0], wires=wires[0])
    qml.RX(params[1], wires=wires[1])


def U_13(params, wires):  # 6 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CRZ(params[2], wires=[wires[1], wires[0]])
    qml.RY(params[3], wires=wires[0])
    qml.RY(params[4], wires=wires[1])
    qml.CRZ(params[5], wires=[wires[0], wires[1]])


def U_14(params, wires):  # 6 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CRX(params[2], wires=[wires[1], wires[0]])
    qml.RY(params[3], wires=wires[0])
    qml.RY(params[4], wires=wires[1])
    qml.CRX(params[5], wires=[wires[0], wires[1]])


def U_15(params, wires):  # 4 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CNOT(wires=[wires[1], wires[0]])
    qml.RY(params[2], wires=wires[0])
    qml.RY(params[3], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])


def U_SO4(params, wires):  # 6 params
    qml.RY(params[0], wires=wires[0])
    qml.RY(params[1], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.RY(params[2], wires=wires[0])
    qml.RY(params[3], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.RY(params[4], wires=wires[0])
    qml.RY(params[5], wires=wires[1])


def U_SU4(params, wires): # 15 params
    qml.U3(params[0], params[1], params[2], wires=wires[0])
    qml.U3(params[3], params[4], params[5], wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.RY(params[6], wires=wires[0])
    qml.RZ(params[7], wires=wires[1])
    qml.CNOT(wires=[wires[1], wires[0]])
    qml.RY(params[8], wires=wires[0])
    qml.CNOT(wires=[wires[0], wires[1]])
    qml.U3(params[9], params[10], params[11], wires=wires[0])
    qml.U3(params[12], params[13], params[14], wires=wires[1])

# Pooling Layer

def Pooling_ansatz1(params, wires): #2 params
    qml.CRZ(params[0], wires=[wires[0], wires[1]])
    qml.PauliX(wires=wires[0])
    qml.CRX(params[1], wires=[wires[0], wires[1]])

def Pooling_ansatz2(wires): #0 params
    qml.CRZ(wires=[wires[0], wires[1]])

def Pooling_ansatz3(*params, wires): #3 params
    qml.CRot(*params, wires=[wires[0], wires[1]])


# Convolutional layers
def conv_layer1(U, params):
    U(params, wires=[0, 7])
    for i in range(0, 8, 2):
        # [0,1], [2,3], [4,5], [6,7]
        U(params, wires=[i, i + 1])
    for i in range(1, 7, 2):
        # [1,2], [3,4], [5,6]
        U(params, wires=[i, i + 1])
def conv_layer2(U, params):
    U(params, wires=[0, 6])
    U(params, wires=[0, 2])
    U(params, wires=[4, 6])
    U(params, wires=[2, 4])
def conv_layer3(U, params):
    U(params, wires=[0,4])

# Pooling layers
def pooling_layer1(V, params):
    for i in range(0, 8, 2):
        print(f'pooling_layer1: i == {i}')
        V(params, wires=[i + 1, i])
def pooling_layer2(V, params):
    V(params, wires=[2,0])
    V(params, wires=[6,4])
def pooling_layer3(V, params):
    V(params, wires=[0,4])