from NeuralNet.LinearNet.LinearNetPy import LinearNet
import numpy as np
from DataRecord.DataRecordPy import DataRecord
import tensorflow as tf
import matplotlib.pyplot as plt


def func(x):
    return 1 / np.sqrt(x)


if __name__ == '__main__':
    tf.compat.v1.disable_eager_execution()

    # Generate Training data
    D = 400
    n = 100

    sigma = 0.01
    beta = np.random.normal(0, 1, [D, 1]) / np.math.sqrt(D)
    X = np.random.normal(0, 1, [n, D]) / np.math.sqrt(D)
    Y = np.matmul(X, beta) + np.random.normal(0, sigma, [n, 1])

    beta_min_norm = np.matmul(np.linalg.pinv(np.matmul(np.transpose(X), X)), np.matmul(np.transpose(X), Y))
    W, _, Phi_t = np.linalg.svd(X, full_matrices=False)
    proj2 = np.eye(D) - np.matmul(np.transpose(Phi_t), Phi_t)

    # Role of imbalance on convergence rate
    print("Numerical Verification for Theorem 2...")

    dr = DataRecord('data/experiment_thm_2.npy')
    init_dist_mean_num = dr.init(channel_name='Initial_dist_U_2VT_mean')
    final_dist_mean_num = dr.init(channel_name='Final_dist_to_min_norm_mean')
    init_dist_std_num = dr.init(channel_name='Initial_dist_U_2VT_std')
    final_dist_std_num = dr.init(channel_name='Final_dist_to_min_norm_std')
    iter_mean_num = dr.init(channel_name='iteration_num_mean')
    iter_std_num = dr.init(channel_name='iteration_num_std')

    # different hidden layer width
    h_list = np.arange(500, 10000, 300)

    lr = 0.005
    epoch_num = 120000

    # repeat for each h
    repeat = 10
    for hid_num in h_list:

        init_dist_list = []
        final_dist_list = []
        iter_list = []

        for rep in range(repeat):
            U_init_scale = 1 / np.sqrt(hid_num)
            V_init_scale = 1 / np.sqrt(hid_num)

            tf.compat.v1.reset_default_graph()
            with tf.compat.v1.Session() as sess:
                train_data = {'inputs': X, 'outputs': Y}

                net = LinearNet(sess, train_data)
                net.set_value({'batch_size': 100, 'learning_rate': lr, 'loss_type': 'l2'})
                net.build_model(hid_layer_num=1, hid_layer_dim=[hid_num],
                                layer_init_scale=[V_init_scale, U_init_scale])

                init_dist_list.append(np.linalg.norm(np.matmul(proj2, np.transpose(net.get_operator()))))

                it, _ = net.train_mini_batch(train_epochs=epoch_num, stop_check_every=1000, display_every=5000,
                                             test_every=5000)

                final_dist_list.append(np.linalg.norm(net.get_operator() - np.transpose(beta_min_norm)))
                iter_list.append(it)

        dr.update(np.mean(init_dist_list), hid_num, init_dist_mean_num)
        dr.update(np.std(init_dist_list), hid_num, init_dist_std_num)

        dr.update(np.mean(final_dist_list), hid_num, final_dist_mean_num)
        dr.update(np.std(final_dist_list), hid_num, final_dist_std_num)

        dr.update(np.mean(iter_list), hid_num, iter_mean_num)
        dr.update(np.std(iter_list), hid_num, iter_std_num)

    # plot results
    label_font_size = 12
    ax_tick_font_size = 14
    ax_label_font_size = 14

    dr.set_channel(0, new_channel_name='$\|U_2(0)V^T(0)\|_F^2$')
    dr.set_channel(1, new_channel_name='$\|U(t_f)V^T(t_f)-\hat{\Theta}\|_F^2$')
    dr.set_channel(2, new_channel_name='')
    dr.set_channel(3, new_channel_name='')
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)

    plt.sca(ax1)
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.sans-serif": ["Helvetica"],
        "font.size": label_font_size})
    plt.tick_params(labelsize=ax_tick_font_size)
    dr.visualize_errorbar(0, 2)
    dr.visualize_errorbar(1, 3)
    plt.ylim(top=1, bottom=0.1)
    plt.xlim(left=500, right=1e4)
    plt.grid(alpha=0.8, linestyle='--')
    plt.xlabel('Hidden layer width $h$')
    ax1.xaxis.label.set_fontsize(ax_label_font_size)
    ax1.yaxis.label.set_fontsize(ax_label_font_size)
    ax1.ticklabel_format(axis='x', style='sci', scilimits=(0, 4))
    plt.legend()

    plt.sca(ax2)
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.sans-serif": ["Helvetica"],
        "font.size": label_font_size})
    plt.tick_params(labelsize=ax_tick_font_size)
    dr.visualize_plot([0, 1], y_log=False, x_fun=func)
    plt.ylim(top=1, bottom=0.1)
    plt.xlim(right=1 / np.sqrt(500), left=1 / np.sqrt(1e4))
    plt.grid(alpha=0.8, linestyle='--')
    plt.xlabel('$1/\sqrt{h}$')
    plt.legend()
    ax2.xaxis.label.set_fontsize(ax_label_font_size)
    ax2.yaxis.label.set_fontsize(ax_label_font_size)
    ax2.invert_xaxis()

    plt.sca(ax3)
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.sans-serif": ["Helvetica"],
        "font.size": label_font_size})
    plt.tick_params(labelsize=ax_tick_font_size)
    dr.visualize_errorbar(4, 5)
    plt.ylim(top=8e4)
    plt.xlim(left=500, right=1e4)
    plt.grid(alpha=0.8, linestyle='--')
    plt.xlabel('Hidden layer width $h$')
    plt.ylabel('# of Iterations')
    plt.title('# of iterations until stop')
    ax3.xaxis.label.set_fontsize(ax_label_font_size)
    ax3.yaxis.label.set_fontsize(ax_label_font_size)
    ax3.ticklabel_format(axis='x', style='sci', scilimits=(0, 4))
    ax3.ticklabel_format(axis='y', style='sci', scilimits=(0, 4))

    plt.show()
