import os
import numpy as np
from scipy import sparse
from scipy.io import savemat
import pickle

def make_data(alpha, data_path, emb_type=None, N_train=30000, N_test=50, lamb=30):
    if (not os.path.isdir(data_path)):
        os.mkdir(data_path)

    V = 5000  ## vocab size
    #lamb = 30  ## average document length
    K = 20  ## number of topics
    # alpha = alpha/K ## Dirichlet parameter
    vocab = [str(i) for i in range(V)]

    ## generate topics
    # topics = np.random.dirichlet(alpha*np.ones(V), K)
    with open('TopicsMatrix.pkl', 'rb') as f:
        topicmat = pickle.load(f)
    A = topicmat[int(alpha) - 1]  # word-distribution-by-topic matrix, size=(K*V)

    def generate_documents(num):
        ## generate document lengths
        doc_lengths = np.maximum(np.random.poisson(lam=lamb, size=num), 4) # length=num

        ## generate topic distribution for each document, alpha=1/K
        topic_dist = np.random.dirichlet(np.ones(K)/K, num) # size=(num*K)

        document_list = []

        for i in range(num):
            document = []
            ## sample the topics for current document
            topic_index_list = np.random.choice(K, size=doc_lengths[i], replace=True, p=topic_dist[i, :])
            for j in range(doc_lengths[i]):
                ## get topic for the j-th word in document
                z = A[topic_index_list[j], :]
                ## sample word from z
                x = np.random.choice(V, size=1, p=z)
                document.append(x)
            document_list.append(document)

        ## sample the topic for each document (equal probability)

        return (topic_dist, document_list)


    with open(os.path.join(data_path, 'topics.pkl'), 'wb') as f:
        pickle.dump(A, f)

    id2word = dict([(w, str(w)) for w in range(V)])

    pretrain_docs = None
    ## Get pretrained word embeddings
    if emb_type == 'word2vec':
        _, pretrain_docs = generate_documents(N_train)
        #make_embeddings(data_path, pretrain_docs, V, emb_size, w2v_window)

    ## Documents to test on
    if not ('bow_test_tokens_'+str(N_test) in os.listdir(data_path) and 'bow_test_counts_'+str(N_test) in
os.listdir(data_path)):
        test_topic_dist, test_documents = generate_documents(N_test)

        print('creating and saving new test documents')
        with open(os.path.join(data_path, 'test_topics_dist.pkl'), 'wb') as f:
           pickle.dump(test_topic_dist, f)

        X_test = test_documents

        indexed_test = [[int(word) for word in doc] for doc in X_test]

        def create_list_words(in_docs):
            return [x for y in in_docs for x in y]

        words_test = create_list_words(indexed_test)

        def create_doc_indices(in_docs):
            aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)]
            return [int(x) for y in aux for x in y]

        doc_indices_test = create_doc_indices(indexed_test)
        n_docs_test = len(indexed_test)

        def create_bow(doc_indices, words, n_docs, vocab_size):
            return sparse.coo_matrix(([1] * len(doc_indices), (doc_indices, words)), shape=(n_docs, vocab_size)).tocsr()

        bow_test = create_bow(doc_indices_test, words_test, n_docs_test, V)

        print('splitting bow into token/value pairs and saving to disk...')

        def split_bow(bow_in, n_docs):
            indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)]
            counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)]
            return indices, counts

        bow_test_tokens, bow_test_counts = split_bow(bow_test, n_docs_test)
        savemat(os.path.join(data_path, 'bow_test_tokens_'+str(N_test)), {'tokens': bow_test_tokens}, do_compression=True)
        savemat(os.path.join(data_path, 'bow_test_counts_'+str(N_test)), {'counts': bow_test_counts}, do_compression=True)


    with open(os.path.join(data_path, 'vocab.pkl'), 'wb') as f:
        pickle.dump(vocab, f)

    with open(os.path.join(data_path, 'id2word.pkl'), 'wb') as f:
        pickle.dump(id2word, f)

    return (A, id2word, pretrain_docs)
