# This code is from https://github.com/automl/pybnn
# pybnn authors: Aaron Klein, Moritz Freidank

import inspect
import logging
import traceback

import numpy as np
from scipy.optimize import fmin_l_bfgs_b, leastsq
from scipy.stats import norm


def recency_weights(num):
    if num == 1:
        return np.ones(1)
    else:
        recency_weights = [10 ** (1.0 / num)] * num
        recency_weights = recency_weights ** (np.arange(0, num))
        return recency_weights


class CurveModel(object):
    def __init__(
        self, function, function_der=None, min_vals={}, max_vals={}, default_vals={}
    ):
        """
        function: the function to be fit
        function_der: derivative of that function
        """
        self.function = function
        if function_der != None:
            raise NotImplementedError(
                "function derivate is not implemented yet...sorry!"
            )
        self.function_der = function_der
        assert isinstance(min_vals, dict)
        self.min_vals = min_vals.copy()
        assert isinstance(max_vals, dict)
        self.max_vals = max_vals.copy()
        function_args = inspect.getargspec(function).args
        assert "x" in function_args, "The function needs 'x' as a parameter."
        for default_param_name in default_vals.keys():
            if default_param_name == "sigma":
                continue
            msg = "function %s doesn't take default param %s" % (
                function.__name__,
                default_param_name,
            )
            assert default_param_name in function_args, msg
        self.function_params = [param for param in function_args if param != "x"]
        # set default values:
        self.default_vals = default_vals.copy()
        for param_name in self.function_params:
            if param_name not in default_vals:
                logging.info(
                    "setting function parameter %s to default of 1.0 for "
                    "function %s" % (param_name, self.function.__name__)
                )
                self.default_vals[param_name] = 1.0
        self.all_param_names = [param for param in self.function_params]
        self.all_param_names.append("sigma")
        self.name = self.function.__name__
        self.ndim = len(self.all_param_names)

        # uniform noise prior over interval:
        if "sigma" not in self.min_vals:
            self.min_vals["sigma"] = 0.0
        if "sigma" not in self.max_vals:
            self.max_vals["sigma"] = 1.0
        if "sigma" not in self.default_vals:
            self.default_vals["sigma"] = 0.05

    def default_function_param_array(self):
        return np.asarray(
            [self.default_vals[param_name] for param_name in self.function_params]
        )

    def are_params_in_bounds(self, theta):
        """
        Are the parameters in their respective bounds?
        """
        in_bounds = True

        for param_name, param_value in zip(self.all_param_names, theta):
            if param_name in self.min_vals:
                if param_value < self.min_vals[param_name]:
                    in_bounds = False
            if param_name in self.max_vals:
                if param_value > self.max_vals[param_name]:
                    in_bounds = False
        return in_bounds

    def split_theta(self, theta):
        """Split theta into the function parameters (dict) and sigma."""
        params = {}
        sigma = None
        for param_name, param_value in zip(self.all_param_names, theta):
            if param_name in self.function_params:
                params[param_name] = param_value
            elif param_name == "sigma":
                sigma = param_value
        return params, sigma

    def split_theta_to_array(self, theta):
        """Split theta into the function parameters (array) and sigma."""
        params = theta[:-1]
        sigma = theta[-1]
        return params, sigma

    def fit(self, x, y):
        raise NotImplementedError()

    def predict(self, x):
        raise NotImplementedError()

    def predict_given_theta(self, x, theta):
        """
        Make predictions given a single theta
        """
        params, sigma = self.split_theta(theta)
        predictive_mu = self.function(x, **params)
        return predictive_mu, sigma

    def likelihood(self, x, y):
        """
        for each y_i in y:
            p(y_i|x, model)
        """
        params, sigma = self.split_theta(self.ml_params)
        return norm.pdf(y - self.function(x, **params), loc=0, scale=sigma)


class MLCurveModel(CurveModel):
    """
    ML fit of a curve.
    """

    def __init__(self, recency_weighting=True, **kwargs):
        super(MLCurveModel, self).__init__(**kwargs)

        # Maximum Likelihood values of the parameters
        self.ml_params = None
        self.recency_weighting = recency_weighting

    def fit(self, x, y, weights=None, start_from_default=True):
        """
        weights: None or weight for each sample.
        """
        return self.fit_ml(x, y, weights, start_from_default)

    def predict(self, x):
        # assert len(x.shape) == 1
        params, sigma = self.split_theta_to_array(self.ml_params)
        return self.function(x, *params)
        # return np.asarray([self.function(x_pred, **params) for x_pred in x])

    def fit_ml(self, x, y, weights, start_from_default):
        """
        non-linear least-squares fit of the data.

        First tries Levenberg-Marquardt and falls back
        to BFGS in case that fails.

        Start from default values or from previous ml_params?
        """
        successful = self.fit_leastsq(x, y, weights, start_from_default)
        if not successful:
            successful = self.fit_bfgs(x, y, weights, start_from_default)
            if not successful:
                return False
        return successful

    def ml_sigma(self, x, y, popt, weights):
        """
        Given the ML parameters (popt) get the ML estimate of sigma.
        """
        if weights is None:
            if self.recency_weighting:
                variance = np.average(
                    (y - self.function(x, *popt)) ** 2, weights=recency_weights(len(y))
                )
                sigma = np.sqrt(variance)
            else:
                sigma = (y - self.function(x, *popt)).std()
        else:
            if self.recency_weighting:
                variance = np.average(
                    (y - self.function(x, *popt)) ** 2,
                    weights=recency_weights(len(y)) * weights,
                )
                sigma = np.sqrt(variance)
            else:
                variance = np.average(
                    (y - self.function(x, *popt)) ** 2, weights=weights
                )
                sigma = np.sqrt(variance)
        return sigma

    def fit_leastsq(self, x, y, weights, start_from_default):
        try:
            if weights is None:
                if self.recency_weighting:
                    residuals = lambda p: np.sqrt(recency_weights(len(y))) * (
                        self.function(x, *p) - y
                    )
                else:
                    residuals = lambda p: self.function(x, *p) - y
            else:
                # the return value of this function will be squared, hence
                # we need to take the sqrt of the weights here
                if self.recency_weighting:
                    residuals = lambda p: np.sqrt(recency_weights(len(y)) * weights) * (
                        self.function(x, *p) - y
                    )
                else:
                    residuals = lambda p: np.sqrt(weights) * (self.function(x, *p) - y)

            if start_from_default:
                initial_params = self.default_function_param_array()
            else:
                initial_params, _ = self.split_theta_to_array(self.ml_params)

            popt, cov_popt, info, msg, status = leastsq(
                residuals, x0=initial_params, full_output=True
            )
            # Dfun=,
            # col_deriv=True)

            if np.any(np.isnan(info["fjac"])):
                return False

            leastsq_success_statuses = [1, 2, 3, 4]
            if status in leastsq_success_statuses:
                if any(np.isnan(popt)):
                    return False
                # within bounds?
                if not self.are_params_in_bounds(popt):
                    return False

                sigma = self.ml_sigma(x, y, popt, weights)
                self.ml_params = np.append(popt, [sigma])

                logging.info("leastsq successful for model %s" % self.function.__name__)

                return True
            else:
                logging.warn(
                    "leastsq NOT successful for model %s, msg: %s"
                    % (self.function.__name__, msg)
                )
                logging.warn("best parameters found: " + str(popt))
                return False
        except Exception as e:
            logging.error(e)
            tb = traceback.format_exc()
            logging.error(tb)
            return False

    def fit_bfgs(self, x, y, weights, start_from_default):
        try:

            def objective(params):
                if weights is None:
                    if self.recency_weighting:
                        return np.sum(
                            recency_weights(len(y))
                            * (self.function(x, *params) - y) ** 2
                        )
                    else:
                        return np.sum((self.function(x, *params) - y) ** 2)
                else:
                    if self.recency_weighting:
                        return np.sum(
                            weights
                            * recency_weights(len(y))
                            * (self.function(x, *params) - y) ** 2
                        )
                    else:
                        return np.sum(weights * (self.function(x, *params) - y) ** 2)

            bounds = []
            for param_name in self.function_params:
                if param_name in self.min_vals and param_name in self.max_vals:
                    bounds.append(
                        (self.min_vals[param_name], self.max_vals[param_name])
                    )
                elif param_name in self.min_vals:
                    bounds.append((self.min_vals[param_name], None))
                elif param_name in self.max_vals:
                    bounds.append((None, self.max_vals[param_name]))
                else:
                    bounds.append((None, None))

            if start_from_default:
                initial_params = self.default_function_param_array()
            else:
                initial_params, _ = self.split_theta_to_array(self.ml_params)

            popt, fval, info = fmin_l_bfgs_b(
                objective, x0=initial_params, bounds=bounds, approx_grad=True
            )
            if info["warnflag"] != 0:
                logging.warn(
                    "BFGS not converged! (warnflag %d) for model %s"
                    % (info["warnflag"], self.name)
                )
                logging.warn(info)
                return False

            if popt is None:
                return False
            if any(np.isnan(popt)):
                logging.info(
                    "bfgs NOT successful for model %s, parameter NaN" % self.name
                )
                return False
            sigma = self.ml_sigma(x, y, popt, weights)
            self.ml_params = np.append(popt, [sigma])
            logging.info("bfgs successful for model %s" % self.name)
            return True
        except:
            return False

    def aic(self, x, y):
        """
        Akaike information criterion
        http://en.wikipedia.org/wiki/Akaike_information_criterion
        """
        params, sigma = self.split_theta_to_array(self.ml_params)
        y_model = self.function(x, *params)
        log_likelihood = norm.logpdf(y - y_model, loc=0, scale=sigma).sum()
        return 2 * len(self.function_params) - 2 * log_likelihood
