from neurips_algorithm import *


def gen_exp_data_params_spe(d,k,m1,n1,m2,n2,sigma,con_num, kspe = 0, alpha= None):
    W = np.random.uniform(low=-1.0, high=1.0, size=d * k).reshape(k, d)
    Sigmas = []
    for i in range(k):
        Sigmas.append(generate_cov_matrix(d, con_num))
    counts_small = batches_count_each_type(m1, k, kspe , alpha)
    BS = generate_batches(W, Sigmas, n1, m1, sigma, counts_small)
    counts_medium = batches_count_each_type(m2, k, kspe , alpha)
    BM = generate_batches(W, Sigmas, n2, m2, sigma, counts_medium)
    Bspe = None
    if kspe != 0:
        Bspe = generate_batches(W[:kspe], Sigmas[:kspe], n2, kspe, sigma)
        W = W[:kspe]
        Sigmas = Sigmas[:kspe]
    return W, Sigmas, BS, BM, Bspe


def run_inf_exp_comp():
    d = 100
    k = 100
    kspe = 4
    alpha = 1/16
    sigma = 1
    n1 = 2
    m1 = 200000
    m2 = 256
    ell = 4 *kspe
    con_num = 4

    init_est = np.zeros((d, 1))
    n2_list = [4, 6, 8, 12, 16, 24, 32]
    rep = 10

    Max_R = 10 # No. of rounds

    m3 = 1600


    b_pred_error = []
    c_pred_error = []
    for n2 in n2_list:
        R = min(n2, Max_R)
        b_pred_error_n2 =[]
        c_pred_error_n2 = []


        for i in range(rep):
            print("rep ", i, " for ", n2)
            W, Sigmas, BS, BM, Bspe = gen_exp_data_params_spe(d,k,m1,n1,m2,n2,sigma,con_num, kspe, alpha)

            L = main_algo_multiple_comp(BS, BM, Bspe, sigma, ell, con_num, R, init_est)

            b_r_pred_error, b_r_incorrect_count, b_final_L = clustering_using_list(L, W, Sigmas, sigma, m3, 4)
            b_pred_error_n2.append(b_r_pred_error)

            c_r_pred_error, c_r_incorrect_count, c_final_L = clustering_using_list(L, W, Sigmas, sigma, m3, 8)
            c_pred_error_n2.append(c_r_pred_error)


        b_pred_error.append(b_pred_error_n2)
        c_pred_error.append(c_pred_error_n2)
    filenameb = "fig4_size_4"+ ".txt"
    filenamec = "fig4_size_8"+ ".txt"
    b_f = open(filenameb, "x")
    c_f = open(filenamec, "x")
    b_f.write("\n new algorithm\n")
    for item in b_pred_error:
        b_f.write(', '.join(str(stat) for stat in item))
        b_f.write("\n")
    b_f.write("\n")
    b_f.write("\n")
    b_f.close()

    c_f.write("\n new algorithm\n")
    for item in c_pred_error:
        c_f.write(', '.join(str(stat) for stat in item))
        c_f.write("\n")

    c_f.write("\n")
    c_f.write("\n")

    c_f.close()


run_inf_exp_comp()