import numpy as np
import matplotlib.pyplot as plt
import random
import numpy.linalg as la
import math
import random
from math import pi

import quantum_gate_operation as opera
from datetime import datetime

# np.random.seed(42)

def alter_layer_circuit(f, weights, num_qubit):#alternating_layer, qubit number is even
    num_params = len(weights)
    block_layer = int((num_params-num_qubit)/(2*num_qubit))
    params = weights.reshape((2*block_layer+1, num_qubit))
    state = np.copy(f)
    for i in range(num_qubit):
        state = opera.RY(state, num_qubit, i, params[0][i])

    for j in range(block_layer):
        for i in range(int(num_qubit/2)):
            state = opera.CX(state, num_qubit, 2*i, 2*i+1)
            state = opera.RY(state, num_qubit, 2*i, params[2*j+1][2*i])
            state = opera.RY(state, num_qubit, 2*i+1, params[2*j+1][2*i+1])
        for i in range(int(num_qubit/2)):
            state = opera.CX(state, num_qubit, 2*i+1, (2*i+2)%num_qubit)
            state = opera.RY(state, num_qubit, 2*i+1, params[2*j+2][2*i+1])
            state = opera.RY(state, num_qubit, (2*i+2)%num_qubit, params[2*j+2][(2*i+2)%num_qubit])
    return state

def alter_layer_encoding(weights,num_qubit):
    block_layer = int((len(weights)-1)/2)

    state = np.zeros(2**num_qubit)
#    state[0] = 1.0
    state[2**num_qubit-1] = 1.0
    params = np.copy(weights)

    for j in range(block_layer):
        for i in range(int(num_qubit/2)):
            state = opera.RY(state, num_qubit, 2*i+1, params[2*j][2*i+1])
            state = opera.RY(state, num_qubit, (2*i+2)%num_qubit, params[2*j][(2*i+2)%num_qubit])
            state = opera.CX(state, num_qubit, 2*i+1, (2*i+2)%num_qubit)
        for i in range(int(num_qubit/2)):
            state = opera.RY(state, num_qubit, 2*i, params[2*j+1][2*i])
            state = opera.RY(state, num_qubit, 2*i+1, params[2*j+1][2*i+1])
            state = opera.CX(state, num_qubit, 2*i, 2*i+1)

    for i in range(num_qubit):
        state = opera.RY(state, num_qubit, i, params[2*block_layer][i])
    return state


def sample_obs(value, value_small, value_large, sample_number):
    result = 0
    a = (value - value_small)/(value_large - value_small)
    for i in range(int(sample_number)):
        sample = np.random.random()
        if (sample < a):
            result = result + 1.0
    return (float)((value_large - value_small)*result/sample_number + value_small)


def mnist_encoding(number, file, size, alter_layer):
    
    lr = 0.1
    iteration = 100
    num_qubit = 8
    data = np.load(file+".npy")
    data_out = np.zeros((size, 2**num_qubit))
    para_out = np.zeros((size, (1+2*alter_layer)*num_qubit))
    glob_phase = np.zeros(size)
    glob_phase_2 = np.zeros(size)
    index_list = [i for i in range(size)]

    for i in range(size):
        print(i, " th image ", number)
        index = index_list[i]
        state_in = data[index].reshape(2**num_qubit)
        state_in = state_in/la.norm(state_in)
        weights = np.random.uniform(0, 2*pi, size=num_qubit*(1+2*alter_layer))
       # weights = np.zeros(num_qubit*(2*alter_layer+1))

        for j in range(iteration):
            grad_exact = np.zeros(len(weights))
            for k in range(len(grad_exact)):
                weights_plus = np.copy(weights)
                weights_minus = np.copy(weights)
                weights_plus[k] = weights[k] + pi/4
                weights_minus[k] = weights[k] - pi/4
                state_out_1 = alter_layer_circuit(state_in, weights_plus, num_qubit)
                state_out_2 = alter_layer_circuit(state_in, weights_minus, num_qubit)
                a = np.mean([opera.SigmaZ(state_out_1, num_qubit, i) for i in range(num_qubit)])
                b = np.mean([opera.SigmaZ(state_out_2, num_qubit, i) for i in range(num_qubit)])
                grad_exact[k] = (a-b)
            #weights = weights - lr*grad/la.norm(grad)
            for k in range(len(grad_exact)):
                weights[k] = weights[k] - lr*grad_exact[k]*(1-0.25*int(4*j/iteration))/la.norm(grad_exact)
            if(j%10 == 0):
                print(la.norm(grad_exact))

        reverse_weights = np.copy(weights)
        reverse_weights = reverse_weights.reshape((2*alter_layer+1, num_qubit))
        for j in range(2*alter_layer+1):
            reverse_weights[j] = -1.0* weights.reshape((2*alter_layer+1, num_qubit))[2*alter_layer-j]
#        print(reverse_weights)
        state_out = alter_layer_encoding(reverse_weights, num_qubit)
        reverse_weights = reverse_weights.reshape((1+2*alter_layer)*num_qubit)
        glob_phase[i] = 1.0
        if(la.norm(state_in-state_out) > la.norm(state_in+state_out)):
            state_out = state_out*-1
            glob_phase[i] = -1.0
        data_out[i] = np.copy(state_out)
        para_out[i] = np.copy(reverse_weights)
        # print("l_2 norm of the encoding error: ", la.norm(state_in-state_out))
        # print("observation: ", [opera.SigmaZ(alter_layer_circuit(state_in, weights, num_qubit), num_qubit, i) for i in range(num_qubit)])

    np.save(file+"_encoding.npy", data_out)
    np.save(file+"_para.npy", para_out)
    np.save(file+"_phase.npy", glob_phase)

    return 0
