import numpy as np

from algorithms import *
from utils import *

import matplotlib.pyplot as plt
import warnings
import matplotlib

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

warnings.simplefilter('ignore')

d = 10
l = 10

conv_smooth = ConvexSmooth(d)
L = conv_smooth.get_L()
alpha_single_gaus = lambda t:  0.11 * (1/d) * (1/L)
alpha_single_sph = lambda t:  0.99 * (1/d) * (1/L)
alpha_multi = lambda t: 0.99 * (l/d) * (1/L)
alpha_multi_gaus = lambda t: 0.11 * (l/d) * (1/L)
h = lambda t : (1/d**2) * 1e-7/(t + 1) 
seed = 121212

h_alg = RandomCenteredHouseholder(d =d ,l=l, h=h, alpha=alpha_multi, seed=seed)
ms_alg = RandomCenteredSphere( l=l, alpha=alpha_multi,h=h, seed=seed)
mg_alg = RandomCenterGaussian( l=l, alpha=alpha_multi_gaus,h=h, seed=seed)
ss_alg = RandomCenteredSphere( l=1, alpha=alpha_single_sph,h=h, seed=seed)
sg_alg = RandomCenteredSphere( l=1, alpha=alpha_single_gaus,h=h, seed=seed)
x0 = np.full(d, 1.0, dtype=np.float64)
T = 1000
reps = 20
ris_h = h_alg.optimize(conv_smooth, x0=x0.copy(), f_star=conv_smooth(conv_smooth.x_star), T = T, reps=reps)
ris_ms = ms_alg.optimize(conv_smooth, x0=x0.copy(), f_star=conv_smooth(conv_smooth.x_star), T = T, reps=reps)
ris_mg = mg_alg.optimize(conv_smooth, x0=x0.copy(), f_star=conv_smooth(conv_smooth.x_star), T = T, reps=reps)
ris_ss = ss_alg.optimize(conv_smooth, x0=x0.copy(), f_star=conv_smooth(conv_smooth.x_star), T = T, reps=reps)
ris_sg = sg_alg.optimize(conv_smooth, x0=x0.copy(), f_star=conv_smooth(conv_smooth.x_star), T = T, reps=reps)


ozd_mean_tm, ozd_std_tm = process_result2(ris_h, l=l, T=T, idx=2)
ozd_mean_v,  ozd_std_v = process_result2(ris_h, l=l, T=T, idx=1)

ms_mean_tm, ms_std_tm = process_result2(ris_ms, l=l, T=T, idx=2)
ms_mean_v,  ms_std_v = process_result2(ris_ms, l=l, T=T, idx=1)

mg_mean_tm, mg_std_tm = process_result2(ris_mg, l=l, T=T, idx=2)
mg_mean_v,  mg_std_v = process_result2(ris_mg, l=l, T=T, idx=1)

ss_mean_tm, ss_std_tm = process_result2(ris_ss, l=l, T=T, idx=2)
ss_mean_v,  ss_std_v = process_result2(ris_ss, l=l, T=T, idx=1)

sg_mean_tm, sg_std_tm = process_result2(ris_sg, l=l, T=T, idx=2)
sg_mean_v,  sg_std_v = process_result2(ris_sg, l=l, T=T, idx=1)


labels = [
        'Single Gaussian', 
        'Single Spherical',
        'Multi Gaussian', 
        'Multi Spherical', 
        'Ours'
        ]

results = [
    (np.cumsum(sg_mean_tm), sg_mean_v, sg_std_v),
    (np.cumsum(ss_mean_tm), ss_mean_v, ss_std_v),
    (np.cumsum(mg_mean_tm), mg_mean_v, mg_std_v),
    (np.cumsum(ms_mean_tm), ms_mean_v, ms_std_v),
    (np.cumsum(ozd_mean_tm), ozd_mean_v, ozd_std_v),
]

fig, ax = plt.subplots()

ax.set_title("CPU Comparison: Convex Smooth Setting", fontsize=18)

for (i, label) in enumerate(labels):
    ax.plot(results[i][0], results[i][1], '-', lw=3, label=label)
    ax.fill_between(results[i][0], results[i][1] - results[i][2], results[i][1] + results[i][2], alpha=0.45)

ax.set_xlabel("Cumulative time ($s$)", fontsize=16)
ax.set_ylabel("$f(x_k) - f(x^*)$", fontsize=16)





ax.set_yscale("log")
ax.legend()

fig.savefig("./cpu_comparison.pdf", bbox_inches='tight')
print(mg_mean_v[0], ms_mean_v[0])


