import os
import numpy as np
from scipy import sparse
from scipy.io import savemat
import pickle

def make_data(alpha, data_path, N_test=50, lamb=30, gen_new_tests=False):
    if (not os.path.isdir(data_path)):
        os.mkdir(data_path)

    V = 5000  ## vocab size
    #lamb = 30  ## average document length
    K = 20  ## number of topics
    # alpha = alpha/K ## Dirichlet parameter
    vocab = [str(i) for i in range(V)]

    ## generate topics
    # topics = np.random.dirichlet(alpha*np.ones(V), K)
    topicmat = np.load('TopicMatrices.npy')
    A = topicmat[int(alpha) - 1]  # word-distribution-by-topic matrix, size=(K*V)

    sigma = np.load('cov_matrix.npy')

    def generate_documents(num):
        ## generate document lengths
        doc_lengths = np.repeat(lamb, num)

        ## generate topic distribution for each document, multivariate-Gaussian normalized
        topic_dist = np.random.multivariate_normal(np.zeros(K), sigma, size=num)  # size n_docs*K, where row_sum = 1 for each row
        row_sum = np.sum(np.exp(topic_dist), axis=-1)
        topic_dist = np.exp(topic_dist) / row_sum[:, np.newaxis] # size=(num*K)

        document_list = []

        for i in range(num):
            document = []
            ## sample the topics for current document
            topic_index_list = np.random.choice(K, size=doc_lengths[i], replace=True, p=topic_dist[i, :])
            for j in range(doc_lengths[i]):
                ## get topic for the j-th word in document
                z = A[topic_index_list[j], :]
                ## sample word from z
                x = np.random.choice(V, size=1, p=z)
                document.append(x.item())
            document_list.append(document)

        return (topic_dist, document_list)


    #with open(os.path.join(data_path, 'test_topics_dist.pkl'), 'wb') as f:
     #   pickle.dump(test_topic_dist, f)

    with open(os.path.join(data_path, 'topics.pkl'), 'wb') as f:
        pickle.dump(A, f)

    id2word = dict([(w, str(w)) for w in range(V)])


    ## Documents to test on
    make_new = not ('bow_test_tokens_'+str(N_test) in os.listdir(data_path) and 'bow_test_counts_'+str(N_test) in os.listdir(data_path))
    if gen_new_tests:
        make_new = True

    if make_new:
        test_topic_dist, test_documents = generate_documents(N_test)

        print('creating and saving new test documents')
        with open(os.path.join(data_path, 'test_topics_dist.pkl'), 'wb') as f:
           pickle.dump(test_topic_dist, f)

        X_test = test_documents

        indexed_test = [[int(word) for word in doc] for doc in X_test]

        def create_list_words(in_docs):
            return [x for y in in_docs for x in y]

        words_test = create_list_words(indexed_test)

        def create_doc_indices(in_docs):
            aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)]
            return [int(x) for y in aux for x in y]

        doc_indices_test = create_doc_indices(indexed_test)
        n_docs_test = len(indexed_test)

        def create_bow(doc_indices, words, n_docs, vocab_size):
            return sparse.coo_matrix(([1] * len(doc_indices), (doc_indices, words)), shape=(n_docs, vocab_size)).tocsr()

        bow_test = create_bow(doc_indices_test, words_test, n_docs_test, V)

        print('splitting bow into token/value pairs and saving to disk...')

        def split_bow(bow_in, n_docs):
            indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)]
            counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)]
            return indices, counts

        bow_test_tokens, bow_test_counts = split_bow(bow_test, n_docs_test)
        savemat(os.path.join(data_path, 'bow_test_tokens_'+str(N_test)), {'tokens': bow_test_tokens}, do_compression=True)
        savemat(os.path.join(data_path, 'bow_test_counts_'+str(N_test)), {'counts': bow_test_counts}, do_compression=True)


    with open(os.path.join(data_path, 'vocab.pkl'), 'wb') as f:
        pickle.dump(vocab, f)

    with open(os.path.join(data_path, 'id2word.pkl'), 'wb') as f:
        pickle.dump(id2word, f)

    return (A, id2word)
