# Effect of aggregation probability on the convergence rate

from algorithms import *

plot_only = False

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:brown', 'tab:green', 'tab:red', 'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']

mu = 1e-4
experiment = "ComparePagg"
pwork = 1.0
method = "SAGA"

datasets = ["w2a", "phishing", "duke", "madelon", "gisette_scale", "a1a", "mushrooms", "a8a"]

T_dict = {"a1a": 5, "mushrooms": 12, "w2a": 10, "phishing": 11, "duke": 4, "madelon": 50, "gisette_scale": 100, "a8a": 8}
omega_dict = {"a1a": 0.1, "mushrooms": 0.05, "w2a": 0.01, "phishing": 0.1, "duke": 0.4, "madelon": 0.02, "gisette_scale": 0.2, "a8a": 0.1}

# characterization of the problem
for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    T = T_dict[dataset]
    omega = omega_dict[dataset]/T


    A, b = get_data(dataset)
    A = normalize_data(A)
    A, b = rearrange_data(A, "random", b)


    n, d = A.shape
    K = int(500*n/T)
    skip_it = int(K/1000)

    v = (1 + 1e-6)*np.ones(n)
    m = int(n/T)
    x0 = np.zeros(d)
    f, g = make_fg_logreg(A, b, mu)



    def total_loss(X): return objective(f, psi_func, X, m, omega)


    if m*T != n:
        continue


    pagg_star = T*omega/(1+T*omega)
    paggs = pagg_star*np.asarray([0.01, 0.1, 0.3, 1.0, 3.0, 10.0])
    paggs = paggs[paggs <= 1.0]
    shift = np.sum(paggs < 0.001)
    paggs = paggs[paggs >= 0.001]

    Flist = []
    Xlist = []

    labels = ["{:.3f}".format(pagg) for pagg in paggs]

    it = create_it(T=K, skip_it=skip_it, tau = T)


    if not plot_only:
        for pagg in paggs:
            alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / m, pagg=pagg, pwork=pwork, omega=omega, n=n, T=T, mu=mu)
            F, X, _ = vr_lgd(total_loss, g, alpha, d, K, T, m, pagg, omega, method="SAGA", psvrg=0, skip_it=skip_it, track_agg=False)
            Flist.append(F)
            Xlist.append(X)

        # save F
        for backup in [True, False]:
            filename = createfilename(experiment, dataset, T, omega, mu, pwork, method, backup)
            pickle_out = open(filename, "wb")
            pickle.dump((Flist, Xlist), pickle_out)
            pickle_out.close()


    # load F
    (Flist, Xlist) = load_pickle(experiment, dataset, T, omega, mu, pwork, method)


    it_list = [it/n for i in Flist]

    # plot
    alphas = None#
    visualize(Flist, it_list, "pagg_exp_epoch", dataset, labels, linestyles=linestyles[shift:], markers=markers[shift:],
              colors=colors[shift:], alphas=alphas, muT=mu/T, ppe=n/T)



