import pickle
import random
import numpy as np

V = 5000 ## vocab size
lamb = 30 ## average document length
K = 20 ## number of topics

def fill_in(mylist: list, indices: list, vals: list):
    assert len(indices) == len(vals)
    i = 0
    for ind in indices:
        mylist[ind] = vals[i]
        i += 1


def generate_vector_group(vector: np.ndarray, start: int, threshold=0.001):
    '''
    Return a group of 4 vectors, where vector 1 ~ vector 2 != vector 3, vector 4, and vector 3 != vector 4
    '''

    # sort values, top 250, top 1000, the rest
    vector_sorted = sorted(list(vector))
    vector_sorted.reverse()

    top_vals = vector_sorted[:250]
    mid_vals = vector_sorted[250:1000]
    rest = vector_sorted[1000:]

    sim_vecs = []
    for _ in range(2):
        v = [0] * 5000
        # select indices to fill in the values
        # fill in the largest 250 values
        top_indices = sorted(list(np.random.choice(range(start, start + 300), 250, replace=False)))
        fill_in(v, top_indices, top_vals)

        # fill in the smallest 4000 values
        random.shuffle(rest)
        fill_in(v, list(range(0, start)) + list(range(start + 1000, 5000)), rest)

        # fill in the middle values
        j = 0
        for i in range(start, start + 1000):
            if v[i] == 0:
                v[i] = mid_vals[j]
                j += 1

        sim_vecs.append(v)

    diff_vecs = []
    for _ in range(2):
        v = [0] * 5000

        # fill in top 1000 values
        vals = top_vals + mid_vals
        random.shuffle(vals)
        fill_in(v, list(range(start, start + 1000)), vals)

        # fill in the smallest 4000 values
        random.shuffle(rest)
        fill_in(v, list(range(0, start)) + list(range(start + 1000, 5000)), rest)

        diff_vecs.append(v)

    return sim_vecs[0], sim_vecs[1], diff_vecs[0], diff_vecs[1]

def get_matrix_CTM(alpha):
    A = np.zeros((20, 5000))
    dir_vectors = np.random.dirichlet(np.ones(5000)*alpha/20, 5)
    for i in range(5):
        v1, v2, v3, v4 = generate_vector_group(dir_vectors[i], start=1000*i)
        A[4*i] = np.array(v1)
        A[4*i+1] = np.array(v2)
        A[4*i+2] = np.array(v3)
        A[4*i+3] = np.array(v4)
    return A

def get_matrix_PAM(alpha):
    A = np.zeros((20, 5000))
    dir_vectors = np.random.dirichlet(np.ones(5000)*alpha/20, 5)
    for i in range(5):
        v1, v2, v3, v4 = generate_vector_group(dir_vectors[i], start=1000*i)
        A[4*i] = np.array(v1)
        A[4*i+2] = np.array(v2)
        A[4*i+1] = np.array(v3)
        A[4*i+3] = np.array(v4)
    return A

if __name__ == '__main__':
    res=np.zeros((10, K, V))
    alphas=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]

    for i,alpha in enumerate(alphas):
        alpha = alpha/K
        topics = np.random.dirichlet(alpha*np.ones(V), K)
        res[i,:,:]=topics

    # Pure
    with open('src_pure/TopicsMatrix.pkl', 'wb') as f:
        pickle.dump(res, f)

    # LDA
    with open('src_LDA/TopicsMatrix.pkl', 'wb') as f:
        pickle.dump(res, f)

    # CTM
    ctm_mat = np.zeros((10, K, V))
    for alpha in range(10):
        ctm_mat[alpha, :, :] = get_matrix_CTM(alpha+1)
    np.save('src_CTM/TopicMatrices.npy', ctm_mat)

    # PAM
    pam_mat = np.zeros((10, K, V))
    for alpha in range(10):
        pam_mat[alpha, :, :] = get_matrix_PAM(alpha + 1)
    np.save('src_PAM/TopicMatrices.npy', pam_mat)

