"""
TensorFlow Model Wrapper
--------------------------
"""


import numpy as np

from .model_wrapper import ModelWrapper


class TensorFlowModelWrapper(ModelWrapper):
    """Loads a TensorFlow model and tokenizer.

    TensorFlow models can use many different architectures and
    tokenization strategies. This assumes that the model takes an
    np.array of strings as input and returns a tf.Tensor of outputs, as
    is typical with Keras modules. You may need to subclass this for
    models that have dedicated tokenizers or otherwise take input
    differently.
    """

    def __init__(self, model):
        self.model = model

    def __call__(self, text_input_list, **kwargs):
        text_array = np.array(text_input_list)
        preds = self.model(text_array)
        return preds.numpy()

    def get_grad(self, text_input):
        raise NotImplementedError()
