import numpy as np
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams.update({'font.size': 15})

# baseline =    100 - np.array([94.62, 85.05, 72.23, 59.01, 48.45, 37.62, 30.68, 24.22, 20.26])
# pi_star =     100 - np.array([92.90, 82.26, 69.56, 56.81, 45.59, 35.66, 27.41, 22.82, 17.34])
# uncertainty = 100 - np.array([72.05, 63.67, 54.50, 43.24, 35.54, 28.14, 22.73, 18.43, 14.87])
# uniform =     100 - np.array([43.28, 35.67, 28.99, 23.19, 29.22, 15.17, 13.07, 11.09,  9.65])
# gbs =         100 - np.array([ 1.76,  1.76,  1.07,  1.07,  1.07, 1.00,   1.00,  1.00,  1.00])
# x = 2. ** np.array([3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7])

# pi_star =     100 - np.array([20.86, 12.88,  7.36,  4.06])

# baseline =    100 - np.array([27.73, 13.87,  7.91,  4.76])
# pi_star =     100 - np.array([23.81, 13.32,  6.77,  4.26])
# uncertainty = 100 - np.array([27.84, 15.71,  8.76,  4.21])
# uniform =     100 - np.array([ 9.81,  5.54,  2.54,  1.24])
# gbs =         100 - np.array([0, 0, 0, 0])
# x = 2. ** np.array([3.5, 4, 4.5, 5])

baseline =    np.array([2.065, 4.888, 6.041, 6.215, 6.433])
pi_star =     np.array([2.102, 4.935, 6.073, 6.303, 6.669])
uncertainty = np.array([6.827, 7.243, 7.570, 7.580, 7.899])
uniform =     np.array([4.655, 7.036, 7.267, 7.335, 7.335])
gbs =         np.array([10, 10, 10, 10, 10])
x = 2. ** np.array([3, 4, 5, 6, 7])


ax = plt.axes()
ax_color_cycle = ax._get_lines.prop_cycler

plt.plot(x, pi_star, 'r-', linewidth=3)
p = np.array(pi_star)
# std = np.sqrt(p / 100 * (1 - p / 100))
# plt.fill_between(x, np.clip(p - std, a_min=0, a_max=None), np.clip(p + std, a_min=None, a_max=100), color='r',
#                  alpha=.3)

plt.plot(x, gbs + .15, 'b-', linewidth=3)
# plt.plot(x, gbs + .5, 'b-', linewidth=3)
# plt.plot(x, gbs + 1, 'b-', linewidth=3)
p = np.array(gbs)
# std = np.sqrt(p / 100 * (1 - p / 100))
# plt.fill_between(x, np.clip(p - std, a_min=0, a_max=None), np.clip(p + std, a_min=None, a_max=100), color='b',
#                  alpha=.3)

for lst in [uncertainty, uniform, baseline, gbs]:
    color = next(ax_color_cycle)['color']
    if color == "#1f77b4" or color == "#d62728":
        color = next(ax_color_cycle)['color']
    print(color)
    plt.plot(x, lst, "--" if lst is baseline else "-", color=color, linewidth=3)
    p = np.array(lst)
    # std = np.sqrt(p / 100 * (1 - p / 100))
    # plt.fill_between(x, np.clip(p - std, a_min=0, a_max=None), np.clip(p + std, a_min=None, a_max=100),
    #                  color=color, alpha=.3)
plt.legend([r"$\widehat{\pi}$ (Ours)", r"SGBS, various $\beta$", "Uncertainty", "Uniform", "r-dependent baseline", "LAL"])
plt.xlabel("r", fontsize=15)

# plt.ylabel(r"Error (%) $\sup_{\theta : \widehat{\rho}(\theta) \leq r} \mathbb{P}_{\pi,\theta}( \widehat{z} \neq z_*(\theta))$",
#            fontsize=15)
ax.set_yticklabels(["", "2", "3", "4", "5", "6", "7", "", "", ">40"])
plt.ylabel(r"Regret $\sup_{\theta : \widehat{\rho}(\theta) \leq r}\;\langle z^\star(\theta) - \widehat{z},\theta\rangle$",
           fontsize=15)

plt.tight_layout()
plt.show()


# rho_star = [88.11, 80.43, 69.46, 57.29, 48.57, 37.28]
# pi_up = [88.53, 79.57, 71.75, 61.89, 47.38]
# pi_bottom = [84.77, 72.99, 65.26, 51.85, 41.79]

matplotlib.rcParams.update({'font.size': 15})
# pi_star = np.array([84.50, 72.85, 59.32, 49.04, 37.86, 30.29, 23.64, 18.50, 15.04])
# pi1 =     np.array([88.53, 67.17, 50.84, 37.65, 28.00, 20.48, 15.41, 11.54,  9.32])
# pi2 =     np.array([87.74, 77.08, 61.37, 47.88, 36.42, 27.86, 20.43, 17.18, 12.90])
# pi3 =     np.array([86.83, 76.27, 63.05, 49.15, 37.51, 29.15, 22.23, 16.03, 13.62])
# pi4 =     np.array([84.45, 72.60, 62.76, 50.88, 40.68, 29.25, 22.26, 18.13, 13.98])
# pi5 =     np.array([81.78, 71.23, 60.86, 52.23, 42.51, 32.13, 25.40, 20.45, 16.28])
# pi6 =     np.array([80.12, 67.28, 58.25, 49.08, 40.98, 34.19, 26.56, 21.64, 16.74])
# pi7 =     np.array([77.46, 66.99, 54.73, 46.41, 38.60, 32.18, 25.48, 20.71, 17.52])
# pi8 =     np.array([76.06, 64.80, 53.30, 43.58, 35.60, 29.85, 24.30, 19.63, 16.49])
# pi9 =     np.array([75.89, 63.06, 53.11, 43.13, 35.30, 28.43, 23.69, 20.31, 16.82])
pi_star = 100 - np.array([92.90, 82.26, 69.56, 56.81, 45.59, 35.66, 27.41, 22.82, 17.34])
pi1 =     100 - np.array([0.9462, 0.7913, 0.6105, 0.4502, 0.3317, 0.2481, 0.1712, 0.1353, 0.1161]) * 100
pi2 =     100 - np.array([0.9451, 0.8505, 0.7093, 0.5452, 0.4061, 0.3195, 0.2474, 0.1942, 0.1553]) * 100
pi3 =     100 - np.array([0.9245, 0.8412, 0.7223, 0.5594, 0.4427, 0.3370, 0.2590, 0.1889, 0.1610]) * 100
pi4 =     100 - np.array([0.8947, 0.7992, 0.6916, 0.5863, 0.4598, 0.3528, 0.2672, 0.2148, 0.1706]) * 100
pi5 =     100 - np.array([0.8727, 0.7617, 0.6780, 0.5901, 0.4845, 0.3655, 0.2869, 0.2340, 0.1828]) * 100
pi6 =     100 - np.array([0.8434, 0.7178, 0.6323, 0.5634, 0.4753, 0.3762, 0.3005, 0.2422, 0.2002]) * 100
pi7 =     100 - np.array([0.8518, 0.7372, 0.6235, 0.5246, 0.4422, 0.3669, 0.3068, 0.2412, 0.2026]) * 100
pi8 =     100 - np.array([0.8460, 0.7153, 0.6216, 0.5141, 0.4275, 0.3511, 0.2816, 0.2359, 0.2003]) * 100
pi9 =     100 - np.array([0.8514, 0.7193, 0.5983, 0.4947, 0.4060, 0.3287, 0.2780, 0.2300, 0.1919]) * 100
pis = [pi1, pi2, pi3, pi4, pi5, pi6, pi7, pi8, pi9]
x = 2. ** np.array([3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7])
colors = ["#EE0000", "#E00020", "#C00040", "#A00060", "#800080", "#6000A0", "#4000C0", "#2000E0", "#0000EE"]
# pi_star = 100 - np.array([20.86, 12.88,  7.36,  4.06])
# pi1 =     100 - np.array([0.2773, 0.1234, 0.0642, 0.0366]) * 100
# pi2 =     100 - np.array([0.2381, 0.1332, 0.0677, 0.0426]) * 100
# pi3 =     100 - np.array([0.2267, 0.1387, 0.0791, 0.0409]) * 100
# pi4 =     100 - np.array([0.2145, 0.1192, 0.0748, 0.0476]) * 100
# pis = [pi1, pi2, pi3, pi4]
# x = 2. ** np.array([3.5, 4, 4.5, 5])
# colors = ["#EE0000", "#C00040", "#4000C0", "#0000EE"]

ind = range(x.shape[0])
func = lambda x: x
for i, (pi, color) in enumerate(zip(pis, colors)):
    plt.plot(x[ind], func(pi[ind]), '-', label="$\pi_%d$" % (i + 1), linewidth=1, color=color)
    print(func(pi[ind]))
plt.plot(x[ind], func(pi_star[ind]), 'r-', label="$\widehat{\pi} (Ours)$", linewidth=3)
print(func(pi_star[ind]))
plt.xlabel("r", fontsize=15)
plt.ylabel(r"Error (%) $\sup_{\theta : \rho_*(\theta) \leq r} \mathbb{P}_{\pi,\theta}( \widehat{z} \neq z_*(\theta))$",
           fontsize=15)
plt.legend(loc='lower right', ncol=2, fancybox=True, shadow=True)
plt.show()

func = lambda x: x
f_star = np.min(np.array(pis), axis=0)
for i, (pi, color) in enumerate(zip(pis, colors)):
    plt.plot(x[ind], func((pi - f_star)[ind]), '-', label="$\pi_%d$" % (i + 1), linewidth=1, color=color)
    print(func(pi[ind]))
plt.plot(x[ind], func((pi_star - f_star)[ind]), 'r-', label="$\widehat{\pi} (Ours)$", linewidth=3)
print(func(pi_star[ind]))
plt.xlabel("r", fontsize=15)
plt.ylabel(r"Gap (%) $\sup_{\theta : \widehat{\rho}(\theta) \leq r} \mathbb{P}_{\pi,\theta}( \widehat{z} \neq z_*(\theta)) - \ell(\pi_{k(\theta)}, \Theta^{r(\theta)})$",
           fontsize=15)
plt.ylim([-1, 25])
plt.legend(loc='upper right', ncol=3, fancybox=True, shadow=True)
plt.show()
