

import gzip
import logging
import multiprocessing
import os
import random
import time

import nltk
from joblib.parallel import Parallel, delayed

from ariel_tests.helpers.utils import setReproducible
from ariel_tests.language.nlp import postprocessSentence, NltkGrammarSampler, Vocabulary
from ariel_tests.language.main_biased import evaluateIndicesBias

logger = logging.getLogger(__name__)

CDIR = os.path.dirname(os.path.realpath(__file__))
CDIR_, _ = os.path.split(CDIR)
DATADIR = os.path.join(CDIR_, 'data')
UNBIASEDGZFILE = os.path.join(DATADIR, 'questions_unbiased.gz')


def generateUniques(grammar, n):
    vocabulary = Vocabulary.fromGrammar(grammar)
    sentences = set()
    sampler = NltkGrammarSampler(grammar)
    while len(sentences) < n:
        sentence = sampler.generate(1)[0]
        sentence = postprocessSentence(sentence)
        sentences.add(sentence)
        indices = vocabulary.sentenceToIndices(sentence)
        is_semantic = evaluateIndicesBias(indices)
        if not is_semantic:
            sentences.add(sentence)
    return sentences


def generateUnbiasedNltk(nbUniqueSamples, batch_size=10000, grammar_filepath='house_questions_grammar'):
    grammar = nltk.data.load(grammar_filepath)

    # Stochastic sampling of the grammar
    nbJobs = multiprocessing.cpu_count()
    sentences = set()
    startTime = time.time()
    try:
        with Parallel(n_jobs=nbJobs) as parallel:
            while len(sentences) < nbUniqueSamples:
                results = parallel(delayed(generateUniques)(grammar, batch_size) for _ in range(nbJobs))
                for s in results:
                    sentences |= s

                elapsedTime = time.time() - startTime
                curNbUniqueSamples = min(len(sentences), nbUniqueSamples)
                logger.info('Number of unique sentences generated: %d (%0.1f %% completed, elapsed %0.1f sec)' % (
                    curNbUniqueSamples, curNbUniqueSamples / nbUniqueSamples * 100.0, elapsedTime))

        # Clip to the desired total number of unique sentences
        sentences = list(sentences)[:nbUniqueSamples]

    except KeyboardInterrupt:
        logger.info('Interrupted by user.')
        if isinstance(sentences, set):
            sentences = list(sentences)
        logger.warn('Number of unique sentences generated: %d (incomplete)' % (len(sentences)))

    # shuffle sentences to make them as iid as possible
    for _ in range(7):
        random.shuffle(sentences)
    return sentences


def main(grammar_filepath):
    if not os.path.isfile(UNBIASEDGZFILE):
        nbUniqueSamples = int(1e4)  # int(1e4)
        batch_size = 10000 if nbUniqueSamples > 10000 else nbUniqueSamples

        sentences = generateUnbiasedNltk(
            nbUniqueSamples=nbUniqueSamples,
            batch_size=batch_size,
            grammar_filepath=grammar_filepath)

        # Display a subset of the sentences
        logger.info('Example sentences: \n' + '\n'.join(sentences[:10]))

        # Write sentences to file
        with gzip.open(UNBIASEDGZFILE, mode='wt') as f:
            for sentence in sentences:
                f.write(sentence + '\r\n')

        logger.info('All done.')


def run_unbiased_generation(grammar_filepath='house_questions_grammar.cfg'):
    if not os.path.isdir(DATADIR): os.mkdir(DATADIR)

    logging.basicConfig(level=logging.INFO)
    setReproducible(1)

    main(grammar_filepath)


if __name__ == '__main__':
    run_unbiased_generation()
