# Get Python six functionality:
from __future__ import\
    absolute_import, print_function, division, unicode_literals
from builtins import range


###############################################################################
###############################################################################
###############################################################################


import matplotlib.pyplot as plt
import numpy as np


__all__ = [
    "project",
    "heatmap",
    "graymap",
    "gamma",
    "clip_quantile",
]


###############################################################################
###############################################################################
###############################################################################


def project(X, output_range=(0, 1), absmax=None, input_is_positive_only=False):
    """Projects a tensor into a value range.

    Projects the tensor values into the specified range.

    :param X: A tensor.
    :param output_range: The output value range.
    :param absmax: A tensor specifying the absmax used for normalizing.
      Default the absmax along the first axis.
    :param input_is_positive_only: Is the input value range only positive.
    :return: The tensor with the values project into output range.
    """

    if absmax is None:
        absmax = np.max(np.abs(X),
                        axis=tuple(range(1, len(X.shape))))
    absmax = np.asarray(absmax)

    mask = absmax != 0
    if mask.sum() > 0:
        X[mask] /= absmax[mask]

    if input_is_positive_only is False:
        X = (X+1)/2  # [0, 1]
    X = X.clip(0, 1)

    X = output_range[0] + (X * (output_range[1]-output_range[0]))
    return X


def heatmap(X, cmap_type="seismic", reduce_op="sum", reduce_axis=-1, alpha_cmap=False, **kwargs):
    """Creates a heatmap/color map.

    Create a heatmap or colormap out of the input tensor.

    :param X: A image tensor with 4 axes.
    :param cmap_type: The color map to use. Default 'seismic'.
    :param reduce_op: Operation to reduce the color axis.
      Either 'sum' or 'absmax'.
    :param reduce_axis: Axis to reduce.
    :param alpha_cmap: Should the alpha component of the cmap be included.
    :param kwargs: Arguments passed on to :func:`project`
    :return: The tensor as color-map.
    """
    cmap = plt.cm.get_cmap(cmap_type)

    tmp = X
    shape = tmp.shape

    if reduce_op == "sum":
        tmp = tmp.sum(axis=reduce_axis)
    elif reduce_op == "absmax":
        pos_max = tmp.max(axis=reduce_axis)
        neg_max = (-tmp).max(axis=reduce_axis)
        abs_neg_max = -neg_max
        tmp = np.select([pos_max >= abs_neg_max, pos_max < abs_neg_max],
                        [pos_max, neg_max])
    else:
        raise NotImplementedError()

    tmp = project(tmp, output_range=(0, 255), **kwargs).astype(np.int64)

    if alpha_cmap:
        tmp = cmap(tmp.flatten()).T
    else:
        tmp = cmap(tmp.flatten())[:, :3].T
    tmp = tmp.T

    shape = list(shape)
    shape[reduce_axis] = 3 + alpha_cmap
    return tmp.reshape(shape).astype(np.float32)


def graymap(X, **kwargs):
    """Same as :func:`heatmap` but uses a gray colormap."""
    return heatmap(X, cmap_type="gray", **kwargs)


def gamma(X, gamma=0.5, minamp=0, maxamp=None):
    """
    Apply gamma correction to an input array X
    while maintaining the relative order of entries,
    also for negative vs positive values in X.
    the fxn firstly determines the max
    amplitude in both positive and negative
    direction and then applies gamma scaling
    to the positive and negative values of the
    array separately, according to the common amplitude.

    :param gamma: the gamma parameter for gamma scaling
    :param minamp: the smallest absolute value to consider.
        if not given assumed to be zero (neutral value for relevance,
        min value for saliency, ...). values above and below
        minamp are treated separately.
    :param maxamp: the largest absolute value to consider relative
        to the neutral value minamp
        if not given determined from the given data.
    """

    #prepare return array
    Y = np.zeros_like(X)

    X = X - minamp # shift to given/assumed center
    if maxamp is None: maxamp = np.abs(X).max() #infer maxamp if not given
    X = X / maxamp # scale linearly

    #apply gamma correction for both positive and negative values.
    i_pos = X > 0
    i_neg = np.invert(i_pos)
    Y[i_pos] = X[i_pos]**gamma
    Y[i_neg] = -(-X[i_neg])**gamma

    #reconstruct original scale and center
    Y *= maxamp
    Y += minamp

    return Y


def clip_quantile(X, quantile=1):
    """Clip the values of X into the given quantile."""
    if not isinstance(quantile, (list, tuple)):
        quantile = (quantile, 100-quantile)

    low = np.percentile(X, quantile[0])
    high = np.percentile(X, quantile[1])
    X[X < low] = low
    X[X > high] = high

    return X
