# Effect of aggregation probability on the convergence rate



from algorithms import *

plot_only = False

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:brown', 'tab:green', 'tab:red', 'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']

experiment = "quadratic_imp"


datasets = ["a1a", "duke", "phishing", "madelon", "mushrooms", "a8a", "gisette_scale"]

T_dict = {"a1a": 321, "mushrooms": 677,  "phishing": 1005, "duke": 11, "madelon": 500, "gisette_scale": 60, "a8a": 2837}

n = 10

ds = [1000]#[10, 50, 200, 1000]

# characterization of the problem
for d in ds:
    print("#############################")
    print(d)
    print("#############################")


    A = np.zeros((n, d))

    for i in range(n):
        A[i,:] = np.random.randn(d)

    n, d = A.shape

    A = normalize_data(A)/2

    A, _ = rearrange_data(A, "random", np.zeros(n))

    K = int(500)
    skip_it = int(K / 50)

    x0 = np.random.randn(d)

    Lmat = np.dot(A.T, A)
    L = eigh(Lmat, eigvals = (d-1,d-1))[0][0]
    Lmax = np.mean(A*A)*n

    alpha = 1/(L + Lmax*d/n)/2
    print("L: {}, LLL: {}".format(L, Lmax*d/n))



    Flist = []
    Xlist = []

    labels = [
#        r'$\lambda = 25\lambda_0$, ours',
#        r'$\lambda = 25\lambda_0$, direct',
        r'$\lambda = 3\lambda_0$, ours',
        r'$\lambda = 3\lambda_0$, direct',
        r'$\lambda = \lambda_0$, ours',
        r'$\lambda = \lambda_0$, direct',
        r'$\lambda = 0.3\lambda_0$, ours',
        r'$\lambda = 0.3\lambda_0$, direct'
    ]

    it = create_it(T=K, skip_it=skip_it, tau = 1)


    vr = 1

    settings = [(False, True), (False, False), (True, False)]


    alphas = [3*alpha, alpha, 0.3*alpha]

    count = 0
    if not plot_only:
        for alpha in alphas:
            for simple_spars in [False, True]:
                print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
                print("label".format(labels[count]))
                print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
                F, x = dist_GD_quad(A, alpha, K, skip_it, simple_spars, imp=True, x0=x0, pre_text='')
                Flist.append((F)/(F[0]))
                Xlist.append(x)
                count += 1


    # save F
    for backup in [True, False]:
        filename = createfilename(experiment, str(d), 0,  0.0, backup)
        pickle_out = open(filename, "wb")
        pickle.dump((Flist, Xlist), pickle_out)
        pickle_out.close()


    # load F
    (Flist, Xlist) = load_pickle(experiment, str(d), 0, 0.0)


    it_list = [it for i in Flist]

    # plot
    alphas = None#
    visualize(Flist, it_list, "sparsif_quad_imp", "d = {}".format(d), labels, linestyles=linestyles, markers=markers,
              colors=colors, alphas=alphas, muT=0.0, ppe=n)



