'''
Figure paper
'''
from matplotlib import rcParams
from cProfile import label
import numpy as np
import time
import pickle
from datetime import timedelta
import matplotlib.pyplot as plt
import matplotlib
from support_func import *
import os
curr_path = os.path.dirname(os.path.abspath(__file__))
data_file_name = os.path.join(curr_path, "change_p_model_err.pkl")


m = 10  # number of training tasks
n = 50  # number of samples for each task
nv = 3  # number of validation samples for each task
# number of features
p_array = [5, 10, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
           30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 50, 60, 70, 100, 120, 150, 200, 400, 1000]
# p_array = [25,
#    30, 40, 50, 60, 70, 100, 120, 150, 200, 400, 600, 1000]
# p_array = [15, 20]
s = 5  # sparsity

np.random.seed(0)
w0 = np.random.uniform(-2, 2, (s, 1))
w0 = w0 / np.linalg.norm(w0) * 10
repeat_num = 100


def model_err_with_p(NU, noise_sigma):
    start_time = time.time()
    all_w2 = []
    all_ideal = []
    for repeat_index in range(repeat_num):
        if repeat_index % (repeat_num // 10) == repeat_num // 10 - 1:
            print("{:< 3}. {}".format(repeat_index, timedelta(
                seconds=int(time.time() - start_time))))
        np.random.seed(repeat_index)
        w2_model_err = []
        ideal_err = []
        _W, _X, Y, _X_v, Y_v = sparse_linear_Gaussian(
            w0, NU, p_array[-1], m, n, nv, noise_sigma, s)
        for i in range(len(p_array)):
            p = p_array[i]
            step_size = 0.02 / p
            # step_size = 1
            W = _W[:, :p, :]
            X = _X[:, :p, :]
            X_v = _X_v[:, :p, :]
            w_star = solve(X, Y, X_v, Y_v, step_size)

            w0_padded = np.zeros((p, 1))
            if p >= s:
                w0_padded[:s, :] = w0
            else:
                w0_padded = w0[:p, :]
            w2_model_err.append(np.linalg.norm(w0_padded-w_star) ** 2)
            ideal_err.append(ideal_diff_W(
                X, Y, X_v, Y_v, step_size, w0_padded))
        all_w2.append(w2_model_err)
        all_ideal.append(ideal_err)

    avg_w2 = np.asarray(all_w2).mean(axis=0)
    avg_ideal = np.asarray(all_ideal).mean(axis=0)
    return avg_w2, avg_ideal


def get_theory_value(NU, noise_sigma):
    result = []
    for p in p_array_theory:
        step_size = 0.02 / p
        bw0 = (p - m * nv) / p * np.linalg.norm(w0) ** 2
        b_delta = m * nv * noise_sigma ** 2 * \
            (1 + step_size ** 2 * p / n) + NU ** 2 * m * nv * \
            ((1 - step_size) ** 2 + step_size ** 2 * (p + 1) / n)
        C1 = C3 = 0.001
        C2 = C4 = 0.99995
        b_delta = m * nv * ((1 + C1 / n) * noise_sigma **
                            2 + C2 * (1 + C3 / n) * NU ** 2)
        esti_eig = p - C4 * m * nv
        b_ideal = b_delta / esti_eig
        bw = bw0 + b_ideal
        result.append(bw)
    return result


nu_sigma_list = [(60, 0), (10, 10), (0.5, 0.5), (0.3, 0.3), (0.1, 0.1)]
marker_list = ["+", "x", "1", "2", "3"]
floor_list = [31, 31, 34, 33, 31]
all_curves = []


'''

# calculate and save data
for a, b in nu_sigma_list:
    curve, _ = model_err_with_p(NU=a, noise_sigma=b)
    all_curves.append(curve)


output = open(data_file_name, 'wb')
pickle.dump({'pArray': p_array, 'data': all_curves}, output)
output.close()
'''

# load data
pkl_file = open(data_file_name, 'rb')
data1 = pickle.load(pkl_file)
p_array = data1['pArray']
all_curves = data1['data']
pkl_file.close()


rcParams.update({'figure.autolayout': True})
plt.rcParams.update({'font.size': 12})

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(
    10, 3), gridspec_kw={'width_ratios': [1.5, 1]})

print(all_curves[2][p_array.index(33)])
print(all_curves[2][p_array.index(34)])
print(all_curves[2][p_array.index(35)])

i = 0
for a, b in nu_sigma_list:
    ax1.plot(p_array, all_curves[i], label=r"$\nu={},\sigma={}$".format(
        a, b), marker=marker_list[i], markevery=[0, 1], markersize=12, markeredgewidth=2)
    ax2.plot(p_array, all_curves[i], label=r"$\nu={},\sigma={}$".format(
        a, b), marker=marker_list[i], markevery=[p_array.index(floor_list[i])], markersize=12, markeredgewidth=2)
    i += 1

ax1.hlines(np.linalg.norm(w0) ** 2, p_array[0], p_array[-1],
           color='black', linestyles='dashed', label="null risk")
ax2.hlines(np.linalg.norm(w0) ** 2, p_array[0], p_array[-1],
           color='black', linestyles='dashed', label="null risk")

ax1.set(yscale='log', xscale='log', xlabel='$p$',
        ylabel='model error', title='(a)')
ax2.set(yscale='log', xlabel='$p$', title='(b)', xlim=(30, 39), ylim=(1, 400))
# ax1.legend()
ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax1.grid('both', 'both')
ax2.grid('both', 'both')
plt.savefig(os.path.join(curr_path, 'change_p_model_err.eps'),
            format='eps', bbox_inches='tight')

# compare the theoretical result with experimental result
fig = plt.figure(figsize=(9, 9))
spec = matplotlib.gridspec.GridSpec(ncols=4, nrows=3)
ax1 = fig.add_subplot(spec[0, 0:2])  # row 0 with axes spanning 2 cols on evens
ax2 = fig.add_subplot(spec[0, 2:4])
ax3 = fig.add_subplot(spec[1, 0:2])
ax4 = fig.add_subplot(spec[1, 2:4])  # row 0 with axes spanning 2 cols on odds
ax5 = fig.add_subplot(spec[2, 1:3])
all_axes = [ax1, ax2, ax3, ax4, ax5]
all_titles = ['(a)', '(b)', '(c)', '(d)', '(e)']
all_theory_curves = []
all_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
p_array_theory = [i for i in range(m * nv, 1000 + 1)]
i = 0
start_p_index = p_array.index(m * nv)
for a, b in nu_sigma_list:
    curve = get_theory_value(NU=a, noise_sigma=b)
    all_theory_curves.append(curve)
    all_axes[i].plot(p_array[start_p_index:], all_curves[i][start_p_index:], label="experiment",
                     marker=marker_list[i], linestyle='None', markersize=12, markeredgewidth=1.5, color=all_colors[i])
    all_axes[i].plot(p_array_theory, curve,
                     label="theory", color=all_colors[i])
    all_axes[i].hlines(np.linalg.norm(w0) ** 2, p_array[start_p_index], p_array[-1],
                       color='black', linestyles='dashed', label="null risk")

    all_axes[i].set(yscale='log', xscale='log', xlabel='$p$',
                    ylabel='model error', title=all_titles[i]+r" $\nu={},\sigma={}$".format(
                        a, b))
    # all_axes[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    all_axes[i].legend()
    all_axes[i].grid('both', 'both')
    i += 1
plt.savefig(os.path.join(curr_path, 'theory_match.eps'),
            format='eps', bbox_inches='tight')
plt.show()

