from neurips_algorithm import *


def gen_exp_data_params(d,k,m1,n1,m2,n2,sigma,con_num):
    W1 = np.random.uniform(low=-1.0, high=1.0, size=d * (k//2)).reshape(k//2, d)
    W2 = np.random.uniform(low=-1.0, high=1.0, size=d * (k-k//2)).reshape((k-k//2), d)
    W = [W1, W2]
    W = np.concatenate(W, axis=0)
    print(W.shape)
    Sigmas = []
    for i in range(k):
        Sigmas.append(generate_cov_matrix(d, con_num))
    BS = generate_batches(W, Sigmas, n1, m1, sigma)
    BM = generate_batches(W, Sigmas, n2, m2, sigma)
    return W, Sigmas, BS, BM


def run_exp_comp_normal():
    d = 100
    k = 16
    sigma = 1
    n1 = 2
    m1 = 200000
    m2 = 256
    ell = k
    con_num = 1

    init_est = np.zeros((d, 1))

    n2_list = [4,6,8,12,16,24,32]
    rep = 10


    b_pred_error_prior = []
    b_pred_error = []
    c_pred_error_prior = []
    c_pred_error = []

    Max_R = 12
    m3 = 1600
    for n2 in n2_list:
        R = min(n2, Max_R)
        b_pred_error_n2 =[]
        b_pred_error_prior_n2 =[]
        c_pred_error_n2 = []
        c_pred_error_prior_n2 = []
        for i in range(rep):
            print("rep ", i, " for ", n2)
            W, Sigmas, BS, BM = gen_exp_data_params(d,k,m1,n1,m2,n2,sigma,con_num)

            L_prior_work = main_algo_prior_work(BS, BM, k, d)
            L= main_algo_all_comp(BS, BM, sigma, ell, con_num, init_est, R)

            b_r_pred_error_prior, b_r_incorrect_count_prior, b_final_L_prior_work = clustering_using_list(L_prior_work, W, Sigmas, sigma, m3, 4)
            b_pred_error_prior_n2.append(b_r_pred_error_prior)
            b_r_pred_error, b_r_incorrect_count, b_final_L = clustering_using_list(L, W, Sigmas, sigma, m3, 4)
            b_pred_error_n2.append(b_r_pred_error)

            c_r_pred_error_prior, c_r_incorrect_count_prior, c_final_L_prior_work = clustering_using_list(L_prior_work, W, Sigmas, sigma, m3, 8)
            c_pred_error_prior_n2.append(c_r_pred_error_prior)
            c_r_pred_error, c_r_incorrect_count, c_final_L = clustering_using_list(L, W, Sigmas, sigma, m3, 8)
            c_pred_error_n2.append(c_r_pred_error)

        b_pred_error_prior.append(b_pred_error_prior_n2)
        b_pred_error.append(b_pred_error_n2)
        c_pred_error_prior.append(c_pred_error_prior_n2)
        c_pred_error.append(c_pred_error_n2)

    filenameb = "fig1_size_4"+ ".txt"
    filenamec = "fig1_size_8"+ ".txt"
    b_f = open(filenameb, "x")
    c_f = open(filenamec, "x")
    b_f.write("\n Prior algorithm \n")
    for item in b_pred_error_prior:
        b_f.write(', '.join(str(stat) for stat in item))
        print(', '.join(str(stat) for stat in item))
        b_f.write("\n")
    b_f.write("\n")
    b_f.write("\n")
    b_f.write("\n new algorithm\n")
    for item in b_pred_error:
        b_f.write(', '.join(str(stat) for stat in item))
        b_f.write("\n")
    b_f.write("\n")
    b_f.write("\n")

    c_f.write("\n Prior algorithm \n")
    for item in c_pred_error_prior:
        c_f.write(', '.join(str(stat) for stat in item))
        c_f.write("\n")
    c_f.write("\n")
    c_f.write("\n")
    c_f.write("\n new algorithm\n")
    for item in c_pred_error:
        c_f.write(', '.join(str(stat) for stat in item))
        c_f.write("\n")

    c_f.write("\n")
    c_f.write("\n")

    c_f.close()



run_exp_comp_normal()