import numpy as np


def synthetic_documents(topics, lam=30.0, n_docs=100):
    K, V = topics.shape

    syn_tokens, syn_counts, Y = [], [], []

    doc_topics = np.random.choice(K, size=n_docs, replace=True)
    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4)

    for length, topic in zip(doc_lengths, doc_topics):
        document = np.random.choice(V, size=length, replace=True, p=topics[topic, :])
        tokens, counts = np.unique(document[:-1], return_counts=True)
        target_word = document[-1]

        syn_tokens.append(tokens)
        syn_counts.append(counts)
        Y.append(target_word)

    return (syn_tokens, syn_counts, np.array(Y))

def synthetic_multitarget_docs(topics, twords=2, lam=30.0, n_docs=100):
    K, V = topics.shape

    syn_tokens, syn_counts, Y = [], [], []

    doc_topics = np.random.choice(K, size=n_docs, replace=True)
    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4+twords)

    for length, topic in zip(doc_lengths, doc_topics):
        document = np.random.choice(V, size=length, replace=True, p=topics[topic, :])
        tokens, counts = np.unique(document[:-twords], return_counts=True)
        target_words = document[-twords:]

        syn_tokens.append(tokens)
        syn_counts.append(counts)
        Y.append(target_words)

    return (syn_tokens, syn_counts, np.array(Y))

# generate documents based on the full LDA model
def synthetic_lda_docs(topic_matrix, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    #n_docs, K = topic_dist.shape # topic_dist[i,:] is topic distribution for i-th document, generated from dirichlet dist

    syn_tokens, syn_counts, Y = [], [], []

    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4)
    topic_dist = np.random.dirichlet(np.ones(K)/K, n_docs) # size n_docs*K, where row_sum = 1 for each row

    for i in range(n_docs):
        document = []
        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True, p=topic_dist[i, :]) # index of topics in current document
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j],:]))

        tokens, counts = np.unique(document, return_counts=True)
        syn_tokens.append(tokens)
        syn_counts.append(counts)

        extra_topics = np.random.choice(K, size=twords, replace=True, p=topic_dist[i, :]) # index of topics for t extra words
        target_words = []
        for t in range(twords):
            target_words.append(np.random.choice(V, size=1, p=topic_matrix[extra_topics[t],:]))

        Y.append(target_words)

    return (syn_tokens, syn_counts, np.array(Y))

def synthetic_ctm_docs(topic_matrix, sigma, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    # n_docs, K = topic_dist.shape # topic_dist[i,:] is topic distribution for i-th document, generated from dirichlet dist

    syn_tokens, syn_counts, Y = [], [], []

    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4)

    topic_dist = np.random.multivariate_normal(np.zeros(K), sigma, size=n_docs)  # size n_docs*K, where row_sum = 1 for each row
    row_sum = np.sum(np.exp(topic_dist), axis=-1)
    topic_dist = np.exp(topic_dist)/row_sum[:, np.newaxis]

    for i in range(n_docs):
        document = []
        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True,
                                      p=topic_dist[i, :])  # index of topics in current document
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j], :]).item())

        tokens, counts = np.unique(document, return_counts=True)
        syn_tokens.append(tokens)
        syn_counts.append(counts)

        extra_topics = np.random.choice(K, size=twords, replace=True,
                                        p=topic_dist[i, :])  # index of topics for t extra words
        target_words = []
        for t in range(twords):
            target_words.append(np.random.choice(V, size=1, p=topic_matrix[extra_topics[t], :]).item())

        Y.append(target_words)

    return (syn_tokens, syn_counts, np.array(Y))

def synthetic_ctm_docs_attn(topic_matrix, sigma, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    # n_docs, K = topic_dist.shape # topic_dist[i,:] is topic distribution for i-th document, generated from dirichlet dist

    ret_docs, Y = [], []

    doc_lengths = np.repeat(lam, n_docs)

    topic_dist = np.random.multivariate_normal(np.zeros(K), sigma, size=n_docs)  # size n_docs*K, where row_sum = 1 for each row
    row_sum = np.sum(np.exp(topic_dist), axis=-1)
    topic_dist = np.exp(topic_dist)/row_sum[:, np.newaxis]

    for i in range(n_docs):
        document = []
        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True,
                                      p=topic_dist[i, :])  # index of topics in current document
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j], :]).item())

        ret_docs.append(document)

        extra_topics = np.random.choice(K, size=twords, replace=True,
                                        p=topic_dist[i, :])  # index of topics for t extra words
        target_words = []
        for t in range(twords):
            target_words.append(np.random.choice(V, size=1, p=topic_matrix[extra_topics[t], :]).item())

        Y.append(target_words)

    return (ret_docs, np.array(Y))