
import pennylane as qml
from pennylane.templates import AmplitudeEmbedding
from pennylane import numpy as np
import matplotlib.pyplot as plt
import random
import numpy.linalg as la
from pennylane.templates import BasicEntanglerLayers
import math
import random
from math import pi
from pennylane.init import basic_entangler_layers_normal
from pennylane.templates.layers import RandomLayers
from pennylane.ops import CNOT, RX, RY, RZ


# np.random.seed(42)

def RY(input_state, num_qubit, wire, theta): #RY=e^{-i (theta) sigma_y} rotation, on the wire-th qubit
    output_state = np.copy(input_state)
    if(len(input_state)!= 2**num_qubit):
        print('input state dimension not proper')
    for i in range(2**num_qubit):
        if((int(i/(2**(num_qubit-1-wire))))%2 == 0):
            k = i + 2**(num_qubit-1-wire)
            output_state[i] = math.cos(theta)*input_state[i] - math.sin(theta)*input_state[k]
            output_state[k] = math.sin(theta)*input_state[i] + math.cos(theta)*input_state[k]
    return output_state

def CX(input_state, num_qubit, wire_1, wire_2): #CX gate, wire_1 is the control qubit, wire_2 is the controlled qubit
    output_state = np.copy(input_state)
    if(len(input_state)!= 2**num_qubit):
        print('input state dimension not proper')
    for i in range(2**num_qubit):
        if((int(i/(2**(num_qubit-1-wire_1))))%2 == 0 and (int(i/(2**(num_qubit-1-wire_2))))%2 == 0):
            i_1 = i
            i_2 = i + 2**(num_qubit-1-wire_2)
            i_3 = i + 2**(num_qubit-1-wire_1)
            i_4 = i + 2**(num_qubit-1-wire_1) + 2**(num_qubit-1-wire_2)
            output_state[i_3] = input_state[i_4]
            output_state[i_4] = input_state[i_3]
    return output_state

def CZ(input_state, num_qubit, wire_1, wire_2): #CZ gate, wire_1 is the control qubit, wire_2 is the controlled qubit
    output_state = np.copy(input_state)
    if(len(input_state)!= 2**num_qubit):
        print('input state dimension not proper')
    for i in range(2**num_qubit):
        if((int(i/(2**(num_qubit-1-wire_1))))%2 == 0 and (int(i/(2**(num_qubit-1-wire_2))))%2 == 0):
            i_4 = i + 2**(num_qubit-1-wire_1) + 2**(num_qubit-1-wire_2)
            output_state[i_4] = -1* input_state[i_4]
    return output_state

def SigmaZ(input_state, num_qubit, wire): # measurement result of the sigma_z at the wire-th qubit
    a = 0
    for i in range(2**num_qubit):
        sign = (int(i/(2**(num_qubit-1-wire))))%2
        a = a + (1-2*sign)*(input_state[i]**2)
    return a

def SigmaX(input_state, num_qubit, wire): # measurement result of the sigma_x at the wire-th qubit
    a = 0
    for i in range(2**num_qubit):
        if((int(i/(2**(num_qubit-1-wire))))%2 == 0):
            k = i + 2**(num_qubit-1-wire)
            a = a + 2*input_state[i]*input_state[k]
    return a