# Effect of aggregation probability on the convergence rate



from algorithms import *

plot_only = False

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:brown', 'tab:green', 'tab:red', 'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']

experiment = "quadratic"


datasets = ["a1a", "duke", "phishing", "madelon", "mushrooms", "a8a", "gisette_scale"]

T_dict = {"a1a": 321, "mushrooms": 677,  "phishing": 1005, "duke": 11, "madelon": 500, "gisette_scale": 60, "a8a": 2837}



# characterization of the problem
for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    T = T_dict[dataset]


    A, b = get_data(dataset)
    A = A[:50, :200]
    b = b[:50]

    n, d = A.shape

    A = normalize_data(A)
    A, b = rearrange_data(A, "random", b)


    K = int(500)
    skip_it = int(K / 50)

    x0 = np.random.randn(d)


    Lmat = np.dot(A.T, A)
    L = eigh(Lmat, eigvals = (d-1,d-1))[0][0]
    Lmax = np.max(A*A)*n

    alpha = 1/(L + Lmax*d/n)/2
    print("L: {}, LLL: {}".format(L, Lmax*d/n))



    Flist = []
    Xlist = []

    labels = [
#        r'$\lambda = 25\lambda_0$, ours',
#        r'$\lambda = 25\lambda_0$, direct',
        r'$\lambda = 3\lambda_0$, ours',
        r'$\lambda = 3\lambda_0$, direct',
        r'$\lambda = \lambda_0$, ours',
        r'$\lambda = \lambda_0$, direct',
        r'$\lambda = 0.33\lambda_0$, ours',
        r'$\lambda = 0.33\lambda_0$, direct'
    ]

    it = create_it(T=K, skip_it=skip_it, tau = 1)


    vr = 1

    settings = [(False, True), (False, False), (True, False)]


    alphas = [3*alpha, alpha, alpha/3]

    count = 0
    if not plot_only:
        for alpha in alphas:
            for simple_spars in [False, True]:
                print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
                print("label".format(labels[count]))
                print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
                F, x = dist_GD_quad(A, alpha, K, skip_it, simple_spars,  x0=x0, pre_text='')
                Flist.append((F)/(F[0]))
                Xlist.append(x)
                count += 1


    # save F
    for backup in [True, False]:
        filename = createfilename(experiment, dataset, T,  0.0, backup)
        pickle_out = open(filename, "wb")
        pickle.dump((Flist, Xlist), pickle_out)
        pickle_out.close()


    # load F
    (Flist, Xlist) = load_pickle(experiment, dataset, T, 0.0)


    it_list = [it for i in Flist]

    # plot
    alphas = None#
    visualize(Flist, it_list, "sparsif_quad_", dataset, labels, linestyles=linestyles, markers=markers,
              colors=colors, alphas=alphas, muT=0.0/T, ppe=n/T)



