import matplotlib.pyplot as plt
import numpy as np
import pickle

result_path = "result_fixed/"
plot_path = "plots_fixed/"
std_plot = 1.0

# visualiztion
# plot of both loss and regret
fig, axes = plt.subplots(2, 3, figsize=[24., 12.])
    
nu_list = [1.0, 3.0, 10.0] # parameter controlling the level of heterogeneity
for j, nu in enumerate(nu_list):
    # load the result
    with open(result_path + "log_loss_list_optimal_mean_nu=" + str(nu) + ".txt", "rb") as f:
        log_loss_list_optimal_mean = pickle.load(f)

    with open(result_path + "log_loss_list_optimal_std_nu=" + str(nu) + ".txt", "rb") as f:
        log_loss_list_optimal_std = pickle.load(f)

    with open(result_path + "log_loss_list_Adaptive_OSMD_mean_nu=" + str(nu) + ".txt", "rb") as f:
        log_loss_list_Adaptive_OSMD_mean = pickle.load(f)

    with open(result_path + "log_loss_list_Adaptive_OSMD_std_nu=" + str(nu) + ".txt", "rb") as f:
        log_loss_list_Adaptive_OSMD_std = pickle.load(f)

    with open(result_path + "log_loss_list_IS_mean_nu=" + str(nu) + ".txt", "rb") as f:
        log_loss_list_IS_mean = pickle.load(f)

    with open(result_path + "log_loss_list_IS_std_nu=" + str(nu) + ".txt", "rb") as f:
        log_loss_list_IS_std = pickle.load(f)

    # load the result
    with open(result_path + "log_regret_list_Adaptive_OSMD_mean_nu=" + str(nu) + ".txt", "rb") as f:
        log_regret_list_Adaptive_OSMD_mean = pickle.load(f)

    with open(result_path + "log_regret_list_Adaptive_OSMD_std_nu=" + str(nu) + ".txt", "rb") as f:
        log_regret_list_Adaptive_OSMD_std = pickle.load(f)

    with open(result_path + "log_regret_list_IS_mean_nu=" + str(nu) + ".txt", "rb") as f:
        log_regret_list_IS_mean = pickle.load(f)

    with open(result_path + "log_regret_list_IS_std_nu=" + str(nu) + ".txt", "rb") as f:
        log_regret_list_IS_std = pickle.load(f)

    n_iter = len(log_loss_list_optimal_mean)

    axes[0][j].plot(np.arange(1, n_iter+1), log_loss_list_Adaptive_OSMD_mean, color="g", label="Ada-OSMD", linewidth=3)
    axes[0][j].fill_between(np.arange(1, n_iter+1), log_loss_list_Adaptive_OSMD_mean - std_plot*log_loss_list_Adaptive_OSMD_std, 
                log_loss_list_Adaptive_OSMD_mean + std_plot*log_loss_list_Adaptive_OSMD_std, facecolor='g', alpha=0.2,
                edgecolor='g', linestyle='dashdot')

    axes[0][j].plot(np.arange(1, n_iter+1), log_loss_list_IS_mean, color="r", label="IS", linewidth=3)
    axes[0][j].fill_between(np.arange(1, n_iter+1), log_loss_list_IS_mean - std_plot*log_loss_list_IS_std, 
                log_loss_list_IS_mean + std_plot*log_loss_list_IS_std, facecolor='r', alpha=0.2,
                edgecolor='r', linestyle='dashdot')

    axes[0][j].set_title(r'$\nu=$'+str(nu), fontsize=25)
    axes[0][j].tick_params(axis='x', labelsize=25)
    axes[0][j].tick_params(axis='y', labelsize=25)
    if j == 0:
       axes[0][j].set_ylabel("log(loss)", fontsize=25)

    box = axes[0][j].get_position()
    axes[0][j].set_position([box.x0, box.y0 + box.height * 0.15, box.width, box.height * 0.85])

    axes[1][j].plot(np.arange(1, n_iter+1), log_regret_list_Adaptive_OSMD_mean, color="g", label="Ada-OSMD", linewidth=3)
    axes[1][j].fill_between(np.arange(1, n_iter+1), log_regret_list_Adaptive_OSMD_mean - std_plot*log_regret_list_Adaptive_OSMD_std, 
                log_regret_list_Adaptive_OSMD_mean + std_plot*log_regret_list_Adaptive_OSMD_std, facecolor='g', alpha=0.2,
                edgecolor='g', linestyle='dashdot')

    axes[1][j].plot(np.arange(1, n_iter+1), log_regret_list_IS_mean, color="r", label="IS", linewidth=3)
    axes[1][j].fill_between(np.arange(1, n_iter+1), log_regret_list_IS_mean - std_plot*log_regret_list_IS_std, 
                log_regret_list_IS_mean + std_plot*log_regret_list_IS_std, facecolor='r', alpha=0.2,
                edgecolor='r', linestyle='dashdot')
    
    axes[1][j].set_xlabel("t", fontsize=25)
    axes[1][j].tick_params(axis='x', labelsize=25)
    axes[1][j].tick_params(axis='y', labelsize=25)
    if j == 0:
       axes[1][j].set_ylabel("log(regret)", fontsize=25)

    box = axes[1][j].get_position()
    axes[1][j].set_position([box.x0, box.y0 + box.height * 0.15, box.width, box.height * 0.85])

fig.legend(labels=['Ada-OSMD', 'IS'], loc='lower center', ncol=2, fontsize=25)
fig.savefig(plot_path+'loss_regret_fixed.png', dpi=400, bbox_inches='tight')
plt.close(fig)