# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
A library for loading 1B word benchmark dataset.
------------------------------------------------

"""


import random

import numpy as np

from textattack.shared.utils import LazyLoader

tf = LazyLoader("tensorflow", globals(), "tensorflow")


class Vocabulary(object):
    """Class that holds a vocabulary for the dataset."""

    def __init__(self, filename):
        """Initialize vocabulary.

        Args:
          filename (str): Vocabulary file name.
        """

        self._id_to_word = []
        self._word_to_id = {}
        self._unk = -1
        self._bos = -1
        self._eos = -1

        with tf.io.gfile.GFile(filename) as f:
            idx = 0
            for line in f:
                word_name = line.strip()
                if word_name == "<S>":
                    self._bos = idx
                elif word_name == "</S>":
                    self._eos = idx
                elif word_name == "UNK":
                    self._unk = idx
                if word_name == "!!!MAXTERMID":
                    continue

                self._id_to_word.append(word_name)
                self._word_to_id[word_name] = idx
                idx += 1

    @property
    def bos(self):
        return self._bos

    @property
    def eos(self):
        return self._eos

    @property
    def unk(self):
        return self._unk

    @property
    def size(self):
        return len(self._id_to_word)

    def word_to_id(self, word):
        if word in self._word_to_id:
            return self._word_to_id[word]
        return self.unk

    def id_to_word(self, cur_id):
        """Converts an ID to the word it represents.

        Args:
          cur_id: The ID

        Returns:
          The word that :obj:`cur_id` represents.
        """
        if cur_id < self.size:
            return self._id_to_word[cur_id]
        return "ERROR"

    def decode(self, cur_ids):
        """Convert a list of ids to a sentence, with space inserted."""
        return " ".join([self.id_to_word(cur_id) for cur_id in cur_ids])

    def encode(self, sentence):
        """Convert a sentence to a list of ids, with special tokens added."""
        word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
        return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)


class CharsVocabulary(Vocabulary):
    """Vocabulary containing character-level information."""

    def __init__(self, filename, max_word_length):
        super(CharsVocabulary, self).__init__(filename)
        self._max_word_length = max_word_length
        chars_set = set()

        for word in self._id_to_word:
            chars_set |= set(word)

        free_ids = []
        for i in range(256):
            if chr(i) in chars_set:
                continue
            free_ids.append(chr(i))

        if len(free_ids) < 5:
            raise ValueError("Not enough free char ids: %d" % len(free_ids))

        self.bos_char = free_ids[0]  # <begin sentence>
        self.eos_char = free_ids[1]  # <end sentence>
        self.bow_char = free_ids[2]  # <begin word>
        self.eow_char = free_ids[3]  # <end word>
        self.pad_char = free_ids[4]  # <padding>

        chars_set |= {
            self.bos_char,
            self.eos_char,
            self.bow_char,
            self.eow_char,
            self.pad_char,
        }

        self._char_set = chars_set
        num_words = len(self._id_to_word)

        self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32)

        self.bos_chars = self._convert_word_to_char_ids(self.bos_char)
        self.eos_chars = self._convert_word_to_char_ids(self.eos_char)

        for i, word in enumerate(self._id_to_word):
            self._word_char_ids[i] = self._convert_word_to_char_ids(word)

    @property
    def word_char_ids(self):
        return self._word_char_ids

    @property
    def max_word_length(self):
        return self._max_word_length

    def _convert_word_to_char_ids(self, word):
        code = np.zeros([self.max_word_length], dtype=np.int32)
        code[:] = ord(self.pad_char)

        if len(word) > self.max_word_length - 2:
            word = word[: self.max_word_length - 2]
        cur_word = self.bow_char + word + self.eow_char
        for j in range(len(cur_word)):
            code[j] = ord(cur_word[j])
        return code

    def word_to_char_ids(self, word):
        if word in self._word_to_id:
            return self._word_char_ids[self._word_to_id[word]]
        else:
            return self._convert_word_to_char_ids(word)

    def encode_chars(self, sentence):
        chars_ids = [self.word_to_char_ids(cur_word) for cur_word in sentence.split()]
        return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])


def get_batch(generator, batch_size, num_steps, max_word_length, pad=False):
    """Read batches of input."""
    cur_stream = [None] * batch_size

    inputs = np.zeros([batch_size, num_steps], np.int32)
    char_inputs = np.zeros([batch_size, num_steps, max_word_length], np.int32)
    global_word_ids = np.zeros([batch_size, num_steps], np.int32)
    targets = np.zeros([batch_size, num_steps], np.int32)
    weights = np.ones([batch_size, num_steps], np.float32)

    no_more_data = False
    while True:
        inputs[:] = 0
        char_inputs[:] = 0
        global_word_ids[:] = 0
        targets[:] = 0
        weights[:] = 0.0

        for i in range(batch_size):
            cur_pos = 0

            while cur_pos < num_steps:
                if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
                    try:
                        cur_stream[i] = list(generator.next())
                    except StopIteration:
                        # No more data, exhaust current streams and quit
                        no_more_data = True
                        break

                how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
                next_pos = cur_pos + how_many

                inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
                char_inputs[i, cur_pos:next_pos] = cur_stream[i][1][:how_many]
                global_word_ids[i, cur_pos:next_pos] = cur_stream[i][2][:how_many]
                targets[i, cur_pos:next_pos] = cur_stream[i][0][1 : how_many + 1]
                weights[i, cur_pos:next_pos] = 1.0

                cur_pos = next_pos
                cur_stream[i][0] = cur_stream[i][0][how_many:]
                cur_stream[i][1] = cur_stream[i][1][how_many:]
                cur_stream[i][2] = cur_stream[i][2][how_many:]

                if pad:
                    break

        if no_more_data and np.sum(weights) == 0:
            # There is no more data and this is an empty batch. Done!
            break
        yield inputs, char_inputs, global_word_ids, targets, weights


class LM1BDataset(object):
    """Utility class for 1B word benchmark dataset.

    The current implementation reads the data from the tokenized text
    files.
    """

    def __init__(self, filepattern, vocab):
        """Initialize LM1BDataset reader.

        Args:
          filepattern: Dataset file pattern.
          vocab: Vocabulary.
        """
        self._vocab = vocab
        self._all_shards = tf.io.gfile.glob(filepattern)
        tf.compat.v1.logging.info(
            "Found %d shards at %s", len(self._all_shards), filepattern
        )

    def _load_random_shard(self):
        """Randomly select a file and read it."""
        return self._load_shard(random.choice(self._all_shards))

    def _load_shard(self, shard_name):
        """Read one file and convert to ids.

        Args:
          shard_name: file path.

        Returns:
          list of (id, char_id, global_word_id) tuples.
        """
        tf.compat.v1.logging.info("Loading data from: %s", shard_name)
        with tf.io.gfile.GFile(shard_name) as f:
            sentences = f.readlines()
        chars_ids = [self.vocab.encode_chars(sentence) for sentence in sentences]
        ids = [self.vocab.encode(sentence) for sentence in sentences]

        global_word_ids = []
        current_idx = 0
        for word_ids in ids:
            current_size = len(word_ids) - 1  # without <BOS> symbol
            cur_ids = np.arange(current_idx, current_idx + current_size)
            global_word_ids.append(cur_ids)
            current_idx += current_size

        tf.compat.v1.logging.info("Loaded %d words.", current_idx)
        tf.compat.v1.logging.info("Finished loading")
        return zip(ids, chars_ids, global_word_ids)

    def _get_sentence(self, forever=True):
        while True:
            ids = self._load_random_shard()
            for current_ids in ids:
                yield current_ids
            if not forever:
                break

    def get_batch(self, batch_size, num_steps, pad=False, forever=True):
        return get_batch(
            self._get_sentence(forever),
            batch_size,
            num_steps,
            self.vocab.max_word_length,
            pad=pad,
        )

    @property
    def vocab(self):
        return self._vocab
