# Effect of omega on the convergence rate

from algorithms import *

plot_only = False

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:red', 'tab:green', 'tab:brown',  'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']

mu = 1e-4
experiment = "CompareMethods"
pwork = 1.0
method = "SAGA"

datasets = ["a1a"]

T_dict = {"a1a": 5, "mushrooms": 12, "w2a": 10, "phishing": 11, "duke": 4, "madelon": 50, "gisette_scale": 100, "w7a": 4, "a8a": 8}


num = 30
c = 1e6**(1/num)

lams = [(c**i)*0.00001 for i in range(num+1)]




# characterization of the problem
for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    T = T_dict[dataset]


    Flist = []
    Xlist = []

    A, b = get_data(dataset)
    A = normalize_data(A)
    A, b = rearrange_data(A, method="random", b=b)

    n, d = A.shape
    m = int(n / T)

    if m * T != n:
        continue
    K = int(1000 * n / T)
    v = (1 + mu) * np.ones(n)
    x0 = np.zeros(d)

    def pagg_star(lam):
        return (4*lam+mu)/(4*lam + 4 + (m+1)*mu)

    if not plot_only:
        f, g = make_fg_logreg(A, b, mu)


        def total_loss(X):
            return objective(f, psi_func, X, m, 0)

        alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / m, pagg=pagg_star(0), pwork=pwork, omega=0, n=n, T=T, mu=mu)
        _, X0, _ = vr_lgd(total_loss, g, alpha, d, K, T, m, 0.0, 0.0, method="SAGA", psvrg=0, skip_it=K, track_agg=False)

        def total_loss(X):
            return objective(f, psi_func, X, n, 0)

        alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / n, pagg=pagg_star(0), pwork=pwork, omega=0, n=n, T=1, mu=mu)
        _, xinf, _ = vr_lgd(total_loss, g, alpha, d, 5*K, 1, n, 0.0, 0.0, method="SAGA", psvrg=0, skip_it=K, track_agg=False)

        Xinf = np.outer(np.ones(T), xinf)
        Xlist.append(X0)

        for lam in lams:
            omega = lam/T
            def total_loss(X):
                return objective(f, psi_func, X, m, omega)

            print("lam = {}".format(lam))

            alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / m, pagg=pagg_star(lam), pwork=pwork, omega=omega, n=n, T=T,
                                      mu=mu)
            _, X, _ = vr_lgd(total_loss, g, alpha, d, K, T, m, pagg_star(lam), omega, method="SAGA", psvrg=0, skip_it=K, track_agg=False)

            Xlist.append(X)

        # save X
        for backup in [True, False]:
            filename = createfilename(experiment, dataset, T, lam, mu, pwork, method, backup)
            pickle_out = open(filename, "wb")
            pickle.dump((Xlist, X0, Xinf, np.asarray(lams)), pickle_out)
            pickle_out.close()

    # load X
    Xlist, X0, Xinf, lams_np = load_pickle(experiment, dataset, T, lams[-1], mu, pwork, method)

    dist_0 = create_distances(Xlist, X0)
    dist_inf = create_distances(Xlist, Xinf)
    lams = np.zeros(len(lams_np)+1)
    lams[1:] = lams_np
    lams_np = lams

    visualize_dist(dist_0, dist_inf, lams_np)

