import numpy as np


def synthetic_documents(topics, lam=30.0, n_docs=100):
    K, V = topics.shape

    syn_tokens, syn_counts, Y = [], [], []

    doc_topics = np.random.choice(K, size=n_docs, replace=True)
    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4)

    for length, topic in zip(doc_lengths, doc_topics):
        document = np.random.choice(V, size=length, replace=True, p=topics[topic, :])
        tokens, counts = np.unique(document[:-1], return_counts=True)
        target_word = document[-1]

        syn_tokens.append(tokens)
        syn_counts.append(counts)
        Y.append(target_word)

    return (syn_tokens, syn_counts, np.array(Y))

def synthetic_multitarget_docs(topics, twords=2, lam=30.0, n_docs=100):
    K, V = topics.shape

    syn_tokens, syn_counts, Y = [], [], []

    doc_topics = np.random.choice(K, size=n_docs, replace=True)
    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4+twords)

    for length, topic in zip(doc_lengths, doc_topics):
        document = np.random.choice(V, size=length, replace=True, p=topics[topic, :])
        tokens, counts = np.unique(document[:-twords], return_counts=True)
        target_words = document[-twords:]

        syn_tokens.append(tokens)
        syn_counts.append(counts)
        Y.append(target_words)

    return (syn_tokens, syn_counts, np.array(Y))

# generate documents based on the full LDA model
def synthetic_lda_docs(topic_matrix, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    #n_docs, K = topic_dist.shape # topic_dist[i,:] is topic distribution for i-th document, generated from dirichlet dist

    syn_tokens, syn_counts, Y = [], [], []

    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4)
    topic_dist = np.random.dirichlet(np.ones(K)/K, n_docs) # size n_docs*K, where row_sum = 1 for each row

    for i in range(n_docs):
        document = []
        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True, p=topic_dist[i, :]) # index of topics in current document
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j],:]))

        tokens, counts = np.unique(document, return_counts=True)
        syn_tokens.append(tokens)
        syn_counts.append(counts)

        extra_topics = np.random.choice(K, size=twords, replace=True, p=topic_dist[i, :]) # index of topics for t extra words
        target_words = []
        for t in range(twords):
            target_words.append(np.random.choice(V, size=1, p=topic_matrix[extra_topics[t],:]))

        Y.append(target_words)

    return (syn_tokens, syn_counts, np.array(Y))

def synthetic_ctm_docs(topic_matrix, sigma, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    # n_docs, K = topic_dist.shape # topic_dist[i,:] is topic distribution for i-th document, generated from dirichlet dist

    syn_tokens, syn_counts, Y = [], [], []

    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4)

    topic_dist = np.random.multivariate_normal(np.zeros(K), sigma, size=n_docs)  # size n_docs*K, where row_sum = 1 for each row
    row_sum = np.sum(np.exp(topic_dist), axis=-1)
    topic_dist = np.exp(topic_dist)/row_sum[:, np.newaxis]

    for i in range(n_docs):
        document = []
        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True,
                                      p=topic_dist[i, :])  # index of topics in current document
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j], :]).item())

        tokens, counts = np.unique(document, return_counts=True)
        syn_tokens.append(tokens)
        syn_counts.append(counts)

        extra_topics = np.random.choice(K, size=twords, replace=True,
                                        p=topic_dist[i, :])  # index of topics for t extra words
        target_words = []
        for t in range(twords):
            target_words.append(np.random.choice(V, size=1, p=topic_matrix[extra_topics[t], :]).item())

        Y.append(target_words)

    return (syn_tokens, syn_counts, np.array(Y))

def synthetic_ctm_docs_attn(topic_matrix, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    K_s = int(K/2)

    ret_docs, Y = [], []

    doc_lengths = np.repeat(lam, n_docs)

    for i in range(n_docs):
        document = []
        super_topics = np.random.dirichlet((1/K_s)*np.ones(K_s))
        supertopic_topics = correlated_supertopic_topic(K,K_s,30)
        topic_dist = super_topics.dot(supertopic_topics)

        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True, p=topic_dist)
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j], :]).item())

        ret_docs.append(document)

        extra_topics = np.random.choice(K, size=twords, replace=True, p=topic_dist) # index of topics for t extra words
        target_words = []
        for t in range(twords):
            target_words.append(np.random.choice(V, size=1, p=topic_matrix[extra_topics[t], :]).item())

        Y.append(target_words)

    return (ret_docs, np.array(Y))


def correlated_supertopic_topic(K=20,K_s=10,dirich_param=30):
    supertopic_topics = []
    dirich_sample = np.random.dirichlet([dirich_param,dirich_param])
    for j in range(0,K,int(K/K_s)):
        w=np.zeros(K)
        w[j] = dirich_sample[0]
        w[j+1] = dirich_sample[1]
        supertopic_topics.append(w)

    return np.array(supertopic_topics)