

import gzip
import logging
import multiprocessing
import os
import random
import time

import nltk
import numpy as np
from joblib.parallel import Parallel, delayed
from tqdm import tqdm

from ariel_tests.helpers.utils import setReproducible
from ariel_tests.language.nlp import postprocessSentence, NltkGrammarSampler, Vocabulary

logger = logging.getLogger(__name__)

CDIR = os.path.dirname(os.path.realpath(__file__))
CDIR_, _ = os.path.split(CDIR)
DATADIR = os.path.join(CDIR_, 'data')
BIASMATFILE = os.path.join(DATADIR, 'bias_matrix.npy')
BIASEDGZFILES = [os.path.join(DATADIR, 'questions_biased_' + key + '.gz') for key in ['train', 'test', 'val']]


def evaluateIndicesBias(indices, adjacency=None):
    if adjacency == None:
        A = np.load(BIASMATFILE)

    try:
        # this is only functional for HQ dataset
        sentence_submatrix = A[np.ix_(indices, indices)]
        is_semantic = np.prod(sentence_submatrix)
    except:
        # this metric works only for the HQ data
        is_semantic = 1
    return is_semantic


def generateUniques(grammar, n):
    vocabulary = Vocabulary.fromGrammar(grammar)
    sentences = set()
    sampler = NltkGrammarSampler(grammar)
    while len(sentences) < n:
        sentence = sampler.generate(1)[0]
        sentence = postprocessSentence(sentence)
        indices = vocabulary.sentenceToIndices(sentence)
        is_semantic = evaluateIndicesBias(indices)
        if is_semantic:
            sentences.add(sentence)
    return sentences


def test_semanticity():
    indices = [1, 1, 2, 3, 0, 0]

    is_semantic = evaluateIndicesBias(indices)
    print(is_semantic)


def getComposedConcepts(grammar_filepath):
    # grammar = nltk.data.load(grammar_filepath)
    grammar = nltk.data.load('file:' + grammar_filepath)

    composedConcepts = []
    for production in grammar.productions():
        words_in_p = [k for k in production._rhs if isinstance(k, str)]
        non_words_in_p = [k for k in production._rhs if not isinstance(k, str)]
        if len(non_words_in_p) == 0 and len(words_in_p) > 1:
            composedConcepts.append(words_in_p)

    return composedConcepts


def generate_bias(grammar_filepath):
    """ We will bias the sentences so some words can appear in the same sentence,
    while other can't, this allows us to emulate the fact that some objects are
    found in specific places, while never found in other, like a toaster will be
    unlikely to be found on the bed """

    if not os.path.isfile(BIASMATFILE):
        vocabulary = Vocabulary.fromGrammarFile(grammar_filepath)

        # some sort of adjacency matrix for the vocabulary, generated at random
        # to begin with, and refined below
        A_rows = []
        for _ in range(vocabulary.getMaxVocabularySize()):
            rand_p = np.random.rand()
            A_row = np.random.choice(2, size=(vocabulary.getMaxVocabularySize()), replace=True, p=[1 - rand_p, rand_p])
            A_rows.append(A_row)

        A = np.array(A_rows)

        # these tokens can be in any sentence
        specialTokens = ['a', 'an', 'that', 'the', 'this', ',', 'and', 'or', 'made', 'of',
                         'and', '!', '.', '?', 'is', 'it', 'the', 'object', 'thing', ]
        for token in specialTokens:
            A[:, vocabulary.indicesByTokens[token]] = 1

        # the following sequences of 2-3 words, appear often together as a concept,
        # e.g. 'article' 'of' 'clothing', 'audio' 'system'
        composedConcepts = getComposedConcepts(grammar_filepath)

        for concept in composedConcepts:
            for word_i in concept:
                for word_j in concept:
                    A[vocabulary.indicesByTokens[word_i], vocabulary.indicesByTokens[word_j]] = 1

        # make it symmetric
        A = np.clip(A + A.T, 0, 1)

        # save the adjacency matrix
        np.save(BIASMATFILE, A)


def generateBiasedNltk(nbUniqueSamples, batch_size=10000, grammar_filepath='house_questions_grammar.cfg'):
    # grammar = nltk.data.load(grammar_filepath)
    grammar = nltk.data.load('file:' + grammar_filepath)

    # Stochastic sampling of the grammar
    nbJobs = multiprocessing.cpu_count()
    sentences = set()
    startTime = time.time()
    try:
        with Parallel(n_jobs=nbJobs) as parallel:
            while len(sentences) < nbUniqueSamples:
                results = parallel(delayed(generateUniques)(grammar, batch_size) for _ in range(nbJobs))
                for s in tqdm(results):
                    sentences |= s

                elapsedTime = time.time() - startTime
                curNbUniqueSamples = min(len(sentences), nbUniqueSamples)
                logger.info('Number of unique sentences generated: %d (%0.1f %% completed, elapsed %0.1f sec)' % (
                    curNbUniqueSamples, curNbUniqueSamples / nbUniqueSamples * 100.0, elapsedTime))

        # Clip to the desired total number of unique sentences
        sentences = list(sentences)[:nbUniqueSamples]

    except KeyboardInterrupt:
        logger.info('Interrupted by user.')
        if isinstance(sentences, set):
            sentences = list(sentences)
        logger.warn('Number of unique sentences generated: %d (incomplete)' % (len(sentences)))

    # shuffle sentences to make them as iid as possible
    for _ in range(7):
        random.shuffle(sentences)
    return sentences


def main(grammar_filepath):
    if not all([os.path.isfile(path) for path in BIASEDGZFILES]):
        n_train = int(1e6)  # int(1e6)
        n_test = int(1e4)  # int(1e4)
        n_val = 512
        nbUniqueSamples = int(n_train + n_test + n_val)
        batch_size = 10000 if nbUniqueSamples > 10000 else nbUniqueSamples
        sentences = generateBiasedNltk(
            nbUniqueSamples=nbUniqueSamples,
            batch_size=batch_size,
            grammar_filepath=grammar_filepath)

        # Display a subset of the sentences
        logger.info('Example sentences: \n' + '\n'.join(sentences[:10]))

        train_test_val = {'train': sentences[:n_train],
                          'test': sentences[n_train:n_train + n_test],
                          'val': sentences[n_train + n_test:]}
        # Write sentences to file
        for key, sentences in train_test_val.items():
            datasetFilename = os.path.join(DATADIR, 'questions_biased_' + key + '.gz')
            with gzip.open(datasetFilename, mode='wt') as f:
                for sentence in tqdm(sentences):
                    f.write(sentence + '\r\n')

        logger.info('All done.')


def run_biased_generation(grammar_filepath='house_questions_grammar.cfg'):
    if not os.path.isdir(DATADIR): os.mkdir(DATADIR)

    logging.basicConfig(level=logging.INFO)
    setReproducible(1)

    grammar_filepath = os.path.join(CDIR, grammar_filepath)
    generate_bias(grammar_filepath)
    main(grammar_filepath)


if __name__ == '__main__':
    run_biased_generation()
