import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as mticker
import pickle
from matplotlib.ticker import MaxNLocator

# Simulation 1: fix setting
with open('results/fix_3_3/reg_UCBVI.pickle','rb') as f:
    reg_UCBVI = pickle.load(f)
with open('results/fix_3_3/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)

nn,K = reg_UCBVI.shape

mean_reg_UCBVI = np.mean(reg_UCBVI,axis=0)
mean_reg_CUCBVI = np.mean(reg_CUCBVI,axis=0)
sd_reg_UCBVI = np.std(reg_UCBVI,axis=0)
sd_reg_CUCBVI = np.std(reg_CUCBVI,axis=0)

ci_UCBVI = sd_reg_UCBVI
ci_CUCBVI = sd_reg_CUCBVI

fig,ax = plt.subplots(1,1)
ax.plot(range(K),mean_reg_UCBVI,'b-',linewidth=2)
ax.fill_between(range(K),mean_reg_UCBVI-ci_UCBVI,mean_reg_UCBVI+ci_UCBVI,color='y',alpha=0.2)
ax.plot(range(K),mean_reg_CUCBVI,'r-',linewidth=2)
ax.fill_between(range(K),mean_reg_CUCBVI-ci_CUCBVI,mean_reg_CUCBVI+ci_CUCBVI,color='y',alpha=0.2)
ax.set_xlabel('Number of Iterations', fontsize = 15)
ax.set_ylabel('Cumulative Regret',fontsize = 15)
ax.legend(['UCBVI','C-UCBVI'],loc='upper left')
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
fig.suptitle('m=3,n=3',fontsize=15)
plt.savefig('results/fix_3_3/UCBVI_CUCBVI_K3000_nn5.png',bbox_inches='tight')

# Simulation 2: X_max
X_max_list = [2,3,4,5,6]
with open('results/Xmax_K3000_N20/reg_UCBVI.pickle','rb') as f:
    reg_UCBVI = pickle.load(f)
with open('results/Xmax_K3000_N20/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)
fig,ax = plt.subplots(1,1)
ax.plot(X_max_list,reg_UCBVI,'k-o',linewidth=2)
ax.plot(X_max_list,reg_CUCBVI,'b-o',linewidth=2)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
# ax.grid(b=False)
ax.set_xlabel('m',fontsize=15)
ax.set_ylabel('Cumulative Regret',fontsize=15)
# ax.set_title('Regret Comparison (Scaling with the Range of Bottom Layer Features)')
ax.legend(['UCBVI','C-UCBVI'],loc = 'upper left',fontsize=10)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.savefig('results/Xmax_K3000_N20/Xmax_comparison.png',bbox_inches='tight')

# Simulation 3: fix_3_3_v2
with open('results/fix_3_3_v2/reg_UCBVI.pickle','rb') as f:
    reg_UCBVI = pickle.load(f)
with open('results/fix_3_3_v2/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)
with open('results/fix_3_3_v2/reg_fac_CUCBVI.pickle','rb') as f:
    reg_fac_CUCBVI = pickle.load(f)

nn,K = reg_UCBVI.shape

mean_reg_UCBVI = np.mean(reg_UCBVI,axis=0)
mean_reg_CUCBVI = np.mean(reg_CUCBVI,axis=0)
mean_reg_fac_CUCBVI = np.mean(reg_fac_CUCBVI,axis=0)
sd_reg_UCBVI = np.std(reg_UCBVI,axis=0)
sd_reg_CUCBVI = np.std(reg_CUCBVI,axis=0)
sd_reg_fac_CUCBVI = np.std(reg_fac_CUCBVI,axis=0)

ci_UCBVI = sd_reg_UCBVI
ci_CUCBVI = sd_reg_CUCBVI
ci_fac_CUCBVI = sd_reg_fac_CUCBVI

fig,ax = plt.subplots(1,1)
ax.plot(range(K),mean_reg_UCBVI,'k-',linewidth=2)
ax.fill_between(range(K),mean_reg_UCBVI-ci_UCBVI,mean_reg_UCBVI+ci_UCBVI,color='y',alpha=0.2)
ax.plot(range(K),mean_reg_CUCBVI,'b-',linewidth=2)
ax.fill_between(range(K),mean_reg_CUCBVI-ci_CUCBVI,mean_reg_CUCBVI+ci_CUCBVI,color='y',alpha=0.2)
ax.plot(range(K),mean_reg_fac_CUCBVI,'r-',linewidth=2)
ax.fill_between(range(K),mean_reg_fac_CUCBVI-ci_fac_CUCBVI,mean_reg_CUCBVI+ci_CUCBVI,color='y',alpha=0.2)
ax.set_xlabel('Number of Iterations', fontsize = 15)
ax.set_ylabel('Cumulative Regret',fontsize = 15)
ax.legend(['UCBVI','C-UCBVI','C-F-UCBVI'],loc='upper left')
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
# fig.suptitle('m=3,n=3',fontsize=15)
plt.savefig('results/fix_3_3_v2/UCBVI_CUCBVI_CFUCBVI_K3000_nn10.png',bbox_inches='tight')

# Simulation 4: X_max_v2
X_max_list = [2,3,4,5,6]
with open('results/Xmax_K3000_N10_v2/reg_UCBVI.pickle','rb') as f:
    reg_UCBVI = pickle.load(f)
with open('results/Xmax_K3000_N10_v2/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)
with open('results/Xmax_K3000_N10_v2/reg_fac_CUCBVI.pickle','rb') as f:
    reg_fac_CUCBVI = pickle.load(f)
fig,ax = plt.subplots(1,1)
ax.plot(X_max_list,reg_UCBVI,'k-o',linewidth=2)
ax.plot(X_max_list,reg_CUCBVI,'b-o',linewidth=2)
ax.plot(X_max_list,reg_fac_CUCBVI,'r-o',linewidth=2)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
# ax.grid(b=False)
ax.set_xlabel('m',fontsize=15)
ax.set_ylabel('Cumulative Regret',fontsize=15)
# ax.set_title('Regret Comparison (Scaling with the Range of Bottom Layer Features)')
ax.legend(['UCBVI','C-UCBVI','C-F-UCBVI'],loc = 'upper left',fontsize=10)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.savefig('results/Xmax_K3000_N10_v2/Xmax_comparison.png',bbox_inches='tight')

# Simulation 5: ds
d_s_list = [1,2,3,4,5,6]
# with open('results/ds_K5000_N10/reg_UCBVI.pickle','rb') as f:
#     reg_UCBVI = pickle.load(f)
with open('results/ds_K5000_N10/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)
with open('results/ds_K5000_N10/reg_fac_CUCBVI.pickle','rb') as f:
    reg_fac_CUCBVI = pickle.load(f)
fig,ax = plt.subplots(1,1)
# ax.plot(d_s_list,reg_UCBVI,'k-o',linewidth=2)
ax.plot(d_s_list,reg_CUCBVI,'b-o',linewidth=2)
ax.plot(d_s_list,reg_fac_CUCBVI,'r-o',linewidth=2)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
# ax.grid(b=False)
ax.set_xlabel('Dimension of States',fontsize=15)
ax.set_ylabel('Cumulative Regret',fontsize=15)
# ax.set_title('Regret Comparison (Scaling with the Range of Bottom Layer Features)')
# ax.legend(['UCBVI','C-UCBVI','C-F-UCBVI'],loc = 'upper left',fontsize=10)
ax.legend(['C-UCBVI','C-F-UCBVI'],loc = 'upper left',fontsize=10)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.savefig('results/ds_K5000_N10/ds_comparison.png',bbox_inches='tight')

# Simulation 6: X_max_4algo
X_max_list = [2,3,4,5,6]
with open('results/Xmax_H2_K5000_N10_4algo/reg_UCBVI.pickle','rb') as f:
    reg_UCBVI = pickle.load(f)
with open('results/Xmax_H2_K5000_N10_4algo/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)
with open('results/Xmax_H2_K5000_N10_4algo/reg_fac_UCBVI.pickle', 'rb') as f:
    reg_fac_UCBVI = pickle.load(f)
with open('results/Xmax_H2_K5000_N10_4algo/reg_fac_CUCBVI.pickle','rb') as f:
    reg_fac_CUCBVI = pickle.load(f)
fig,ax = plt.subplots(1,1)
ax.plot(X_max_list,reg_UCBVI,'k-o',linewidth=2)
ax.plot(X_max_list,reg_CUCBVI,'b-o',linewidth=2)
ax.plot(X_max_list,reg_fac_UCBVI,'g-o',linewidth=2)
ax.plot(X_max_list,reg_fac_CUCBVI,'r-o',linewidth=2)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
# ax.grid(b=False)
ax.set_xlabel('m',fontsize=15)
ax.set_ylabel('Cumulative Regret',fontsize=15)
# ax.set_title('Regret Comparison (Scaling with the Range of Bottom Layer Features)')
ax.legend(['UCBVI','C-UCBVI','F-UCBVI','CF-UCBVI'],loc = 'upper left',fontsize=10)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.savefig('results/Xmax_H2_K5000_N10_4algo/Xmax_comparison.png',bbox_inches='tight')

# Simulation 7: fix_3_3_4algo
with open('results/fix_3_3_4algo/reg_UCBVI.pickle','rb') as f:
    reg_UCBVI = pickle.load(f)
with open('results/fix_3_3_4algo/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)
with open('results/fix_3_3_4algo/reg_fac_UCBVI.pickle','rb') as f:
    reg_fac_UCBVI = pickle.load(f)
with open('results/fix_3_3_4algo/reg_fac_CUCBVI.pickle','rb') as f:
    reg_fac_CUCBVI = pickle.load(f)

nn,K = reg_UCBVI.shape

mean_reg_UCBVI = np.mean(reg_UCBVI,axis=0)
mean_reg_CUCBVI = np.mean(reg_CUCBVI,axis=0)
mean_reg_fac_UCBVI = np.mean(reg_fac_UCBVI,axis=0)
mean_reg_fac_CUCBVI = np.mean(reg_fac_CUCBVI,axis=0)
sd_reg_UCBVI = np.std(reg_UCBVI,axis=0)
sd_reg_CUCBVI = np.std(reg_CUCBVI,axis=0)
sd_reg_fac_UCBVI = np.std(reg_fac_UCBVI,axis=0)
sd_reg_fac_CUCBVI = np.std(reg_fac_CUCBVI,axis=0)

ci_UCBVI = sd_reg_UCBVI
ci_CUCBVI = sd_reg_CUCBVI
ci_fac_UCBVI = sd_reg_fac_UCBVI
ci_fac_CUCBVI = sd_reg_fac_CUCBVI

fig,ax = plt.subplots(1,1)
ax.plot(range(K),mean_reg_UCBVI,'k-',linewidth=2)
ax.fill_between(range(K),mean_reg_UCBVI-ci_UCBVI,mean_reg_UCBVI+ci_UCBVI,color='y',alpha=0.2)
ax.plot(range(K),mean_reg_CUCBVI,'b-',linewidth=2)
ax.fill_between(range(K),mean_reg_CUCBVI-ci_CUCBVI,mean_reg_CUCBVI+ci_CUCBVI,color='y',alpha=0.2)
ax.plot(range(K),mean_reg_fac_UCBVI,'g-',linewidth=2)
ax.fill_between(range(K),mean_reg_fac_UCBVI-ci_fac_UCBVI,mean_reg_fac_UCBVI+ci_fac_UCBVI,color='y',alpha=0.2)
ax.plot(range(K),mean_reg_fac_CUCBVI,'r-',linewidth=2)
ax.fill_between(range(K),mean_reg_fac_CUCBVI-ci_fac_CUCBVI,mean_reg_fac_CUCBVI+ci_fac_CUCBVI,color='y',alpha=0.2)
ax.set_xlabel('Number of Episodes', fontsize = 15)
ax.set_ylabel('Cumulative Regret',fontsize = 15)
ax.legend(['UCBVI','C-UCBVI','F-UCBVI','CF-UCBVI'],loc='upper left')
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
# fig.suptitle('m=3,n=3',fontsize=15)
plt.savefig('results/fix_3_3_4algo/UCBVI_CUCBVI_FUCBVI_CFUCBVI_K5000_nn10.png',bbox_inches='tight')

# Simulation 8: ds_4algo
d_s_list = [2,3,4,5]
# with open('results/ds_K5000_N10/reg_UCBVI.pickle','rb') as f:
#     reg_UCBVI = pickle.load(f)
with open('results/ds_K5000_N10_4algo_adim2/reg_CUCBVI.pickle','rb') as f:
    reg_CUCBVI = pickle.load(f)
with open('results/ds_K5000_N10_4algo_adim2/reg_fac_UCBVI.pickle','rb') as f:
    reg_fac_UCBVI = pickle.load(f)
with open('results/ds_K5000_N10_4algo_adim2/reg_fac_CUCBVI.pickle','rb') as f:
    reg_fac_CUCBVI = pickle.load(f)
fig,ax = plt.subplots(1,1)
# ax.plot(d_s_list,reg_UCBVI,'k-o',linewidth=2)
ax.plot(d_s_list,reg_CUCBVI[1:],'b-o',linewidth=2)
ax.plot(d_s_list,reg_fac_UCBVI[1:],'g-o',linewidth=2)
ax.plot(d_s_list,reg_fac_CUCBVI[1:],'r-o',linewidth=2)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
# ax.grid(b=False)
ax.set_xlabel('Dimension of States',fontsize=15)
ax.set_ylabel('Cumulative Regret',fontsize=15)
# ax.set_title('Regret Comparison (Scaling with the Range of Bottom Layer Features)')
# ax.legend(['UCBVI','C-UCBVI','C-F-UCBVI'],loc = 'upper left',fontsize=10)
ax.legend(['C-UCBVI','F-UCBVI','CF-UCBVI'],loc = 'upper left',fontsize=10)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.savefig('results/ds_K5000_N10_4algo_adim2/ds_comparison.png',bbox_inches='tight')