import numpy as np
import scipy.sparse
import pickle
import os
import torch


def synthetic_documents(topics, lam=30.0, n_docs=100):
    K, V = topics.shape

    syn_tokens, syn_counts, Y = [], [], []

    doc_topics = np.random.choice(K, size=n_docs, replace=True)
    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs), 4)

    for length, topic in zip(doc_lengths, doc_topics):
        document = np.random.choice(V, size=length, replace=True, p=topics[topic, :])
        tokens, counts = np.unique(document[:-1], return_counts=True)
        target_word = document[-1]

        syn_tokens.append(tokens)
        syn_counts.append(counts)
        Y.append(target_word)

    return (syn_tokens, syn_counts, np.array(Y))

def synthetic_multitarget_docs(topics, twords=2, lam=30.0, n_docs=100):
    K, V = topics.shape

    syn_tokens, syn_counts, Y = [], [], []

    doc_topics = np.random.choice(K, size=n_docs, replace=True)
    doc_lengths = np.maximum(np.random.poisson(lam=lam, size=n_docs),
4+twords)

    for length, topic in zip(doc_lengths, doc_topics):
        document = np.random.choice(V, size=length, replace=True, p=topics[topic, :])
        tokens, counts = np.unique(document[:-twords], return_counts=True)
        target_words = document[-twords:]

        syn_tokens.append(tokens)
        syn_counts.append(counts)
        Y.append(target_words)

    return (syn_tokens, syn_counts, np.array(Y))

