import numpy as np
import warnings
import matplotlib
import matplotlib.pyplot as plt

from algorithms import *
from utils import *

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

warnings.simplefilter('ignore')

d = 50
seed = 12131415
num_dirs = [1, 5, 10, 20, 50]
conv_smooth = ConvexSmooth(d)
conv_nonsmooth = ConvexNonSmooth(d)
x0 = np.ones(d)
L = conv_smooth.get_L()
T = 4000
reps = 10

results, results_nsmt = [], []
means_smt, stds_smt = [], []
means_nsmt, stds_nsmt = [], []
for l in num_dirs:
    alpha = lambda t :  0.99 * l/d * 1/L
    h = lambda t : 1e-5/(t + 1)
    ozd = RandomCenteredStructured(l = l, alpha= alpha, h=h, seed=seed)
    ozd_nsmt = RandomCenteredStructured(l = l, alpha= lambda t:  np.sqrt(l/d) * (t**(-1/2 - 1e-5)) , h=lambda t: 1e-7/(t + 1), seed=seed)
    smt_ris = ozd.optimize(conv_smooth, x0=x0, f_star=0, T=T, reps = reps)
    nsmt_ris = ozd_nsmt.optimize(conv_nonsmooth, x0=x0, f_star =0, T = T, reps=reps)
    mu_smt, std_smt = get_fvals(smt_ris, 1)
    mu_nsmt, std_nsmt = get_fvals(nsmt_ris, 1)
    means_smt.append(mu_smt)
    stds_smt.append(std_smt)
    means_nsmt.append(mu_nsmt)
    stds_nsmt.append(std_nsmt)
    print("[--] Completed for l = {}".format(l))


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
for i in range(len(num_dirs)):
    res, r_std = [], []
    res_nsmt, res_nsmt_std = [], []
    for j in range(len(means_smt[i])):
        res += [means_smt[i][j] for _ in range(2*num_dirs[i])]
        r_std += [stds_smt[i][j] for _ in range(2 * num_dirs[i])]
        res_nsmt += [means_nsmt[i][j] for _ in range(2*num_dirs[i])]
        res_nsmt_std += [stds_nsmt[i][j] for _ in range(2 * num_dirs[i])]
        if len(res) == T:
            break

    ax1.plot(range(T), res, '-', lw=3, label="$\ell = {}$".format(num_dirs[i]))
    ax1.fill_between(range(T), np.asarray(res) - np.asarray(r_std), np.asarray(res) + np.asarray(r_std), alpha=0.5)
    ax2.plot(range(T), res_nsmt, '-', lw=3, label="$\ell = {}$".format(num_dirs[i]))
    ax2.fill_between(range(T), np.asarray(res_nsmt) - np.asarray(res_nsmt_std), np.asarray(res_nsmt) + np.asarray(res_nsmt_std), alpha=0.5)
    
ax1.set_title("Smooth Convex function", fontsize=20)
ax1.set_xlabel("function evaluations", fontsize=16)
ax1.set_ylabel("$f(x_k) - f(x^*)$", fontsize=16)
ax2.set_title("NonSmooth Convex function", fontsize=20)
ax2.set_xlabel("function evaluations", fontsize=16)
ax2.set_ylabel("$f(x_k) - f(x^*)$", fontsize=16)

ax1.set_yscale("log")
ax2.set_yscale("log")

ax1.legend(fontsize=14)
fig.tight_layout()
fig.savefig("./change_l.pdf", bbox_inches='tight')    
    