

import logging
import time

import nltk
import numpy as np
from tqdm import tqdm

from ariel_tests.language.nlp import tokenize
from ariel_tests.language.main_biased import evaluateIndicesBias
from ariel_tests.language.utils import get_encoded_sentences, get_samples_and_volume, \
    parse2grammarLabel
from ariel_tests.test_analysis.disentanglement import disentanglement_measure

logger = logging.getLogger(__name__)


class ReconstructionAndGeneralizationStudy(object):

    def __init__(self, DataClass, config):

        ############################################################
        #     Tests: Model for Reconstruction
        ############################################################

        self.config = config
        self.config['results']['reconstruction'] = {}

        # Test semantic correctness, coverage and grammatical correctness over biased dataset
        logger.info("""\n\nTesting:
        - latent Hypercube
        - semantic correctness, 
        - coverage 
        - grammatical correctness 
        over biased dataset""")

        nbSuccess_semantic = 0
        nbSuccess_coverage = 0
        nbSuccess_grammatical = 0

        # sometimes we explore concatenations of states, so 
        # DataClass.latDim won't be enough
        example_sentence = DataClass.biasedSentences_Test[0]
        latDim_study = DataClass.encoder.encode(example_sentence).shape[0]
        zmin = +100 * np.ones(latDim_study)
        zmax = -100 * np.ones(latDim_study)

        parser = nltk.ChartParser(DataClass.grammar)

        encoded_sentences = []
        biasedSentences_Test = DataClass.biasedSentences_Test

        # save generated samples
        with open(DataClass.experiment_folder + '/text/' + 'biased_reconstruction.txt', 'w') as file:

            for sentence in tqdm(biasedSentences_Test):

                z = DataClass.encoder.encode(sentence)
                encoded_sentences.append(z)
                reconstruction = DataClass.decoder.decode(z)

                nbSuccess_semantic += evaluateIndicesBias(DataClass.vocabulary.sentenceToIndices(reconstruction))
                if reconstruction == sentence:
                    nbSuccess_coverage += 1

                tokens = tokenize(reconstruction)
                trees = list(parser.parse(tokens))
                if len(trees) > 0:
                    nbSuccess_grammatical += 1

                zmin = np.minimum(zmin, z)
                zmax = np.maximum(zmax, z)

                file.write('\n\n')
                file.write(sentence + '\n')
                file.write(reconstruction)

        np.savez(DataClass.experiment_folder + '/text/' + 'biased_encoded', encoded_sentences)

        semantic = nbSuccess_semantic / len(DataClass.biasedSentences_Test)
        logger.info('\n\nSemantic correctness ratio over biased dataset: %0.3f' % (semantic))
        self.config['results']['reconstruction']['semantic'] = semantic

        coverageBiased = nbSuccess_coverage / len(DataClass.biasedSentences_Test)
        logger.info('\n\nCoverage ratio over biased dataset: %0.3f' % (coverageBiased))
        self.config['results']['reconstruction']['coverage-biased'] = coverageBiased

        grammaticalReconstruction = nbSuccess_grammatical / len(DataClass.biasedSentences_Test)
        logger.info('\n\nGrammatical correctness ratio over biased dataset: %0.3f' % (grammaticalReconstruction))
        self.config['results']['reconstruction']['grammatical'] = grammaticalReconstruction

        np.save(DataClass.experiment_folder + '/text/' + 'hypercube-zmin.npy', zmin)
        np.save(DataClass.experiment_folder + '/text/' + 'hypercube-zmax.npy', zmax)

        # Test coverage over unbiased dataset
        logger.info('\n\nTesting coverage over unbiased dataset')
        nbSuccess = 0
        nbSuccess_semantic = 0

        # save generated samples
        with open(DataClass.experiment_folder + '/text/' + 'unbiased_reconstruction.txt', 'w') as file:

            i = 0
            for sentence in tqdm(DataClass.unbiasedSentences_Test):
                z = DataClass.encoder.encode(sentence)
                reconstruction = DataClass.decoder.decode(z)
                if reconstruction == sentence:
                    nbSuccess += 1

                is_this_semantic = evaluateIndicesBias(DataClass.vocabulary.sentenceToIndices(reconstruction))
                nbSuccess_semantic += is_this_semantic

                i += 1
                file.write('\n\n')
                file.write(sentence + '\n')
                file.write(reconstruction + '\n')
                file.write('is it semantic: {}'.format(is_this_semantic))

        coverageUnbiased = nbSuccess / len(DataClass.unbiasedSentences_Test)
        logger.info('\n\nCoverage ratio over unbiased dataset: %0.3f' % (coverageUnbiased))
        self.config['results']['reconstruction']['coverage-unbiased'] = coverageUnbiased

        ub_semantic = nbSuccess_semantic / len(DataClass.biasedSentences_Test)
        logger.info('\n\nSemantic correctness ratio over unbiased dataset: %0.3f' % (ub_semantic))
        self.config['results']['reconstruction']['unbiased_semantic'] = ub_semantic

        self.DataClass = DataClass
        self.DataClass.get_data('encoded_sentences', encoded_sentences)
        self.DataClass.get_data('semantic_reconstruction', semantic)
        self.DataClass.get_data('semantic_generalization', ub_semantic)
        self.DataClass.get_data('coverageBiased', coverageBiased)
        self.DataClass.get_data('grammaticalReconstruction', grammaticalReconstruction)
        self.DataClass.get_data('coverageUnbiased', coverageUnbiased)


class EncodingStudy(object):

    def __init__(self, DataClass, config):

        self.config = config
        self.config['results']['encoding'] = {}

        grammar_labels = DataClass.grammar_labels
        length_labels = DataClass.length_labels

        if hasattr(DataClass, 'encoded_sentences'):
            encoded_sentences = DataClass.encoded_sentences
        else:
            encoded_sentences = get_encoded_sentences(DataClass)

        # print('encoded_sentences.shape:   ', np.array(encoded_sentences).shape)

        len_enc_score, gra_enc_score = disentanglement_measure([length_labels, grammar_labels], encoded_sentences)

        logger.info('\n\n')
        logger.info('Length clustering through the encoding: %0.3f' % (len_enc_score))
        self.config['results']['encoding']['length-clustering-encoding'] = len_enc_score
        logger.info('Grammar clustering through the encoding: %0.3f' % (gra_enc_score))
        self.config['results']['encoding']['grammar-clustering-encoding'] = gra_enc_score

        self.DataClass = DataClass
        self.DataClass.get_data('gra_enc_score', gra_enc_score)
        self.DataClass.get_data('len_enc_score', len_enc_score)


class DecodingStudy(object):

    def __init__(self, DataClass, config):

        self.config = config
        self.config['results']['generation'] = {}

        samples, volume = get_samples_and_volume(DataClass)

        # check uniqueness, correctness and parsing of sentences 
        # generated from the latent space
        logger.info('\n\nCheck uniqueness, correctness and parsing of sentences generated from the latent space')
        words = []
        parses = []
        uniqueParses = []
        correct_and_unseen_sentence = []
        sentences = []

        parser = nltk.ChartParser(DataClass.grammar)

        nbSuccess_semantic = 0
        MeanTimePerSampleGenerated = 0
        samples_grammar_rule = []
        for i, sample in enumerate(tqdm(samples)):
            correct_and_unseen_sentence.append(0)

            begin = time.time()
            sentence = DataClass.decoder.decode(sample)
            end = time.time()
            MeanTimePerSampleGenerated += (end - begin) / DataClass.nb_samples_test
            nbSuccess_semantic += evaluateIndicesBias(DataClass.vocabulary.sentenceToIndices(sentence))

            parsed = []
            if not sentence in sentences:
                if len(sentence) > 0:
                    tokens = tokenize(sentence)

                    parsed = list(parser.parse(tokens))
                    if not not parsed:
                        parsed = str(list(parser.parse(tokens).__next__()))
                        correct_and_unseen_sentence[i] = 1
                        uniqueParses.append(parsed)
            else:
                index = sentences.index(sentence)
                parsed = parses[index]

            samples_grammar_rule.append(parse2grammarLabel(parsed))
            sentences.append(sentence)
            parses.append(parsed)
            words += sentence.split(' ')

        # check how many grammar rules have been covered

        GrammarRulesUsed = [0] * 4
        for parse in uniqueParses:
            label = parse2grammarLabel(parse)
            GrammarRulesUsed[label] = 1

            if sum(GrammarRulesUsed) == 4: break

        coverageVocabulary = len(set(words)) / DataClass.vocabulary.getMaxVocabularySize()
        rtUniqueSentences = len(set(sentences)) / DataClass.nb_samples_test
        rtCorrectAndUniqueSentences = sum(correct_and_unseen_sentence) / DataClass.nb_samples_test
        coverageGrammar = sum(GrammarRulesUsed) / len(GrammarRulesUsed)
        semantic = nbSuccess_semantic / DataClass.nb_samples_test

        logger.info('\n\nVocabulary size: %0.3f\n\n' % (DataClass.vocabulary.getMaxVocabularySize()))

        self.config['vabulary size'] = DataClass.vocabulary.getMaxVocabularySize()
        self.config['results']['generation']['nb_samples_test'] = DataClass.nb_samples_test
        self.config['results']['generation']['coverageVocabulary'] = coverageVocabulary
        self.config['results']['generation']['nbUniqueSentences'] = rtUniqueSentences
        self.config['results']['generation']['nbCorrectAndUniqueSentences'] = rtCorrectAndUniqueSentences
        self.config['results']['generation']['coverageGrammar'] = coverageGrammar
        self.config['results']['generation']['GrammarRulesUsed'] = GrammarRulesUsed
        self.config['results']['generation']['density'] = rtCorrectAndUniqueSentences / volume
        self.config['results']['generation']['MeanTimePerSampleGenerated'] = MeanTimePerSampleGenerated
        self.config['results']['generation']['semantic'] = semantic
        self.DataClass = DataClass

        self.DataClass.get_data('coverageVocabulary', coverageVocabulary)
        self.DataClass.get_data('rtUniqueSentences', rtUniqueSentences)
        self.DataClass.get_data('rtCorrectAndUniqueSentences', rtCorrectAndUniqueSentences)
        self.DataClass.get_data('coverageGrammar', coverageGrammar)
        self.DataClass.get_data('samples', samples)
        self.DataClass.get_data('sentences', sentences)
        self.DataClass.get_data('parses', parses)
        self.DataClass.get_data('semantic_generation', semantic)
        self.DataClass.get_data('samples_grammar_rule', samples_grammar_rule)
