import functools
import itertools
import logging
from typing import Optional

import numpy as np
import sklearn.metrics._regression as skregression

_log = logging.getLogger(__name__)


def exponential_bounds(y_true, y_pred, *, y_lim: float, argmin: Optional[float] = 0.5,
                       c: Optional[float] = None, min_val: Optional[float] = 1e-6,
                       multioutput: str = "raw_values"):
    """
    Generates an exponential boundary function for incoming predictions that can be
    used to impose regression value boundaries in the loss function. Whether this
    function generates an "upper" or "lower" bound is determined by the values of `y_lim`
    and `argmin`, wherein if `y_lim > argmin` an upper bound is generated, otherwise a
    lower bound is generated. Behaviour is undefined for `y_lim` == `argmin`.

    :param y_true: array-like
        The true labels/expected regression targets.
    :param y_pred: array-like
        The predicted regression values.
    :param y_lim: float or None
        The upper/lower limit on the regressand.
    :param argmin: float or None
        The value of y_pred at which the boundary function should decay to a very small
        value. If not given, it will be implicitly determined by the value of `c`.
    :param c: float or None
        The constant multiplier for the slope of the exponential ($e^{c * x}$),
        determines how quickly the loss value grows or decays away from the boundary.
    :param min_val: float or None
        The minimum value of the boudnary loss. If None, `c` and `argmin` must be given
        and `min-val` is automatically calculated.
    :param multioutput: str
        Corresponds with the equivalent parameter of SKLearn compatible metrics.
        Currently present for compatibility purposes only and effectively only implements
        the default "raw_metrics" feature.
    :return:
    """

    y_type, y_true, y_pred, multioutput = skregression._check_reg_targets(
        y_true, y_pred, multioutput
    )
    skregression.check_consistent_length(y_true, y_pred)

    if min_val is None or min_val == 0.:
        assert c is not None, \
            f"When no minimum loss value is specified (parameter 'min_val' - was " \
            f"{min_val}), the parameter 'c' must be given (was {c})."
    elif c is None:
        c = np.log(min_val) / (argmin - y_lim)
    else:
        raise RuntimeError("Only one of `c` and `min_val` should be specified.")

    # TODO: Check for numerical overflows and implement a solution e.g. cap this value
    loss = np.exp(c * (y_pred - y_lim))
    grad = c * loss
    hess = (c ** 2) * loss

    return grad, hess


def squared_error(y_true, y_pred, multioutput: str = "raw_values"):
    """
    Implements the squared error loss. Returns the gradient and hessian for corrected
    squared error loss i.e. 0.5 * ||y_pred - y_true||^2.

    :param y_true: array-like
        The true labels/expected regression targets.
    :param y_pred: array-like
        The predicted regression values.
    :param multioutput: str
        Corresponds with the equivalent parameter of SKLearn compatible metrics.
        Currently present for compatibility purposes only and effectively only implements
        the default "raw_metrics" feature.
    """

    y_type, y_true, y_pred, multioutput = skregression._check_reg_targets(
        y_true, y_pred, multioutput
    )
    skregression.check_consistent_length(y_true, y_pred)

    # loss = 0.5 * np.power(y_true - y_pred, 2)
    grad = y_pred - y_true
    hess = np.ones_like(y_pred)

    return grad, hess


def mix_objectives(*funcs, weights=None):
    """ Mix any number of cost functions (objectives), passed as positional arguments,
    into a new objective function that assumes that the components objectives are to be
    added together. A weighted mixture can be generated by providing a corresponding
    array-like of values to 'weights'. The individual functions will all be called using
    the same inputs - (y_true, y!°_pred). If any other arguments need to be passed to the
    inidividual functions, the functions themselves should be converted to partial
    functions using `functools.partial`. """

    if weights is None:
        n = len(funcs)
        weights = np.ones(n) / n
    else:
        weights = np.array(weights).squeeze()
        assert weights.shape[0] == len(funcs)
        assert weights.ndim == 1

    def mix_fn(y_true, y_pred):
        res = map(lambda f: f(y_true, y_pred), funcs)
        res = itertools.starmap(lambda r, w: (r[0] * w, r[1] * w), zip(res, weights))
        grad, hess = functools.reduce(lambda acc, new: (acc[0] + new[0], acc[1] + new[1]),
                                      res)
        return grad, hess

    return mix_fn
