import pickle
import os
import numpy as np
import matplotlib.pyplot as plt


data_regret = []
data_std_regret = []
pickle_path = 'figs/evals_epoch250/cumulative_regret_d10.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

# print(cumulative_regret, regret_means, regret_sems)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)

pickle_path = 'figs/evals_epoch250/cumulative_regret_d20.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

# print(cumulative_regret, regret_means, regret_sems)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)


pickle_path = 'figs/evals_epoch250/cumulative_regret_d30.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)



pickle_path = 'figs/evals_epoch250/cumulative_regret_d40.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)

print(data_regret,len(data_regret))

fig, (ax1) = plt.subplots(1, 1, figsize=(15, 10))


data_regret = np.array(data_regret)
data_std_regret = np.array(data_std_regret)

print(data_std_regret[:,1], data_std_regret[:,2], data_std_regret[:,3], data_std_regret[:,4])

print(data_regret[:,1])

# for m in range(len(data_regret)):
#     print(data_regret[m])
k = 1.0
scale = [10, 20, 30, 40]

ax1.plot(scale, data_regret[:,1], label ='DPT (greedy)', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,1] - k*data_std_regret[:,1], data_regret[:,1] + k*data_std_regret[:,1], alpha=0.4)

ax1.plot(scale, data_regret[:,2], label ='PreDeToR (ours)', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,2] - k*data_std_regret[:,2], data_regret[:,2] + k*data_std_regret[:,2], alpha=0.4)

ax1.plot(scale, data_regret[:,3], label ='PreDeToR-$\\tau$ (ours)', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,3] - k*data_std_regret[:,3], data_regret[:,3] + k*data_std_regret[:,3], alpha=0.4)


ax1.plot(scale, data_regret[:,4], label ='AD', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,4] - k*data_std_regret[:,4], data_regret[:,4] + k*data_std_regret[:,4], alpha=0.4)



ax1.plot(scale, data_regret[:,5], label ='Thomp', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,5] - k*data_std_regret[:,5], data_regret[:,5] + k*data_std_regret[:,5], alpha=0.4)

ax1.plot(scale, data_regret[:,6], label ='LinUCB', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,6] - k*data_std_regret[:,6], data_regret[:,6] + k*data_std_regret[:,6], alpha=0.4)

ax1.set_xlabel('Dimension', fontsize = 30)
ax1.set_xticks(scale)
ax1.set_ylabel('Cumulative Regret', fontsize = 30)
ax1.set_title('Regret Over Time', fontsize = 30)
# ax1.set_ylim(0,10)
ax1.legend(fontsize=30)
ax1.tick_params(axis='both', which='major', labelsize=30)

plt.savefig(f'figs/evals_epoch250/dim_fig.png')
plt.clf()