# Effect of aggregation probability on the convergence rate



from algorithms import *

plot_only = False

linestyles = ['-', '--', ':', '-', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', 'd', '*',  'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:brown', 'tab:green', 'tab:red', 'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']

mu = 1e-5
experiment = "vr_acc2"


datasets = ["a8a", "madelon", "duke"]#["a1a", "duke", "phishing", "madelon", "mushrooms", "a8a", "gisette_scale"]

T_dict = {"a1a": 15, "mushrooms": 677,  "phishing": 1005, "duke": 11, "madelon": 500, "gisette_scale": 60, "a8a": 2837}


omega_dict = {"a1a": 0.1, "mushrooms": 0.05, "w2a": 0.01, "phishing": 0.1, "duke": 0.4, "madelon": 0.02, "gisette_scale": 0.2, "a8a": 0.1}

# characterization of the problem
for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    T = T_dict[dataset]


    A, b = get_data(dataset)

    n, d = A.shape

    A = normalize_data(A)
    A, b = rearrange_data(A, "random", b)

    blocks = create_blocks(n, T)

    K = int(2000)
    skip_it = int(K/200)

    v = (1 + 1e-6)*np.ones(n)
    m = int(n/T)
    f, g = make_fg_logreg(A, b, mu)




    def total_loss(x): return objective(f, x, n)


    if m*T != n:
        continue
    if not plot_only:
        Lmat = np.dot(A.T,A)/n/4 + mu*np.eye(d)
        L = eigh(Lmat, eigvals = (d-1,d-1))[0][0]

        x0 = np.zeros(d)  # find_xstar(g, mu, L, d, n, 5000, pre_text='', loss=total_loss)

        xstar = find_xstar(g, mu, L, d, n, K, pre_text='', loss=total_loss, x0 = x0)

        Fstar = total_loss(xstar)

    Flist2 = []
    Xlist = []

    labels = ["ADIANA+",
              "DIANA+",
              "DCGD+",
              "ADIANA",
              "DIANA",
              "DCGD"
              ]

    it = create_it(T=K, skip_it=skip_it, tau = 1)




    tau = 1
    simple_spars = False
    imp = False
#    settings = [(True, True), (False, True)]
    settings = [(True, True)]#, (False, True), (False, False)]

    if not plot_only:
        for simple_spars in [False, True]:
            for (acc, vr) in settings:
                F, x = dist_GD(total_loss, g, blocks, A, mu, vr,  imp, K, skip_it, tau, simple_spars, acc, x0=x0, pre_text='')
                Flist2.append((F - Fstar)/(F[0]-Fstar))
                Xlist.append(x)



         # save F
        for backup in [True, False]:
            filename = createfilename(experiment + "xx", dataset, T,  mu, backup)
            pickle_out = open(filename, "wb")
            pickle.dump((Flist2, Xlist), pickle_out)
            pickle_out.close()


    # load F
    (Flist, Xlist) = load_pickle(experiment, dataset, T, mu)

    Flist[0] = Flist2[0]
    Flist[1] = Flist[1][:len(Flist[0])]
    Flist[2] = Flist[2][:len(Flist[0])]
    Flist[3] = Flist2[1]
    Flist[4] = Flist[4][:len(Flist[0])]
    Flist[5] = Flist[5][:len(Flist[0])]

    it_list = [it for i in Flist]

    # plot
    alphas = None#
    visualize(Flist, it_list, "vr_acc2_", dataset, labels, linestyles=linestyles, markers=markers,
              colors=colors, alphas=alphas, muT=mu/T, ppe=n/T)



