# Get Python six functionality:
from __future__ import\
    absolute_import, print_function, division, unicode_literals


###############################################################################
###############################################################################
###############################################################################


import keras.backend as K
import keras.utils
import math


__all__ = [
    "model_wo_softmax",
    "to_list",

    "BatchSequence",
    "TargetAugmentedSequence",

    "preprocess_images",
    "postprocess_images",
]


###############################################################################
###############################################################################
###############################################################################


def model_wo_softmax(*args, **kwargs):
    # Break cyclic import
    from .keras.graph import model_wo_softmax

    return model_wo_softmax(*args, **kwargs)


###############################################################################
###############################################################################
###############################################################################


def to_list(l):
    """ If not list, wraps parameter into a list."""
    if not isinstance(l, list):
        return [l, ]
    else:
        return l


###############################################################################
###############################################################################
###############################################################################


class BatchSequence(keras.utils.Sequence):
    """Batch sequence generator.

    Take a (list of) input tensors and a batch size
    and creates a generators that creates a sequence of batches.

    :param Xs: One or a list of tensors. First axis needs to have same length.
    :param batch_size: Batch size. Default 32.
    """

    def __init__(self, Xs, batch_size=32):
        self.Xs = to_list(Xs)
        self.single_tensor = len(Xs) == 1
        self.batch_size = batch_size

        if not self.single_tensor:
            for X in self.Xs[1:]:
                assert X.shape[0] == self.Xs[0].shape[0]
        super(BatchSequence, self).__init__()

    def __len__(self):
        return int(math.ceil(float(len(self.Xs[0])) / self.batch_size))

    def __getitem__(self, idx):
        ret = [X[idx*self.batch_size:(idx+1)*self.batch_size]
               for X in self.Xs]

        if self.single_tensor:
            return ret[0]
        else:
            return tuple(ret)


class TargetAugmentedSequence(keras.utils.Sequence):
    """Augments a sequence with a target on the fly.

    Takes a sequence/generator and a function that
    creates on the fly for each batch a target.
    The generator takes a batch from that sequence,
    computes the target and returns both.

    :param sequence: A sequence or generator.
    :param augment_f: Takes a batch and returns a target.
    """

    def __init__(self, sequence, augment_f):
        self.sequence = sequence
        self.augment_f = augment_f

        super(TargetAugmentedSequence, self).__init__()

    def __len__(self):
        return len(self.sequence)

    def __getitem__(self, idx):
        inputs = self.sequence[idx]
        if isinstance(inputs, tuple):
            assert len(inputs) == 1
            inputs = inputs[0]

        targets = self.augment_f(to_list(inputs))
        return inputs, targets


###############################################################################
###############################################################################
###############################################################################


def preprocess_images(images, color_coding=None):
    """Image preprocessing

    Takes a batch of images and:
    * Adjust the color axis to the Keras format.
    * Fixes the color coding.

    :param images: Batch of images with 4 axes.
    :param color_coding: Determines the color coding.
      Can be None, 'RGBtoBGR' or 'BGRtoRGB'.
    :return: The preprocessed batch.
    """

    ret = images
    image_data_format = K.image_data_format()
    # todo: not very general:
    channels_first = images.shape[1] in [1, 3]
    if image_data_format == "channels_first" and not channels_first:
        ret = ret.transpose(0, 3, 1, 2)
    if image_data_format == "channels_last" and channels_first:
        ret = ret.transpose(0, 2, 3, 1)

    assert color_coding in [None, "RGBtoBGR", "BGRtoRGB"]
    if color_coding in ["RGBtoBGR", "BGRtoRGB"]:
        if image_data_format == "channels_first":
            ret = ret[:, ::-1, :, :]
        if image_data_format == "channels_last":
            ret = ret[:, :, :, ::-1]

    return ret


def postprocess_images(images, color_coding=None, channels_first=None):
    """Image postprocessing

    Takes a batch of images and reverts the preprocessing.

    :param images: A batch of images with 4 axes.
    :param color_coding: The initial color coding,
      see :func:`preprocess_images`.
    :param channels_first: The output channel format.
    :return: The postprocessed images.
    """

    ret = images
    image_data_format = K.image_data_format()
    assert color_coding in [None, "RGBtoBGR", "BGRtoRGB"]
    if color_coding in ["RGBtoBGR", "BGRtoRGB"]:
        if image_data_format == "channels_first":
            ret = ret[:, ::-1, :, :]
        if image_data_format == "channels_last":
            ret = ret[:, :, :, ::-1]

    if image_data_format == "channels_first" and not channels_first:
        ret = ret.transpose(0, 2, 3, 1)
    if image_data_format == "channels_last" and channels_first:
        ret = ret.transpose(0, 3, 1, 2)

    return ret
