import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os 

path_prefix = os.environ['UDG_DATA_PATH']

results_0 = [[5845.45, 7567.43, 8080.22, 10793.09],
             [5717.46, 7570.74, 8970.39, 11185.47],
             [5765.05, 7850.39, 8363.73, 10442.51]]
results_1 = [[5598.59, 7301.39, 9286.69, 10578.04],
             [4521.38, 6433.97, 9949.53, 11063.28],
             [5997.72, 7670.18, 9634.41, 10949.07]]
results_2 = [[5643.56, 7923.82, 8262.69, 9904.57],
             [5602.15, 7039.06, 9138.5, 10724.14],
             [5583.3, 7035.37, 8673.08, 9979.58]]
ref_0 = [1740.74, 4252.46, 6719.85, 9095.8]
ref_1 = [1368.66, 3834.41, 6678.89, 9223.15]
ref_2 = [1393.13, 4400.04, 6716.32, 9286.48]
REF_MIN = -280.17895
REF_MAX = 12135.0

def normalize(data, ref_min, ref_max):
    return (data - ref_min)/(ref_max-ref_min)

results = np.array([results_0, results_1, results_2])
refs = np.array([ref_0, ref_1, ref_2])
std = np.std(results, axis=1)
results = normalize(results, REF_MIN, REF_MAX)
refs = normalize(refs, REF_MIN, REF_MAX)
mean = np.mean(results, axis=1)
std = normalize(std, REF_MIN, REF_MAX)

# colors = ['b', 'r', 'g']
ref_colors = mcolors.TABLEAU_COLORS
colors = [ref_colors['tab:orange'], ref_colors['tab:red'], ref_colors['tab:blue']]
markers = ['o', '^', 's']
fig, ax = plt.subplots(figsize=(9, 9))
ax.set_ylim([0.0, 1.0])
ax.set_xlim([0.0, 1.0])
ax.set_xlabel('dataset mean score', fontsize=17)
ax.set_ylabel('offline training mean score', fontsize=17)
ax.tick_params(labelsize=17)
ax.set_aspect('equal', adjustable='box')
for i in range(refs.shape[0]):
    ax.plot(refs[i], mean[i], color=colors[i], marker=markers[i], markersize=10)
    ax.fill_between(refs[i], mean[i]+std[i], mean[i]-std[i], facecolor=colors[i], alpha=0.25, label='_nolegend_')
    
labels = ('multiple: diverse', 'multiple: random', 'single')
legend = ax.legend(labels, loc=[0.5, 0.1],
                            labelspacing=0.1, fontsize=17)

fig_path = path_prefix + 'mopo-local/fig/halfcheetah.png'
plt.savefig(fig_path)
print("Figure plotting complete!")