from NeuralNet.LinearNet.LinearNetPy import LinearNet
from DataRecord.DataRecordPy import DataRecord
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


if __name__ == '__main__':
    tf.compat.v1.disable_eager_execution()

    # Generate Training data
    D = 400
    n = 100

    X = np.random.normal(0, 1, [n, D]) / np.math.sqrt(D)
    W, Sigma, Phi_t = np.linalg.svd(X, full_matrices=False)
    X = np.matmul(np.matmul(W, np.eye(100)), Phi_t)

    sigma = 0.01
    beta = np.random.normal(0, 1, [D, 1]) / np.math.sqrt(D)
    Y = np.matmul(X, beta) + np.random.normal(0, sigma, [n, 1])

    # Role of imbalance on convergence rate
    print("Numerical Verification for Theorem 1...")
    dr = DataRecord('data/experiment_thm_1.npy')

    h = 500

    # case 1
    U_init_scale_1 = 0.1
    V_init_scale_1 = 0.1

    # case 2
    U_init_scale_2 = 0.5
    V_init_scale_2 = 0.02

    lr = 0.0005
    epoch_num = 15000

    # training for case 1
    print("Runing Case 1...")
    tf.compat.v1.reset_default_graph()
    with tf.compat.v1.Session() as sess:
        train_data = {'inputs': X, 'outputs': Y}

        net = LinearNet(sess, train_data, data_recorder=dr)
        net.set_value({'batch_size': 100, 'learning_rate': lr, 'loss_type': 'l2'})
        net.build_model(hid_layer_num=1, hid_layer_dim=[h],
                        layer_init_scale=[V_init_scale_1, U_init_scale_1])
        with tf.compat.v1.variable_scope('LinearNet', reuse=True):
            Ut = tf.compat.v1.get_variable("W1")
            V = tf.compat.v1.get_variable("W2")
            Ut_eval = Ut.eval()
            V_eval = V.eval()
            U1 = np.matmul(Phi_t, np.transpose(Ut_eval))
            imbalance = np.matmul(np.transpose(U1), U1) - np.matmul(np.transpose(V_eval), V_eval)
            _, S_imb, _ = np.linalg.svd(imbalance, full_matrices=False)

        imbalance_c = S_imb[n - 1]
        print("Level of imbalance at initialization: c={:.2e}".format(imbalance_c))

        ret = net.train_mini_batch(train_epochs=epoch_num, stop_check_every=1000, display_every=100,
                                   test_every=5000)

    dr.set_channel(0, new_channel_name='Case 1: $\sigma_U=0.1$,$\sigma_V=0.1$, $c={:.2e}$'.format(imbalance_c))

    # training for case 2
    tf.compat.v1.reset_default_graph()
    with tf.compat.v1.Session() as sess:
        train_data = {'inputs': X, 'outputs': Y}

        net = LinearNet(sess, train_data, data_recorder=dr)
        net.set_value({'batch_size': 100, 'learning_rate': lr, 'loss_type': 'l2'})
        net.build_model(hid_layer_num=1, hid_layer_dim=[h],
                        layer_init_scale=[V_init_scale_2, U_init_scale_2])

        with tf.compat.v1.variable_scope('LinearNet', reuse=True):
            Ut = tf.compat.v1.get_variable("W1")
            V = tf.compat.v1.get_variable("W2")
            Ut_eval = Ut.eval()
            V_eval = V.eval()
            U1 = np.matmul(Phi_t, np.transpose(Ut_eval))
            imbalance = np.matmul(np.transpose(U1), U1) - np.matmul(np.transpose(V_eval), V_eval)
            _, S_imb, _ = np.linalg.svd(imbalance, full_matrices=False)

        imbalance_c = S_imb[n - 1]
        print("Level of imbalance at initialization: c={:.2e}".format(imbalance_c))

        ret = net.train_mini_batch(train_epochs=epoch_num, stop_check_every=1000, display_every=100,
                                   test_every=5000)
    dr.set_channel(2, new_channel_name='Case 2: $\sigma_U=0.5$,$\sigma_V=0.02$, $c={:.2e}$'.format(imbalance_c))

    # plot results
    label_font_size = 12
    ax_tick_font_size = 14
    ax_label_font_size = 14

    eta = 5e-4 / 100 * 2

    color = ['tab:blue', 'tab:orange']

    fig, (ax1, ax2) = plt.subplots(1, 2)

    plt.sca(ax1)
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.sans-serif": ["Helvetica"],
        "font.size": label_font_size})
    plt.tick_params(labelsize=ax_tick_font_size)
    dr.visualize_plot([0, 2], y_log=False)
    plt.xlabel('Iteration $k$')
    plt.ylabel('Squared $l2$ Loss')
    plt.grid(alpha=0.8, linestyle='--')
    plt.legend()

    ax1.xaxis.label.set_fontsize(ax_label_font_size)
    ax1.yaxis.label.set_fontsize(ax_label_font_size)

    ax1.ticklabel_format(axis='x', style='sci', scilimits=(0, 4))

    plt.sca(ax2)
    plt.tick_params(labelsize=ax_tick_font_size)
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.sans-serif": ["Helvetica"],
        "font.size": label_font_size})
    plt_list = dr.visualize_plot([0, 2], y_log=True)

    for p in plt_list:
        p[0].set_label('')

    for i, case_id in enumerate([0, 2]):
        idx = dr._channel_num_list.index(case_id)
        name = dr._name_list[idx]
        imb_temp = name.split('$')[-2]
        imb = float(imb_temp.split('=')[-1])

        data = np.copy(dr._data_list[idx])
        data = data[1:, :]
        initial_loss = data[0, 1]
        data[:, 1] = np.power((1 - 2 * eta * imb), data[:, 0]) * initial_loss
        plt.plot(data[:, 0], data[:, 1], '--', label='Case {}: bound by Thm 1, $L(0)(1-2c\eta)^k$'.format(i + 1),
                 color=color[i])

    plt.xlabel('Iteration $k$')
    plt.ylabel('Log squared $l2$ Loss')
    ax2.xaxis.label.set_fontsize(ax_label_font_size)
    ax2.yaxis.label.set_fontsize(ax_label_font_size)
    plt.grid(alpha=0.8, linestyle='--')

    plt.legend()
    ax2.ticklabel_format(axis='x', style='sci', scilimits=(0, 4))

    plt.show()

