# Get Python six functionality:
from __future__ import\
    absolute_import, print_function, division, unicode_literals
from builtins import range


###############################################################################
###############################################################################
###############################################################################


import keras.models
import keras.engine.topology


from ... import utils as iutils


__all__ = [
    "TestAnalysisHelper",
]


###############################################################################
###############################################################################
###############################################################################


class TestAnalysisHelper(object):

    def __init__(self, model, analyzer, weights=None):
        """ Helper class for retrieving output and analysis in test cases.
        

        :param model: A Keras layer object or a list of layer objects.
          In this case a sequntial model will be build. The first layer
          must have set input_shape or batch_input_shape.
          Alternatively a tuple with input and output tensors, in which
          case the keras modle api will be used.
        :param analyzer: Either an analyzer class or a function
          that takes a keras model and returns an analyzer.
        :param weights: After creating the model set the given weights.
        """

        if isinstance(model, keras.engine.topology.Layer):
            model = [model]

        if isinstance(model, list):
            self._model = keras.models.Sequential(model)
        else:
            self._model = keras.models.Model(*model)

        self._input_shapes = iutils.to_list(self._model.input_shape)

        if weights is not None:
            self._model.set_weights(weights)

        self._analyzer = analyzer(self._model)

    @property
    def weights(self):
        return self._model.get_weights()
    
    def run(self, inputs):
        """Runs the model given the inputs.

        :return: Tuple with model output and analyzer output.
        """
        return_list = True
        if not isinstance(inputs, list):
            return_list = False
            inputs = iutils.to_list(inputs)

        augmented = []
        for i in range(len(inputs)):
            if len(inputs[i].shape) == len(self._input_shapes[i])-1:
                # Augment by batch axis.
                augmented.append(i)
                inputs[i] = inputs[i].reshape((1,)+inputs[i].shape)

        outputs = iutils.to_list(self._model.predict_on_batch(inputs))
        analysis = iutils.to_list(self._analyzer.analyze(inputs))

        for i in augmented:
            # Remove batch axis.
            outputs[i] = outputs[i][0]
            analysis[i] = analysis[i][0]
        
        if return_list:
            return outputs, analysis
        else:
            return outputs[0], analysis[0]
