import numpy as np

from algorithms import *
from utils import *
from other_nonsmooth_utils import *

import matplotlib
import warnings

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

warnings.simplefilter('ignore')

d = 50

t1 = SparseGroupLasso(d = d, seed=121314)

l = 25
reps =  20
alpha = lambda k :  5.0 * np.sqrt(1/ k)
alpha_gaus = lambda k : 0.1* np.sqrt(1/ k)
h = lambda k : 1e-3/np.sqrt(k)
our = RandomCenteredStructured(l = l, alpha= alpha, h=h, seed=121314)
sing_sph = RandomCenteredSphere(l = 1, alpha= alpha, h=h, seed=121314)
multi_sph = RandomCenteredSphere(l = l, alpha= alpha, h=h, seed=121314)
multi_gaus = RandomCenterGaussian(l = l, alpha= alpha_gaus, h=h, seed=121314)
sing_gaus = RandomCenterGaussian(l = 1, alpha= alpha_gaus, h=h, seed=121314)

T = 20000
sing_gaus_ris = sing_gaus.optimize(t1, t1.x0, f_star=0.0, T = T, reps = reps)
mul_gaus_ris = multi_gaus.optimize(t1, t1.x0, f_star=0.0, T = T, reps = reps)
sing_sph_ris = sing_sph.optimize(t1, t1.x0, f_star=0.0, T = T, reps = reps)
mul_sph_ris = multi_sph.optimize(t1, t1.x0, f_star=0.0, T = T, reps = reps)
our_ris = our.optimize(t1, t1.x0, f_star=0.0, T = T, reps = reps)

results = [sing_gaus_ris, sing_sph_ris, mul_gaus_ris, mul_sph_ris, our_ris]
labels = [
    'Single Gaussian', 
    'Single Spherical',
    'Multi Gaussian', 
    'Multi Spherical', 
    'Ours'
    ]
means, stds = [], []

for ris in results:
    mean, std =  process_result(ris, l, T, idx=1)
    means.append(mean)
    stds.append(std)


plot_results(t1.title, means, stds, labels, "$f(x_k)$", legend=True, out_file="./{}.pdf".format(t1.name))

#print(ris)