import matplotlib.pyplot as plt
import numpy as np
import matplotlib

matplotlib.rcParams.update({'font.size': 15})
ax = plt.axes()
ax_color_cycle = ax._get_lines.prop_cycler

x = np.arange(1, 11) / 10.
x = x[4:]
pi_star = [0.5296154, 0.56461537, 0.66423076, 0.78884614, 0.8984615, 1.0]
bs = {'0.01': [0.3576923, 0.4080769, 0.5003846, 0.62153846, 0.795, 1.0],
      '0.03': [0.34730768, 0.38038462, 0.46153846, 0.5896154, 0.75769234, 1.0],
      '0.1': [0.37615386, 0.39653847, 0.5080769, 0.62653846, 0.79423076, 1.0],
      '0.2': [0.32384616, 0.38346153, 0.4523077, 0.57576925, 0.7553846, 1.0],
      '0.3': [0.3419231, 0.41923076, 0.49846154, 0.60346156, 0.8046154, 1.0],
      '0.4': [0.29346153, 0.3776923, 0.45076925, 0.53846157, 0.7165385, 1.0]}
uncertainty = [0.455, 0.47346154, 0.56038463, 0.71384615, 0.87846154, 1.0]
uniform = [0.32576925, 0.3146154, 0.35038462, 0.42807692, 0.49423078, 0.55653846]
lal = [0.44384616216023765, 0.593461545308431, 0.7342308044433594, 0.8469230651855468, 0.9403846104939778, 1.0]


# x = np.arange(1, 11) / 10.
# x = x[4:]
# pi_star = [0.0664, 0.052, 0.078, 0.2091, 0.5463, 0.9674]
# bs = {'0.01': [0.0533, 0.0687, 0.105, 0.2854, 0.7042, 0.9989], '0.03': [0.0565, 0.0727, 0.1131, 0.2868, 0.7068, 0.9989],
#       '0.1': [0.0569, 0.07, 0.1124, 0.2942, 0.7163, 0.9986], '0.2': [0.0516, 0.068, 0.1091, 0.2779, 0.6914, 0.9984],
#       '0.3': [0.0408, 0.0524, 0.0813, 0.21, 0.5783, 0.9993], '0.4': [0.0241, 0.0297, 0.0445, 0.1219, 0.3479, 0.847]}
# uniform = [0.0415, 0.0252, 0.0345, 0.0771, 0.1883, 0.439]
# uncertainty = [0.0454, 0.0367, 0.0529, 0.1193, 0.3585, 0.8195]

plt.plot(x, 100 - np.array(pi_star) * 100, "r-", linewidth=3, label=r"$\pi_* (Ours)$")
for beta in bs:
    if beta == '0.01':
        plt.plot(x, 100 - np.array(bs[beta]) * 100, "b-", linewidth=1, label=r"SGBS, various $\beta$")
    else:
        plt.plot(x, 100 - np.array(bs[beta]) * 100, "b-", linewidth=1)
color = next(ax_color_cycle)['color']
color = next(ax_color_cycle)['color']
plt.plot(x, 100 - np.array(uncertainty) * 100, "-", linewidth=3, label="Uncertainty", color=color)
color = next(ax_color_cycle)['color']
plt.plot(x, 100 - np.array(uniform) * 100, "-", linewidth=3, label="Uniform", color=color)
color = next(ax_color_cycle)['color']
color = next(ax_color_cycle)['color']
color = next(ax_color_cycle)['color']
plt.plot(x, 100 - np.array(lal) * 100, "-", linewidth=3, label="LAL", color=color)
plt.legend()
plt.xlabel("h", fontsize=15)
plt.ylabel(r"Error (%) $\mathbb{E}_{\theta \sim \mathcal{P}_h}[\mathbb{P}_{\pi,\theta}( \widehat{z} \neq z_\star(\theta) )]$", fontsize=15)
plt.show()

# x = np.arange(1, 11) / 10.
# x = x[4:]
# pi_star = [3.89, 4.66, 3.63, 1.39, 0.03, 0.0]
# bs = {'0.01': [3.25, 4.34, 3.14, 2.02, 0.0, 0.0], '0.03': [3.5, 3.32, 3.69, 2.84, 1.05, 0.0],
#       '0.1': [3.64, 3.45, 5.0, 2.79, 0.59, 0.0], '0.2': [3.29, 3.15, 4.62, 2.81, 0.14, 0.0],
#       '0.3': [3.04, 4.1, 4.73, 3.16, 0.17, 0.0], '0.4': [3.66, 3.63, 3.26, 1.47, 0.58, 0.0]}
# uniform = [4.23, 5.15, 5.49, 4.95, 1.97, 1.05]
# uncertainty = [3.75, 4.88, 2.68, 1.42, 0.17, 0.0]
#
# plt.plot(x, np.array(pi_star), "r-", linewidth=3, label=r"$\pi_* (Ours)$")
# for beta in bs:
#     if beta == '0.01':
#         plt.plot(x, np.array(bs[beta]), "b-", linewidth=1, label=r"SGBS, various $\beta$")
#     else:
#         plt.plot(x, np.array(bs[beta]), "b-", linewidth=1)
# plt.plot(x, np.array(uncertainty), "-", linewidth=1, label="Uncertainty")
# plt.plot(x, np.array(uniform), "-", linewidth=1, label="Uniform")
# plt.legend()
# plt.xlabel("h", fontsize=15)
# plt.ylabel(r"Regret $\mathbb{E}_{\theta \sim \mathcal{P}_h}[\langle z^\star(\theta) - \widehat{z},\theta\rangle]$", fontsize=15)
# plt.show()
