

import gzip

import numpy as np
import tensorflow as tf
from nltk import CFG
from nltk.parse.generate import generate
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical

# grammar cannot have recursion!
from ariel_tests.language.nlp import postprocessSentence

grammar = CFG.fromstring("""
                         S -> NP VP | NP V
                         VP -> V NP
                         NP -> Det N
                         Det -> 'a' | 'the'
                         N -> 'dog' | 'cat'
                         V -> 'chased' | 'saw'
                         """)


def _basicGenerator(grammar, batch_size=3):
    # sentences = []
    while True:
        yield [[' '.join(sentence)] for sentence in generate(grammar, n=batch_size)]


def indicesToOneHot(indices, num_tokens):
    return np.eye(num_tokens)[indices]


def getNoisyInput(keep, input_indices):
    if keep != 1.:
        assert keep <= 1.
        assert keep >= 0.
        keep_matrix = np.random.choice([0, 1], input_indices.shape, p=[1 - keep, keep])
        input_indices = input_indices * keep_matrix
    return input_indices


class BaseGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(
            self,
            gzip_filepath,
            vocabulary,
            batch_size,
            steps_per_epoch=None,
            maxlen=None,
            nb_lines=None,
            reverse_input=True,
            keep=1.):
        'Initialization'

        self.__dict__.update(gzip_filepath=gzip_filepath,
                             vocabulary=vocabulary,
                             batch_size=batch_size,
                             maxlen=maxlen,
                             nb_lines=nb_lines,
                             reverse_input=reverse_input,
                             keep=keep)
        self.count_lines_in_gzip()
        self.on_epoch_end()

        self.vocab_size = self.vocabulary.getMaxVocabularySize()

        self.PAD = self.vocabulary.padIndex
        self.START = self.vocabulary.startIndex
        self.END = self.vocabulary.endIndex

        if steps_per_epoch == 'all':
            self.steps_per_epoch = int(np.floor(self.nb_lines / self.batch_size))
        else:
            self.steps_per_epoch = steps_per_epoch

        if 'val' in self.gzip_filepath:
            self.X_val, self.y_val = self.data_generation()
            self.steps_per_epoch = 1
            if not nb_lines == None and nb_lines > 5:
                self.nb_lines == 512
                self.batch_size == self.nb_lines

    def count_lines_in_gzip(self):

        if self.nb_lines == None:
            self.nb_lines = 0
            f = gzip.open(self.gzip_filepath, 'rb')
            for line in f:
                self.nb_lines += 1

    def __len__(self):
        'Denotes the number of batches per epoch'
        return self.steps_per_epoch

    def __getitem__(self, index=0):
        'Generate one batch of data'

        # Generate data
        if 'val' in self.gzip_filepath:
            X, y = self.X_val, self.y_val
        else:
            X, y = self.data_generation()

        return X, y

    def on_epoch_end(self):
        self.f = gzip.open(self.gzip_filepath, 'rb')

    def data_generation(self):
        'Generates data containing batch_size samples'  # X : (n_samples, *dim, n_channels)
        # Initialization

        i = 0
        list_inidices = []
        for line in self.f:
            sentence = line.strip().decode('cp437')
            sentence = postprocessSentence(sentence)

            indices = [self.PAD, self.START] + \
                      self.vocabulary.sentenceToIndices(sentence) + \
                      [self.END]
            indices = indices[:self.maxlen]

            list_inidices.append(indices)
            i += 1
            if i >= self.batch_size: break

        indices = list_inidices
        maxSentenceLen = len(max(indices, key=len))

        if self.reverse_input:
            # Add a end token to encoder input
            x_enc = pad_sequences([tokens[::-1]
                                   for tokens in indices],
                                  maxlen=maxSentenceLen,
                                  value=self.vocabulary.padIndex,
                                  padding='post')
            x_enc = np.array(x_enc, dtype=np.int32)
        else:
            # Add a end token to encoder input
            x_enc = pad_sequences([tokens
                                   for tokens in indices],
                                  maxlen=maxSentenceLen,
                                  value=self.vocabulary.padIndex,
                                  padding='post')
            x_enc = np.array(x_enc, dtype=np.int32)

        # Add a end token to decoder input
        x_dec = pad_sequences([[self.vocabulary.padIndex] + tokens
                               for tokens in indices],
                              maxlen=maxSentenceLen + 1,
                              value=self.vocabulary.padIndex,
                              padding='post')
        x_dec = np.array(x_dec, dtype=np.int32)

        # Add a end token to decoder input
        y_dec = pad_sequences([tokens + [self.vocabulary.padIndex]
                               for tokens in indices],
                              maxlen=maxSentenceLen + 1,
                              value=self.vocabulary.padIndex,
                              padding='post')
        y_dec_oh = np.array(indicesToOneHot(y_dec, self.vocab_size),
                            dtype=np.float32)

        x_enc = getNoisyInput(self.keep, x_enc)
        x_dec = getNoisyInput(self.keep, x_dec)
        print(x_enc.shape, x_dec.shape, y_dec_oh.shape)
        return [x_enc, x_dec], y_dec_oh


class AeGenerator(BaseGenerator):
    pass


class VaeGenerator(BaseGenerator):

    def __getitem__(self, index=0):
        'Generate one batch of data'

        # Generate data
        if 'val' in self.gzip_filepath:
            X, y = self.X_val, self.y_val
        else:
            X, y = self.data_generation()

        return X, [y, y]


class TransformerGenerator(BaseGenerator):

    def data_generation(self):
        'Generates data containing batch_size samples'  # X : (n_samples, *dim, n_channels)
        # Initialization

        i = 0
        list_indices = []
        for line in self.f:
            sentence = line.strip().decode('cp437')
            sentence = postprocessSentence(sentence)

            indices = [self.PAD, self.START] + \
                      self.vocabulary.sentenceToIndices(sentence) + \
                      [self.END]
            indices = indices[:self.maxlen]

            list_indices.append(indices)
            i += 1
            if i >= self.batch_size: break
        padded = pad_sequences(list_indices,
                               maxlen=self.maxlen,
                               value=self.PAD,
                               padding='post')
        input_indices = padded[:, :-1]
        output_indices = padded[:, 1:]

        input_indices = getNoisyInput(self.keep, input_indices)
        return [input_indices, output_indices], None


class ArielGenerator(BaseGenerator):

    def data_generation(self):
        'Generates data containing batch_size samples'  # X : (n_samples, *dim, n_channels)
        # Initialization

        i = 0
        list_input = []
        list_output = []
        for line in self.f:
            sentence = line.strip().decode('cp437')
            sentence = postprocessSentence(sentence)

            indices = [self.PAD, self.START] + self.vocabulary.sentenceToIndices(sentence) + [self.END]
            indices = indices[:self.maxlen]

            length = len(indices)
            next_token_pos = np.random.randint(length)
            input_indices = indices[:next_token_pos]
            next_token = [indices[next_token_pos]]

            list_input.append(input_indices)
            list_output.append(next_token)

            i += 1
            if i >= self.batch_size: break

        input_indices = pad_sequences(list_input, maxlen=self.maxlen,
                                      value=self.vocabulary.padIndex,
                                      padding='post')
        output_indices = np.array(list_output)

        output_indices = to_categorical(output_indices, num_classes=self.vocab_size)

        input_indices = getNoisyInput(self.keep, input_indices)
        return input_indices, output_indices


class SimpleGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(
            self,
            filepath_spikes,
            filepath_sound,
            batch_size,
            steps_per_epoch=None,
    ):
        'Initialization'

        self.__dict__.update(filepath_spikes=filepath_spikes,
                             filepath_sound=filepath_sound,
                             batch_size=batch_size,
                             )
        self.count_lines_in_file()
        self.on_epoch_end()

        if steps_per_epoch == 'all  ':
            self.steps_per_epoch = int(np.floor(self.nb_lines / self.batch_size))
        else:
            self.steps_per_epoch = steps_per_epoch

    def count_lines_in_file(self):
        self.nb_lines = 0
        f = open(self.filepath, 'rb')
        for line in f:
            self.nb_lines += 1

    def __len__(self):
        'Denotes the number of batches per epoch'
        return self.steps_per_epoch

    def __getitem__(self, index=0):
        'Generate one batch of data'
        X, y = self.data_generation()

        return X, y

    def on_epoch_end(self):
        self.spikes_f = open(self.filepath_spikes, 'rb')
        self.sound_f = open(self.filepath_sound, 'rb')

    def data_generation(self):
        'Generates data containing batch_size samples'  # X : (n_samples, *dim, n_channels)
        # Initialization

        i = 0
        list_spikes = []
        list_sounds = []
        for i, (sp, so) in enumerate(zip(self.spikes_f, self.sound_f)):
            list_spikes.append(sp)
            list_sounds.append(so)
            if i >= self.batch_size: break

        spikes_batch = np.array(list_spikes)
        sounds_batch = np.array(list_sounds)
        return spikes_batch, sounds_batch


if __name__ == '__main__':
    grammar_filepath = '../data/simplerREBER_grammar.cfg'
    gzip_filepath = '../data/REBER_biased_train.gz'
    batch_size = 3
    generator = GzipToNextStepGenerator(gzip_filepath, grammar_filepath, batch_size)
    # check REBER generator
