import numpy as np

from algorithms import *
from utils import *
import warnings
import matplotlib

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

warnings.simplefilter('ignore')

d = 10
l = 10

seed = 121212
conv_smooth = ConvexSmooth(d)

# SMOOTH SETTING

conv_smooth = ConvexSmooth(d)
L = conv_smooth.get_L()
alpha_single_gaus = lambda t:  0.11 * (1/d) * (1/L)
alpha_single_sph = lambda t:  0.99 * (1/d) * (1/L)
alpha_multi = lambda t: 0.99 * (l/d) * (1/L)
alpha_multi_gaus = lambda t: 0.11 * (l/d) * (1/L)
h = lambda t : (1/d**2) * 1e-7/(t + 1)


T = 1000
reps = 10
x0 = np.full(d, 1.0, dtype=np.float64)
cgaus = RandomCenterGaussian(l = 1, alpha=alpha_single_gaus, h=h, seed=seed)
cgausm = RandomCenterGaussian(l = l, alpha=alpha_multi_gaus, h=h, seed=seed)
csph = RandomCenteredSphere(l = 1, alpha=alpha_single_sph, h=h, seed=seed)
csphm = RandomCenteredSphere(l = l, alpha=alpha_multi, h=h, seed=seed)
our = RandomCenteredStructured(l = l, alpha= alpha_multi, h=h, seed=seed)




ris_csph = csph.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps = reps)
ris_csphm = csphm.optimize(conv_smooth, x0=x0.copy(), f_star=0.0, T=T, reps=reps)

ris_cgaus = cgaus.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps=reps)
ris_cgausm = cgausm.optimize(conv_smooth, x0=x0.copy(), f_star=0.0, T=T, reps=reps)

ris_our = our.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps = reps)

ris_cgaus = process_result(ris_cgaus, 1, T) 
ris_csph = process_result(ris_csph, 1, T)
ris_cgausm = process_result(ris_cgausm, l, T)
ris_csphm = process_result(ris_csphm, l, T)

ris_our = process_result(ris_our, l, T)

labels = ['Single Gaussian', 
          'Single Spherical',
          'Multi Gaussian', 
          'Multi Spherical', 
          'Ours']

csph_mu,  csph_std= get_fvals(ris_csph, 1)
csphm_mu,  csphm_std = get_fvals(ris_csphm, 1)


cgaus_mu,  cgaus_std= get_fvals(ris_cgaus, 1)
cgausm_mu,  cgausm_std = get_fvals(ris_cgausm, 1)

#our_mu,   our_std = get_fvals(ris_our)

means = [ris_cgaus[0], ris_csph[0], ris_cgausm[0], ris_csphm[0], ris_our[0]]# [cgaus_mu, csph_mu, cgausm_mu, csphm_mu, our_mu]
stds =  [ris_cgaus[1], ris_csph[1], ris_cgausm[1], ris_csphm[1], ris_our[1]] #[cgaus_std, csph_std, cgausm_std, csphm_std, our_std]


plot_results('Smooth Convex Target', 
             means, stds, labels, "$f(x_k) - f(x^*)$", 
             legend=True, out_file="./conv_smooth_comp.pdf")