
import pennylane as qml
from pennylane.templates import AmplitudeEmbedding
import numpy as np
import matplotlib.pyplot as plt
import random
import numpy.linalg as la
from pennylane.templates import BasicEntanglerLayers
import math
import random
from math import pi
from pennylane.init import basic_entangler_layers_normal
from pennylane.templates.layers import RandomLayers
from pennylane.ops import CNOT, RX, RY, RZ
from datetime import datetime

# np.random.seed(42)

def ttn_circuit(f, weights, num_qubits):
    layers = int(math.log(num_qubits.val,2))
    n_wires = int(num_qubits.val)
    AmplitudeEmbedding(features=f, wires = range(n_wires))
    for i in range(layers):
        params = weights[(2**(i+1)-2)*n_wires: (2**(i+1)-1)*n_wires]
        for j in range(n_wires):
            qml.RY(params[j], wires = j*(2**i))
        n_wires = int(n_wires/2)
        if(n_wires>0):
            for j in range(n_wires):
                qml.CNOT(wires=[(2*j+1)*(2**i),2*j*(2**i)])
    return [qml.expval(qml.PauliZ(0))]
#    return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)]

def random_circuit(f, weights, num_qubits):
    n_wires = int(num_qubits.val)
    AmplitudeEmbedding(features = f, wires = range(n_wires))
    weights = [list(weights)]
    RandomLayers(weights=weights, wires = range(n_wires), ratio_imprim=0.5, imprimitive=CNOT, seed=42)
#    return [qml.expval(qml.PauliZ(0))]
    return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)]

def sample_obs(value, value_small, value_large, sample_number):
    result = 0
    a = (value - value_small)/(value_large - value_small)
    for i in range(int(sample_number)):
        sample = np.random.random()
        if (sample < a):
            result = result + 1
    return (float)((value_large - value_small)*result/sample_number + value_small)


def mnist_classification(num_qubits, num_params, qcircuit, circuit_name, iteration, check_slide, lr, sample, sample_inf, test_size, train_size, gradient_normalize):
    # train_0 = np.load("train_images_0_16_normalized.npy")
    # train_1 = np.load("train_images_8_16_normalized.npy")
    # test_0 = np.load("test_images_0_16_normalized.npy")
    # test_1 = np.load("test_images_8_16_normalized.npy")
    train_0 = np.load("train_images_0_16_normalized_encoding.npy")
    train_1 = np.load("train_images_1_16_normalized_encoding.npy")
    test_0 = np.load("test_images_0_16_normalized_encoding.npy")
    test_1 = np.load("test_images_1_16_normalized_encoding.npy")
    print(len(train_0),len(train_1))
    print(len(test_0),len(test_1))
    # dev = qml.device('default.qubit', wires = num_qubits)
    # qcircuit = qml.QNode(circuit_structure, dev)
    error_test = np.zeros(int(iteration/check_slide))
    error_train = np.zeros(int(iteration/check_slide))
    loss_test = np.zeros(int(iteration/check_slide))
    loss_train = np.zeros(int(iteration/check_slide))
    weights = np.random.uniform(0, 4*pi, size=num_params)
    #weights = np.zeros(num_params)
    grad_norm = np.zeros(2*iteration)
    index_0_list = np.random.choice(range(train_size), size=iteration, replace=True)
    index_1_list = np.random.choice(range(train_size), size=iteration, replace=True)
    # index_0_list = [i%train_size for i in range(iteration)]
    # index_1_list = [i%train_size for i in range(iteration)]
    for j in range(iteration):
        # image 0 gradient update
        index_0 = index_0_list[j]
        state_in = train_0[index_0].reshape(2**num_qubits)
        grad = np.zeros(num_params)
        grad_exact = np.zeros(num_params)
        for k in range(len(grad)):
            weights_plus = np.copy(weights)
            weights_minus = np.copy(weights)
            weights_plus[k] = weights[k] + pi/2
            weights_minus[k] = weights[k] - pi/2
            # a = qcircuit(state_in, weights_plus, num_qubits)
            # b = qcircuit(state_in, weights_minus, num_qubits)
            a = np.mean(qcircuit(state_in, weights_plus, num_qubits))
            b = np.mean(qcircuit(state_in, weights_minus, num_qubits))
            grad[k] = (sample_obs(a, -1.0, 1.0, sample) - sample_obs(b, -1.0, 1.0, sample))
            grad_exact[k] = (a-b)*0.5
        #weights = weights - lr*grad/la.norm(grad)
        #weights = weights - lr*grad_exact/la.norm(grad_exact)
        if(gradient_normalize == False):
            weights = weights - lr*grad*(1-0.25*int(4*j/iteration))#/la.norm(grad)
        if(gradient_normalize == True):
            weights = weights - lr*grad*(1-0.25*int(4*j/iteration))/la.norm(grad)
        grad_norm[2*j] = la.norm(grad_exact)
        # image 1 gradient update
        index_1 = index_1_list[j]
        state_in = train_1[index_1].reshape(2**num_qubits)
        grad = np.zeros(num_params)
        grad_exact = np.zeros(num_params)
        for k in range(len(grad)):
            weights_plus = np.copy(weights)
            weights_minus = np.copy(weights)
            weights_plus[k] = weights[k] + pi/2
            weights_minus[k] = weights[k] - pi/2
            # a = qcircuit(state_in, weights_plus, num_qubits)
            # b = qcircuit(state_in, weights_minus, num_qubits)
            a = np.mean(qcircuit(state_in, weights_plus, num_qubits))
            b = np.mean(qcircuit(state_in, weights_minus, num_qubits))
            grad[k] = (sample_obs(a, -1.0, 1.0, sample) - sample_obs(b, -1.0, 1.0, sample))
            grad_exact[k] = (a-b)*0.5
        #weights = weights + lr*grad/la.norm(grad)
        #weights = weights + lr*grad_exact/la.norm(grad_exact)
        if(gradient_normalize == False):
            weights = weights + lr*grad*(1-0.25*int(4*j/iteration))#/la.norm(grad)
        if(gradient_normalize == True):
            weights = weights + lr*grad*(1-0.25*int(4*j/iteration))/la.norm(grad)
        grad_norm[2*j+1] = la.norm(grad_exact)

        # count the error and the loss after a slide of training
        if(j%check_slide ==0):
            k = int(j/check_slide)
            # count the test error
            for i in range(test_size):
                state_in = test_0[i].reshape(2**num_qubits)
                result_gauss = np.mean(qcircuit(state_in, weights, num_qubits))
                if(sample_obs(result_gauss, -1.0, 1.0, sample_inf) > 0):
                    error_test[k] = error_test[k] + 1.0
            for i in range(test_size):
                state_in = test_1[i].reshape(2**num_qubits)
                result_gauss = np.mean(qcircuit(state_in, weights, num_qubits))
                if(sample_obs(result_gauss, -1.0, 1.0, sample_inf) < 0):
                    error_test[k] = error_test[k] + 1.0
            error_test[k] = error_test[k]/(2*test_size)
            print("the test error at iteration: ", j, error_test[k])
            # count the train error
            for i in range(train_size):
                state_in = train_0[i].reshape(2**num_qubits)
                result_gauss = np.mean(qcircuit(state_in, weights, num_qubits))
                if(sample_obs(result_gauss, -1.0, 1.0, sample_inf) > 0):
                    error_train[k] = error_train[k] + 1.0
            for i in range(train_size):
                state_in = train_1[i].reshape(2**num_qubits)
                result_gauss = np.mean(qcircuit(state_in, weights, num_qubits))
                if(sample_obs(result_gauss, -1.0, 1.0, sample_inf) < 0):
                    error_train[k] = error_train[k] + 1.0
            error_train[k] = error_train[k]/(2*train_size)
            print("the train error at iteration: ", j, error_train[k])
            # count the test loss
            for i in range(test_size):
                state_in = test_0[i].reshape(2**num_qubits)
                loss_test[k] = loss_test[k] + np.mean(qcircuit(state_in, weights, num_qubits))*0.5 + 0.5
                state_in = test_1[i].reshape(2**num_qubits)
                loss_test[k] = loss_test[k] - np.mean(qcircuit(state_in, weights, num_qubits))*0.5 + 0.5
            loss_test[k] = loss_test[k]/(2*test_size)
            print("the test loss at iteration: ", j, loss_test[k])
            # count the train loss
            for i in range(train_size):
                state_in = train_0[i].reshape(2**num_qubits)
                loss_train[k] = loss_train[k] + np.mean(qcircuit(state_in, weights, num_qubits))*0.5 + 0.5
                state_in = train_1[i].reshape(2**num_qubits)
                loss_train[k] = loss_train[k] - np.mean(qcircuit(state_in, weights, num_qubits))*0.5 + 0.5
            loss_train[k] = loss_train[k]/(2*train_size)
            print("the train loss at iteration: ", j, loss_train[k])

    # save data
    if(gradient_normalize == False):
        np.save(circuit_name+"_train_error_"+ str(num_qubits)+".npy", error_train)
        np.save(circuit_name+"_test_error_"+ str(num_qubits)+".npy", error_test)
        np.save(circuit_name+"_train_loss_"+ str(num_qubits)+".npy", loss_train)
        np.save(circuit_name+"_test_loss_"+ str(num_qubits)+".npy", loss_test)
        np.save(circuit_name+"_gradnorm_"+ str(num_qubits)+".npy",  grad_norm)

    if(gradient_normalize == True):
        np.save(circuit_name+"_train_error_"+ str(num_qubits)+"_sngd.npy", error_train)
        np.save(circuit_name+"_test_error_"+ str(num_qubits)+"_sngd.npy", error_test)
        np.save(circuit_name+"_train_loss_"+ str(num_qubits)+"_sngd.npy", loss_train)
        np.save(circuit_name+"_test_loss_"+ str(num_qubits)+"_sngd.npy", loss_test)
    return 0

# num_qubits = 8
# iteration = 100
# check_slide = 5
# lr = 0.4
# shot = 100
# shot_inference = 1000
# test_size = 2
# train_size = 2

# figure_iteration_x = range(1,1+iteration, check_slide)
# figure_gradient_x = np.zeros(2*iteration)
# for i in range(2*iteration):
#     figure_gradient_x[i] = 0.5 *(1+i)

# t0 = datetime.now()
# dev = qml.device('default.qubit', wires = num_qubits)
# qcircuit = qml.QNode(ttn_circuit, dev)
# a = mnist_classification(8, 15, qcircuit, "ttn", iteration, check_slide, lr, shot, shot_inference, test_size, train_size, False)
# t1 = datetime.now()
# print('seconds for the ttn circuit: ', (t1-t0).seconds)
# dev = qml.device('default.qubit', wires = num_qubits)
# qcircuit = qml.QNode(random_circuit, dev)
# b = mnist_classification(8, 15, qcircuit, "random", iteration, check_slide, lr, shot, shot_inference, test_size, train_size, False)
# t2 = datetime.now()
# print('seconds for the random circuit: ', (t2-t1).seconds)

# tt_train_error = np.load("ttn_train_error_8.npy")
# tt_test_error = np.load("ttn_test_error_8.npy")
# tt_train_loss = np.load("ttn_train_loss_8.npy")
# tt_test_loss = np.load("ttn_test_loss_8.npy")
# tt_gradient_norm = np.load("ttn_gradnorm_8.npy")
# random_train_error = np.load("random_train_error_8.npy")
# random_test_error = np.load("random_test_error_8.npy")
# random_train_loss = np.load("random_train_loss_8.npy")
# random_test_loss = np.load("random_test_loss_8.npy")
# random_gradient_norm = np.load("random_gradnorm_8.npy")
# # figure of the error
# plt.plot(figure_iteration_x, tt_train_error, '--', label = "TT QNN train", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, tt_test_error,  label = "TT QNN test", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, random_train_error, '--', label = "Random QNN train", color = 'blue', lw = 2)
# plt.plot(figure_iteration_x, random_test_error, label = "Random QNN test", color = 'blue', lw = 2)
# plt.xlabel('training iteration', fontsize = 16)
# plt.ylabel('error', fontsize = 16)
# #    plt.ylim(0, 1)
# # Show the major grid lines with dark grey lines
# plt.grid(b=True, which='major', color='#666666', linestyle='--')
# # Show the minor grid lines with very faint and almost transparent grey lines
# plt.minorticks_on()
# plt.grid(b=True, which='minor', color='#999999', linestyle=':', alpha=0.2)
# plt.legend()
# plt.savefig('all_error_sgd.eps')
# plt.close()
# # figure of the loss
# plt.plot(figure_iteration_x, tt_train_loss, '--', label = "TT QNN train", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, tt_test_loss, label = "TT QNN test", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, random_train_loss, '--', label = "Random QNN train", color = 'blue', lw = 2)
# plt.plot(figure_iteration_x, random_test_loss, label = "Random QNN test", color = 'blue', lw = 2)
# plt.xlabel('training iteration', fontsize = 16)
# plt.ylabel('loss', fontsize = 16)
# #    plt.ylim(0, 1)
# # Show the major grid lines with dark grey lines
# plt.grid(b=True, which='major', color='#666666', linestyle='--')
# # Show the minor grid lines with very faint and almost transparent grey lines
# plt.minorticks_on()
# plt.grid(b=True, which='minor', color='#999999', linestyle=':', alpha=0.2)
# plt.legend()
# plt.savefig('all_loss_sgd.eps')
# plt.close()
# # figure of the gradient norm
# plt.plot(figure_gradient_x, tt_gradient_norm, label = "TT QNN")
# plt.plot(figure_gradient_x, random_gradient_norm, label = "Random QNN")
# plt.xlabel('training iteration', fontsize = 16)
# plt.ylabel('gradient norm', fontsize = 16)
# #    plt.ylim(0, 1)
# # Show the major grid lines with dark grey lines
# plt.grid(b=True, which='major', color='#666666', linestyle='--')
# # Show the minor grid lines with very faint and almost transparent grey lines
# plt.minorticks_on()
# plt.grid(b=True, which='minor', color='#999999', linestyle=':', alpha=0.2)
# plt.legend()
# plt.savefig('all_gradnorm_sgd.eps')
# plt.close()

# t0 = datetime.now()
# dev = qml.device('default.qubit', wires = num_qubits)
# qcircuit = qml.QNode(ttn_circuit, dev)
# a = mnist_classification(8, 15, qcircuit, "ttn", iteration, check_slide, lr, shot, shot_inference, test_size, train_size, True)
# t1 = datetime.now()
# print('seconds for the ttn circuit: ', (t1-t0).seconds)
# dev = qml.device('default.qubit', wires = num_qubits)
# qcircuit = qml.QNode(random_circuit, dev)
# b = mnist_classification(8, 15, qcircuit, "random", iteration, check_slide, lr, shot, shot_inference, test_size, train_size, True)
# t2 = datetime.now()
# print('seconds for the random circuit: ', (t2-t1).seconds)

# tt_train_error = np.load("ttn_train_error_8_sngd.npy")
# tt_test_error = np.load("ttn_test_error_8_sngd.npy")
# tt_train_loss = np.load("ttn_train_loss_8_sngd.npy")
# tt_test_loss = np.load("ttn_test_loss_8_sngd.npy")
# random_train_error = np.load("random_train_error_8_sngd.npy")
# random_test_error = np.load("random_test_error_8_sngd.npy")
# random_train_loss = np.load("random_train_loss_8_sngd.npy")
# random_test_loss = np.load("random_test_loss_8_sngd.npy")

# # figure of the error
# plt.plot(figure_iteration_x, tt_train_error, '--', label = "TT QNN train", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, tt_test_error,  label = "TT QNN test", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, random_train_error, '--', label = "Random QNN train", color = 'blue', lw = 2)
# plt.plot(figure_iteration_x, random_test_error, label = "Random QNN test", color = 'blue', lw = 2)
# plt.xlabel('training iteration', fontsize = 16)
# plt.ylabel('error', fontsize = 16)
# #    plt.ylim(0, 1)
# # Show the major grid lines with dark grey lines
# plt.grid(b=True, which='major', color='#666666', linestyle='--')
# # Show the minor grid lines with very faint and almost transparent grey lines
# plt.minorticks_on()
# plt.grid(b=True, which='minor', color='#999999', linestyle=':', alpha=0.2)
# plt.legend()
# plt.savefig('all_error_sngd.eps')
# plt.close()
# # figure of the loss
# plt.plot(figure_iteration_x, tt_train_loss, '--', label = "TT QNN train", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, tt_test_loss, label = "TT QNN test", color = 'red', lw = 2)
# plt.plot(figure_iteration_x, random_train_loss, '--', label = "Random QNN train", color = 'blue', lw = 2)
# plt.plot(figure_iteration_x, random_test_loss, label = "Random QNN test", color = 'blue', lw = 2)
# plt.xlabel('training iteration', fontsize = 16)
# plt.ylabel('loss', fontsize = 16)
# #    plt.ylim(0, 1)
# # Show the major grid lines with dark grey lines
# plt.grid(b=True, which='major', color='#666666', linestyle='--')
# # Show the minor grid lines with very faint and almost transparent grey lines
# plt.minorticks_on()
# plt.grid(b=True, which='minor', color='#999999', linestyle=':', alpha=0.2)
# plt.legend()
# plt.savefig('all_loss_sngd.eps')
# plt.close()
