# Effect of omega on the convergence rate

from algorithms import *

plot_only = True

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:red', 'tab:green', 'tab:brown',  'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']

mu = 1e-4
experiment = "commp5"
pwork = 1.0
method = "SAGA"

datasets = ["a1a"]

T_dict = {"a1a": 5, "mushrooms": 12, "w2a": 10, "phishing": 11, "duke": 4, "madelon": 50, "gisette_scale": 100, "w7a": 4, "a8a": 8}

num = 30
c = 50.0**(1/num)

ps = [(c**i)*0.01 for i in range(num+1)]




# characterization of the problem
for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    T = T_dict[dataset]

    lam = 0.1
    omega = lam/T

    comm_list = []

    A, b = get_data(dataset)
    A = normalize_data(A)
    A, b = rearrange_data(A, method="random", b=b)

    n, d = A.shape
    m = int(n / T)

    if m * T != n:
        continue
    K = int(1000 * n / T)
    skip_it = int(K / 1000)

    v = (1 + mu) * np.ones(n)
    x0 = np.zeros(d)

    def pagg_star(lam):
        return (4*lam+mu)/(4*lam + 4 + (m+1)*mu)

    if not plot_only:
        f, g = make_fg_logreg(A, b, mu)

        def total_loss(X):
            return objective(f, psi_func, X, m, omega)


        alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / m, pagg=pagg_star(lam), pwork=pwork, omega=omega,
                                  n=n, T=T, mu=mu)
        X_start = np.random.randn(T, d)
        FF, _, _ = vr_lgd(total_loss, g, alpha, d, K, T, m, pagg_star(lam), omega, method="SAGA", psvrg=0, skip_it=skip_it, track_agg=False, X0=X_start)

        Fstar = np.min(FF)
        target_acc = Fstar + (FF[0]-Fstar)/100000.0

        for p in ps:
            print("p = {}".format(p))

            alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / m, pagg=p, pwork=pwork,
                                      omega=omega, n=n, T=T, mu=mu)
            _, _, (aggregations, agg_fvals) = vr_lgd(total_loss, g, alpha, d, 100*K, T, m, p, omega,
                                                     method="SAGA", psvrg=0, skip_it=skip_it, track_agg=False,
                                                     target_acc=target_acc, X0=X_start)

            comm_list.append(aggregations)

        # save X
        for backup in [True, False]:
            filename = createfilename(experiment, dataset, T, omega, mu, pwork, method, backup)
            pickle_out = open(filename, "wb")
            pickle.dump((comm_list, ps), pickle_out)
            pickle_out.close()

    # load X
    comm_list, ps = load_pickle(experiment, dataset, T, omega, mu, pwork, method)

    visualize_(np.asarray(comm_list), np.asarray(ps))

