import pickle
import os
import numpy as np
import matplotlib.pyplot as plt


data_regret = []
data_std_regret = []
pickle_path = 'figs/evals_epoch350/cumulative_regret_hd2.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

# print(cumulative_regret, regret_means, regret_sems)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)

pickle_path = 'figs/evals_epoch350/cumulative_regret_hd4.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

# print(cumulative_regret, regret_means, regret_sems)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)


pickle_path = 'figs/evals_epoch350/cumulative_regret_hd6.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)



pickle_path = 'figs/evals_epoch350/cumulative_regret_hd8.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)

print(data_regret,len(data_regret))

fig, (ax1) = plt.subplots(1, 1, figsize=(15, 10))



# pickle_path = 'figs/evals_epoch350/cumulative_regret_hd10.pkl'
# with open(pickle_path, 'rb') as f:
#     cumulative_regret, regret_means, regret_sems = pickle.load(f)

# data_H = []
# data_H_std = []
# for regret in cumulative_regret:
#     print(regret)
#     # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

#     sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
#     # print(sum_regret_row, np.shape(sum_regret_row))

#     avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
#     std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
#     print(avg_sum_regret_row_over_env)
#     data_H.append(avg_sum_regret_row_over_env)
#     data_H_std.append(std_err)

# data_regret.append(data_H)
# data_std_regret.append(data_H_std)

# print(data_regret,len(data_regret))

# fig, (ax1) = plt.subplots(1, 1, figsize=(15, 10))



pickle_path = 'figs/evals_epoch350/cumulative_regret_hd12.pkl'
with open(pickle_path, 'rb') as f:
    cumulative_regret, regret_means, regret_sems = pickle.load(f)

data_H = []
data_H_std = []
for regret in cumulative_regret:
    print(regret)
    # print(cumulative_regret[regret], np.shape(cumulative_regret[regret]))

    sum_regret_row = np.sum(cumulative_regret[regret], axis=1)
    # print(sum_regret_row, np.shape(sum_regret_row))

    avg_sum_regret_row_over_env = np.mean(sum_regret_row)    
    std_err = np.std(sum_regret_row) / np.sqrt(len(sum_regret_row)) 
    print(avg_sum_regret_row_over_env)
    data_H.append(avg_sum_regret_row_over_env)
    data_H_std.append(std_err)

data_regret.append(data_H)
data_std_regret.append(data_H_std)

print(data_regret,len(data_regret))

fig, (ax1) = plt.subplots(1, 1, figsize=(15, 10))





data_regret = np.array(data_regret)
data_std_regret = np.array(data_std_regret)

print(data_std_regret[:,1], data_std_regret[:,2], data_std_regret[:,3], data_std_regret[:,4])

print(data_regret[:,1])

# for m in range(len(data_regret)):
#     print(data_regret[m])
k = 1.0
# scale = [10, 20, 30, 40]
# scale = [2, 4, 6, 8, 10, 12]
scale = [2, 4, 6, 8, 12]

ax1.plot(scale, data_regret[:,1], label ='DPT (greedy)', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,1] - k*data_std_regret[:,1], data_regret[:,1] + k*data_std_regret[:,1], alpha=0.4)

ax1.plot(scale, data_regret[:,2], label ='PreDeToR (ours)', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,2] - k*data_std_regret[:,2], data_regret[:,2] + k*data_std_regret[:,2], alpha=0.4)

ax1.plot(scale, data_regret[:,3], label ='PreDeToR-$\\tau$ (ours)', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,3] - k*data_std_regret[:,3], data_regret[:,3] + k*data_std_regret[:,3], alpha=0.4)


ax1.plot(scale, data_regret[:,4], label ='AD', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,4] - k*data_std_regret[:,4], data_regret[:,4] + k*data_std_regret[:,4], alpha=0.4)



ax1.plot(scale, data_regret[:,5], label ='Thomp', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,5] - k*data_std_regret[:,5], data_regret[:,5] + k*data_std_regret[:,5], alpha=0.4)

ax1.plot(scale, data_regret[:,6], label ='LinUCB', linewidth=7.0)
ax1.fill_between(scale, data_regret[:,6] - k*data_std_regret[:,6], data_regret[:,6] + k*data_std_regret[:,6], alpha=0.4)

ax1.set_xlabel('Heads', fontsize = 30)
ax1.set_xticks(scale)
ax1.set_ylabel('Cumulative Regret', fontsize = 30)
ax1.set_title('Regret Over Time', fontsize = 30)
# ax1.set_ylim(0,10)
plt.yscale('log')
ax1.legend(fontsize=30)
ax1.tick_params(axis='both', which='major', labelsize=30)

plt.savefig(f'figs/evals_epoch350/heads_fig.png')
plt.clf()