import os
import numpy as np
from scipy import sparse
from scipy.io import savemat
import pickle

def correlated_supertopic_topic(K=20, K_s=10, dirich_param=30):
    supertopic_topics = []
    dirich_sample = np.random.dirichlet([dirich_param,dirich_param])
    for j in range(0,K,int(K/K_s)):
        w=np.zeros(K)
        w[j] = dirich_sample[0]
        w[j+1] = dirich_sample[1]
        supertopic_topics.append(w)

    return np.array(supertopic_topics)

def make_data(alpha, data_path, N_test=50, lamb=30, gen_new_tests=False):
    if (not os.path.isdir(data_path)):
        os.mkdir(data_path)
        
    ## generate topics
    topicmat = np.load('TopicMatrices.npy')
    A = topicmat[int(alpha) - 1]  # word-distribution-by-topic matrix, size=(K*V)
    K, V = A.shape
    vocab = [str(i) for i in range(V)]

    def generate_documents(num, K_s=10):
        doc_lengths = np.repeat(lamb, num)

        documents = []
        doc_topics = []

        for length in doc_lengths:
            super_topics = np.random.dirichlet((1 / K_s) * np.ones(K_s))  # super topic dist
            supertopic_topics = correlated_supertopic_topic(K, K_s, 30)

            topic_dist = super_topics.dot(supertopic_topics)

            doc = np.random.choice(V, p=(topic_dist).dot(A), size=length)
            doc_topics.append(topic_dist)
            documents.append(doc)
        return (doc_topics, documents)

    with open(os.path.join(data_path, 'topics.pkl'), 'wb') as f:
        pickle.dump(A, f)

    id2word = dict([(w, str(w)) for w in range(V)])


    ## Documents to test on
    make_new = not ('bow_test_tokens_'+str(N_test) in os.listdir(data_path) and 'bow_test_counts_'+str(N_test) in os.listdir(data_path))
    if gen_new_tests:
        make_new = True

    if make_new:
        test_topic_dist, test_documents = generate_documents(N_test, K_s=K//2)

        print('creating and saving new test documents')
        with open(os.path.join(data_path, 'test_topics_dist.pkl'), 'wb') as f:
           pickle.dump(test_topic_dist, f)

        X_test = test_documents

        indexed_test = [[int(word) for word in doc] for doc in X_test]

        def create_list_words(in_docs):
            return [x for y in in_docs for x in y]

        words_test = create_list_words(indexed_test)

        def create_doc_indices(in_docs):
            aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)]
            return [int(x) for y in aux for x in y]

        doc_indices_test = create_doc_indices(indexed_test)
        n_docs_test = len(indexed_test)

        def create_bow(doc_indices, words, n_docs, vocab_size):
            return sparse.coo_matrix(([1] * len(doc_indices), (doc_indices, words)), shape=(n_docs, vocab_size)).tocsr()

        bow_test = create_bow(doc_indices_test, words_test, n_docs_test, V)

        print('splitting bow into token/value pairs and saving to disk...')

        def split_bow(bow_in, n_docs):
            indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)]
            counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)]
            return indices, counts

        bow_test_tokens, bow_test_counts = split_bow(bow_test, n_docs_test)
        savemat(os.path.join(data_path, 'bow_test_tokens_'+str(N_test)), {'tokens': bow_test_tokens}, do_compression=True)
        savemat(os.path.join(data_path, 'bow_test_counts_'+str(N_test)), {'counts': bow_test_counts}, do_compression=True)


    with open(os.path.join(data_path, 'vocab.pkl'), 'wb') as f:
        pickle.dump(vocab, f)

    with open(os.path.join(data_path, 'id2word.pkl'), 'wb') as f:
        pickle.dump(id2word, f)

    return (A, id2word)
