__author__ = ",  "
__version__ = "0.0"
__email__ = " "

import copy
from collections import deque
from enum import Enum
import matplotlib as matplotlib
import torch
from scipy.optimize import root
from sklearn.model_selection import KFold
from torch.optim.optimizer import Optimizer
from .regression_collection import *


class ELF_STATE(Enum):
    """
    Defines the current state of the LFO optimizer
    """
    MEASURE_LINE = 0
    SGD_TRAINING = 1


class _RequiredParameter(object):
    """Singleton class representing a required parameter for an Optimizer."""

    def __repr__(self):
        return "<required parameter>"


required = _RequiredParameter()


# noinspection SpellCheckingInspection,PyMethodOverriding,PyPep8,PyPep8,PyPep8,PyPep8,PyPep8,PyPep8,PyPep8,PyPep8,PyPep8,PyPep8
class ELF(Optimizer):

    def __init__(self, model=required, loss_improvement_factor=0.01, amount_measure_points=1500, momentum=0.4,
                 decrease_factor=0.2, l2_reg=0.0, fit_function_family="polynomial", backtracking=False, use_log10=False,
                 is_logging=True, plot_save_dir="./lines/"):
        """
        The ELF optimizer.
        Estimates step sizes by approximating the expected empirical loss.
        Only one parameter group is supported.

        :param model: a pytroch model
        :param loss_improvement_factor: from the approximated empirical  an expected improvement is inferred. If the real improvement is smaller than expected_improvement*loss_improvement_factor a new step size is measured.
        :param amount_measure_points:  losses to measure for each step size evaluation. Good values are 1500 or 3000.
        :param decrease_factor: If 0, step to the approximated minimum is chosen. If not, step to loss_divider * (emp loss at minimum - empirical loss )-emp loss at minimum.
        :param l2_reg: l2 loss regularization factor
        :param fit_function_family: one of "polynomial" "cubic spline" "fourier series"
        :param backtracking: do not allow increasing step sizes
        :param use_log10: fitting on the log loss.
        """
        assert amount_measure_points > 1500
        assert 0 <= decrease_factor <= 1
        assert fit_function_family == "polynomial" or fit_function_family == "cubic spline" or fit_function_family == "fourier series"
        assert 0 <= momentum < 1.0
        assert 0 <= decrease_factor < 1.0
        assert l2_reg >= 0.0

        params = model.parameters()
        self.model = model
        self.params = list(self.model.parameters())

        self.loss_improvement_factor = loss_improvement_factor
        self.momentum = momentum
        self.loss_divider = decrease_factor
        self.is_logging = is_logging
        self.use_log10 = use_log10
        self.backtracking = backtracking
        self.plot_save_dir = plot_save_dir
        self.fit_function_family = fit_function_family  # "polynomial" "cubic spline" "fourier series"
        self.l2_reg = l2_reg

        # constants based on from empirical evidence
        self.amount_lines = 3.0
        self.amount_readjustments = 5.0  # readjustment of the measure interval during a line measure
        self.amount_losses_to_measure_per_line = int(
            amount_measure_points // self.amount_lines // self.amount_readjustments * self.amount_readjustments)
        self.max_degree = 10  # max polynomial degree or max amount of fourier frequencies or max amount of splines
        self.min_degree = 3
        self.epsilon = 1e-15
        # To get to a position where simple function approximations get always applicable we train a small time in the beginning of the training with the maximal step size that still supports training.
        self.max_steps_for_initial_lr_search = 1500  # A good heuristic is to choose between 1% and 3% of the training set size.
        self.losses_to_consider_for_plateau = self.max_steps_for_initial_lr_search // 10
        self.initial_lr_divider = 10

        # internal state fields
        self.lfo_state = ELF_STATE.SGD_TRAINING
        self.update_step_size = 100  # in rare cases 100 leads to inf values than a lower number as to be chosen
        self.old_update_step_size = self.update_step_size
        self.measure_interval_factor = 1.0  # initial value, will be changed during training
        self.current_line = 1
        self.update_steps_per_line = []
        self.loss_improvements_per_line = []
        self.last_line_measure_step = 0
        self.direction_norm = torch.Tensor([1.0])
        self.performed_training_steps = -1

        self.before_last_line_losses_mean = np.inf
        self.current_losses = deque(maxlen=self.losses_to_consider_for_plateau)

        self.measured_losses = []
        self.measured_locations = []
        self.mean_step_size = []
        self.current_position = 0.0
        self.points_to_measure = []
        self.loss_improvement = 0.1
        self.use_validation_set = False
        self.model_state_checkpoint = None

        defaults = dict()  # only relevant for parameter groups which are not supported
        super(ELF, self).__init__(params, defaults)

    def state_dict(self):
        """
        :return:  a dictionary representing the current state of the optimizer
        """
        dict_ = self.__dict__
        return copy.deepcopy(dict_)

    def load_state_dict(self, state_dict):
        """
        set the current state of the optimizer
        """
        state_dict = copy.deepcopy(state_dict)
        self.__dict__.update(state_dict)
        self.params = list(self.model.parameters())

    def _update_direction_vars(self, params):
        """
        Update the search or update direction. In SGD training mode it is used as the update direction.
        In measure line state the normalized direction is used as search direction.
        :param params: the network parameters
        """
        with torch.no_grad():
            norm = torch.tensor(0.0)
            for p in params:
                if p.grad is None:
                    continue
                param_state = self.state[p]
                if 'dir_buffer' not in param_state:
                    buf = param_state['dir_buffer'] = torch.zeros_like(p.grad.data, device=p.device)
                else:
                    buf = param_state['dir_buffer']
                buf.mul_(self.momentum)  # _ = inplace
                buf.add_(p.grad.data)
                if self.l2_reg != 0:
                    buf.add_(self.l2_reg * p.data)
                flat_buf = buf.view(-1)
                norm = norm + torch.dot(flat_buf, flat_buf)
            torch.sqrt_(norm)
            if norm == 0.0:
                norm = self.epsilon

            if torch.cuda.is_available() and isinstance(norm, torch.Tensor):
                self.direction_norm = norm.cuda()
            else:
                self.direction_norm = norm

    def is_check_validation_score_valid(self):
        """
        called by the training loop. Tells whether the optimizer is in a state where the validation or test accuracy can be measured.
        E.g. during a line search the positions in the parameter space might not be optimal.
        :return:
        """
        if self.lfo_state == ELF_STATE.SGD_TRAINING and self.performed_training_steps > self.max_steps_for_initial_lr_search:
            return True
        return False

    def _set_direction_norm_to_one(self):
        """
        Used for SGD update steps
        """
        with torch.no_grad():
            norm = torch.Tensor([1.0])
            if torch.cuda.is_available():
                self.direction_norm = norm.cuda()
            else:
                self.direction_norm = norm

    # def _zero_direction_vars(self, params):
    #     with torch.no_grad():
    #         for p in params:
    #             if p.grad is None:
    #                 continue
    #             param_state = self.state[p]
    #             param_state['dir_buffer'] = torch.zeros_like(p.grad.data)

    def _set_checkpoint(self):
        """
        Saves the current position on the parameter space
        """
        with torch.no_grad():
            self.model_state_checkpoint = copy.deepcopy(self.model.state_dict())
            for p in self.params:
                if p.grad is None:
                    continue
                param_state = self.state[p]
                param_state['ckpt_buffer'] = param_state['dir_buffer'].clone().detach()

    def _reset_to_best_checkpoint(self):
        """
        Resets to the parameter position to the last saved checkpoint
        """
        with torch.no_grad():
            if self.model_state_checkpoint is not None:
                self.model.load_state_dict(self.model_state_checkpoint)
                self.params = list(self.model.parameters())
                for p in self.params:
                    if p.grad is None:
                        continue
                    param_state = self.state[p]
                    assert "ckpt_buffer" in param_state
                    param_state['dir_buffer'] = param_state['ckpt_buffer'].clone().detach()

    def _perform_sgd_update(self, params, step_size):
        """
        Performs a SGD update step.
        """
        with torch.no_grad():
            for p in params:
                if p.grad is None:
                    continue
                mom = self.state[p]["dir_buffer"]
                p.data += step_size * -mom / self.direction_norm

    def _get_l2_loss(self, params):
        """
        Determines the current l2 loss.
        """
        with torch.no_grad():
            l2_loss = torch.tensor(0.0)
            if self.l2_reg is not 0.0:
                for p in params:
                    if p.data is None:
                        continue
                    flat_data = p.data.view(-1)
                    l2_loss = l2_loss + torch.dot(flat_data, flat_data)
            return l2_loss * self.l2_reg * 0.5

    def _perform_update_step(self, params, step_size, loss_fn, new_direction):
        """
        Performs an update step on the parameters space in the direction of the direction variabels.
        :param params:
        :param step_size:
        :param loss_fn:
        :param new_direction: if true, update the direction variables
        :return: current loss, model output
        """
        if self.performed_training_steps == 0:
            loss, net_out = loss_fn(
                backwards=True)  # at first step we have to call the loss_fn twice to initialize the direction buffer
            l2_loss = self._get_l2_loss(params)
            loss = loss + l2_loss
            self._update_direction_vars(params)
            self._set_checkpoint()
            self.before_last_line_losses_mean = loss.item()

        self._perform_sgd_update(params, step_size)
        if new_direction:
            loss, net_out = loss_fn(backwards=True)
        else:
            loss, net_out = loss_fn(backwards=False)
        l2_loss = self._get_l2_loss(params)
        loss = loss + l2_loss
        loss = self._to_numpy(loss)

        self.current_position += step_size
        if new_direction:
            self._update_direction_vars(params)
        return loss, net_out

    def step(self, loss_fn):
        """
        Optimizer update step. Either SGD training is done. Or one loss for the loss approximation is done. If enough losses for a valid line approximation are determined a new step size is set.
        # No support of param groups since it conflicts with checkpoints of the parameters and every state field must be param_group specific
        :param loss_fn: function of the form:
        >>>    def closure(backwards = True):
        >>>        self.optimizer.zero_grad()
        >>>        output = self.model(x)
        >>>        loss_ = self.loss(output,y)
        >>>       if backwards:
        >>>            loss_.backward()
        >>>        return loss_, output
        >>>    loss, outputs, step_size = self.optimizer.step(closure)
        >>>    return loss, outputs, step_size, y
        :return: current loss, current model output, current step size if SGD training state. None,None,None if Line Search State.
        """

        params = self.params
        self.performed_training_steps += 1

        if self.performed_training_steps == self.max_steps_for_initial_lr_search:
            self._set_measure_line_state()

        if self.lfo_state == ELF_STATE.SGD_TRAINING:
            assert len(self.points_to_measure) == 0
            loss, net_out = self._perform_update_step(params, self.update_step_size, loss_fn, new_direction=True)
            self.current_losses.append(loss)
            if len(self.current_losses) >= self.losses_to_consider_for_plateau:
                current_losses_mean = np.nanmean(self.current_losses)
                expected_improvement = self.loss_improvement_factor * self.loss_improvement * (
                        self.performed_training_steps - self.last_line_measure_step)
                real_improvement = self.before_last_line_losses_mean - current_losses_mean
                # print("improvement: {0:2.5f}".format(self.last_losses_mean-current_losses_mean))
                if np.isnan(current_losses_mean) or current_losses_mean > self.before_last_line_losses_mean:
                    print("+" * 50)
                    print(
                        "current average loss is greater than before update step adaptation. saerching new update step. current: {0:2.5f} last: {1:2.5f}".format(
                            current_losses_mean, self.before_last_line_losses_mean))
                    print("+" * 50)
                    self._reset_to_best_checkpoint()
                    self.current_losses = deque(
                        maxlen=self.losses_to_consider_for_plateau)  # throw away the new losses and reuse old mean loss
                    if self.performed_training_steps >= self.max_steps_for_initial_lr_search:
                        self._perform_update_step(params, 0.0, loss_fn, new_direction=True)
                        self._set_measure_line_state()
                    else:
                        self.update_step_size /= self.initial_lr_divider
                        # self._zero_direction_vars(params)
                        loss, _ = self._perform_update_step(params, 0.0, loss_fn, new_direction=True)
                        print(
                            "initial update step adaptation: divided update step by {0} the new step size is {1:2.5f}".format(
                                self.initial_lr_divider, self.update_step_size))

                elif expected_improvement > real_improvement and self.performed_training_steps >= self.max_steps_for_initial_lr_search:
                    print("plateau reached at mean loss: {0:2.5f}. Starting new line search".format(
                        current_losses_mean))
                    self._perform_update_step(params, 0.0, loss_fn, new_direction=True)
                    self._set_measure_line_state()
                    self.current_losses = deque(maxlen=self.losses_to_consider_for_plateau)
                    self.before_last_line_losses_mean = current_losses_mean.item()
                else:
                    self.current_losses = deque(maxlen=self.losses_to_consider_for_plateau)

            return loss, net_out, self.update_step_size.item() if isinstance(self.update_step_size,
                                                                             torch.Tensor) else self.update_step_size
        else:  # self.lfo_state == LFO_STATE.MEASURE_LINE:
            if len(self.points_to_measure) > 0:
                self._measure_next_loss(params, loss_fn)
            else:
                direction_norm = self.direction_norm
                x = np.array(self.measured_locations)
                if self.use_log10:
                    y = np.log10(self.measured_losses)
                else:
                    y = np.array(self.measured_losses)
                degree_of_best_fitting_polynom = self._get_best_fitting_function_with_cross_validation(x, y)
                fitted_function, minimum_location, loss_improvement = self._get_min_position_and_fitted_function_of_degree(
                    x, y,
                    degree_of_best_fitting_polynom,
                    direction_norm)
                if fitted_function is not None and minimum_location is not None and loss_improvement is not None:
                    self._decide_how_to_proceed_on_line(x, y, minimum_location, fitted_function,
                                                        loss_improvement, direction_norm,
                                                        degree_of_best_fitting_polynom, loss_fn, params)
                else:
                    self._handle_nonsuitable_fit(x, y, loss_fn, params, fitted_function)
            if len(self.update_steps_per_line) == self.amount_lines:
                self._set_new_update_step_originating_from_multiple_lines()
            return None, None, None

    def _set_measure_line_state(self):
        self.lfo_state = ELF_STATE.MEASURE_LINE
        self.measured_losses = []
        self.measured_locations = []
        self.current_position = 0
        self._set_checkpoint()
        center = 0.0
        self.points_to_measure = self._get_uniformly_distributed_points_to_measure(
            center - self.measure_interval_factor * 2,
            center + self.measure_interval_factor * 2,
            self.amount_losses_to_measure_per_line // self.amount_readjustments)
        self.use_validation_set = True

    def set_sgd_training_state(self):
        self.lfo_state = ELF_STATE.SGD_TRAINING
        self.points_to_measure = []
        self.use_validation_set = False
        self._set_direction_norm_to_one()
        self.last_line_measure_step = self.performed_training_steps

    @staticmethod
    def _get_uniformly_distributed_points_to_measure(min_, max_, amount_measure_points):
        assert amount_measure_points > 0
        return np.random.uniform(min_, max_, size=int(amount_measure_points)).tolist()

    def _measure_next_loss(self, params, loss_fn):
        assert len(self.points_to_measure) > 0
        measure_location = self.points_to_measure[0]
        del self.points_to_measure[0]
        step = measure_location - self.current_position
        loss, _ = self._perform_update_step(params, step, loss_fn, new_direction=False)
        self.measured_locations.append(measure_location)
        self.measured_losses.append(loss)
        return loss

    def _get_best_fitting_function_with_cross_validation(self, x, y):
        """
        Performs cross validation to find the best generalizing polynomial fit.
        Returns: the degree of the best fitting polynomial in range min_degree to max_degree.
        """
        last_mean_fold_val_error = None
        kFold = KFold(n_splits=5, shuffle=True, random_state=None)
        splits = list(kFold.split(x))
        for degree in range(self.min_degree, self.max_degree + 1):
            fold_val_errors = []
            fold_train_errors = []
            for i, (train_indexes, val_indexes) in enumerate(splits):
                x_train = x[train_indexes]
                y_train = y[train_indexes]
                # weights = np.ones(len(x_train))  # *np.sqrt(n) n is 1 in this case
                x_val = x[val_indexes]
                y_val = y[val_indexes]
                fitted_function = self._get_fitted_function(x_train, y_train, degree)
                if fitted_function is None: continue
                train_error = self._get_mean_least_square_error(x_train, y_train, fitted_function)
                val_error = self._get_mean_least_square_error(x_val, y_val, fitted_function)
                fold_train_errors.append(train_error)
                fold_val_errors.append(val_error)
            mean_fold_val_error = np.nanmean(fold_val_errors)
            if last_mean_fold_val_error is not None and mean_fold_val_error > last_mean_fold_val_error:
                self._log_print("function of degree: {0:d} was chosen".format(degree - 1))
                return degree - 1
            last_mean_fold_val_error = mean_fold_val_error
        self._log_print("function of degree: {0:d} was chosen".format(self.max_degree))
        return self.max_degree

    def _log_print(self, *messages):
        if self.is_logging: print(*messages)

    def _get_fitted_function(self, x, y, degree):
        """
        Returns: list of coefficients of decreasing power
        """
        assert len(x) > 0
        assert degree > 0
        if self.fit_function_family == "polynomial":
            function_to_fit, _ = get_polynomial(degree)
        elif self.fit_function_family == "cubic spline":
            function_to_fit, _ = get_smooth_cubic_spline(degree)
        elif self.fit_function_family == "fourier series":
            function_to_fit, _ = get_fourier_series(degree)
        else:
            raise Exception("fit function family not supported " + self.fit_function_family)
        try:
            popt, pcov = curve_fit(function_to_fit, x, y, maxfev=10000, check_finite=False)

            def fitted_function(z):
                return function_to_fit(z, *popt)
        except:
            print("error occured while fitting function of degree ", degree)
            fitted_function = None

        return fitted_function

    @staticmethod
    def _get_mean_least_square_error(x, y, fitted_function):
        assert len(x) == len(y)
        assert len(x) > 0
        return np.sum((fitted_function(x) - y) ** 2) / len(x)

    @staticmethod
    def _get_step_to_loss(fitted_function, loss_divider=0.0, bounds=(-np.inf, np.inf)):
        """
        Args:
            coefficients:
            loss_divider:  if 0 the exact minimum is estimated, if != 0 step to target_loss: loss_divider * (emp loss at minimum - empirical loss )-emp loss at minimum is estimated
        Returns: position of the minimum or target_loss, np.nan if no position was found.
        """
        assert callable(fitted_function)
        result = minimize(fitted_function, x0=0, options={"maxiter": 1000, "disp": False}, bounds=[bounds])
        if result.success and len(result.x) > 0:
            minimum = result.x[0]
            if loss_divider == 0.0:
                return minimum
            else:
                loss_at_0 = fitted_function(0)
                loss_at_min = fitted_function(minimum)
                shifted_function_squared = lambda x: (fitted_function(x) - ((
                                                                                    loss_at_0 - loss_at_min) * loss_divider) - loss_at_min) ** 2  # todo change name loss divider
                # result = root(shifted_function,[minimum]) # root does not know bounds therefore we minimize the least squares
                # The Problem is that the minimum is near the local maximum in this case, working methods are:
                # Powell (binary line search) https://en.wikipedia.org/wiki/Powell%27s_method
                result_sf = minimize(shifted_function_squared, x0=minimum, options={"maxiter": 1000, "disp": False},
                                     method="Powell", bounds=[(minimum, np.inf)])
                if result_sf.success and len(result_sf.x) > 0:
                    shifted_function_root = result_sf.x[0]
                    return shifted_function_root
                else:
                    return np.nan
        else:
            return np.nan

    def _get_min_position_and_fitted_function_of_degree(self, x, y, degree, direction_norm):
        """
        :param x:  line positions
        :param y:  loss values
        :param degree: polynomial degree, or fourier frequency or number of splines
        :param direction_norm:  norm of the search direction
        :return: fitted function as callable, the location of its nearest minimum near 0, expected loss improvement
        """
        fitted_function = self._get_fitted_function(x, y, degree)
        if fitted_function is not None:
            minimum_location = self._get_step_to_loss(fitted_function, self.loss_divider)
            loss_improvement = self._get_loss_improvement(fitted_function, minimum_location)
            train_error = self._get_mean_least_square_error(x, y, fitted_function)
            self._log_print("fit on full data:")
            self._log_print(
                "degree: {0} \t train_error: {1:3.4f} \t step_to_min: {2:3.4f} \t step_to_min_normalized: {3:3.4f} \t improvement: {4:3.4f}".format(
                    degree, train_error,
                    minimum_location / direction_norm if minimum_location is not None else -1000.0,
                    minimum_location if minimum_location is not None else -1000.0,
                    loss_improvement if loss_improvement is not None else -1000.0))
            if not np.isnan(minimum_location):
                return fitted_function, minimum_location, loss_improvement
        return fitted_function, None, None

    @staticmethod
    def _get_loss_improvement(fitted_function, target_position):
        """
        :return: expected loss improvement if a step to the target_position is done.
        """
        if target_position is None:
            return None
        else:
            improvement = (fitted_function(0.0) - fitted_function(target_position))
            return improvement

    def _decide_how_to_proceed_on_line(self, x, y, minimum_location, fitted_function, loss_improvement, grad_norm,
                                       degree, loss_fn, params):
        """
        Decide where and whether to measure new points or to fit the line function approximation
        """
        # extrapolation
        if minimum_location > max(self.measured_locations) or minimum_location < min(
                self.measured_locations):
            self._log_print("estimated minimum is extrapolating, measuring up to this location, too")
            self._plot(x, y, None, None, fitted_function, minimum_location, degree,
                       message="extrapolation")
            # Here it might happen that the algorithm measures more points than it should per line but this is intended
            self.points_to_measure = self._get_uniformly_distributed_points_to_measure(
                minimum_location - self.measure_interval_factor,
                minimum_location + self.measure_interval_factor,
                self.amount_losses_to_measure_per_line // self.amount_readjustments)
            self._measure_next_loss(params, loss_fn)
        else:
            self._plot(x, y, None, None, fitted_function, minimum_location, degree,
                       message="line_" + str(len(self.update_steps_per_line)))
            self._set_measure_interval(fitted_function, minimum_location, x, y)
            # enough losses sampled to fit line function approximation
            if len(x) >= self.amount_losses_to_measure_per_line:
                self._reset_to_best_checkpoint()
                step = minimum_location
                self._perform_update_step(params, step, loss_fn, new_direction=True)
                self.update_step_size = minimum_location / grad_norm
                self.update_steps_per_line.append(self.update_step_size)
                self.loss_improvements_per_line.append(loss_improvement)
                self._set_measure_line_state()
            # set further points to measure
            else:
                self.points_to_measure = self._get_uniformly_distributed_points_to_measure(
                    minimum_location - self.measure_interval_factor,
                    minimum_location + self.measure_interval_factor,
                    self.amount_losses_to_measure_per_line // self.amount_readjustments)

    def _set_measure_interval(self, fitted_function, min_position, x, y):
        """
        Determines in which interval new losses are measured equally distributed.
        Does nothing if min_position is nan or none or if the polynomial does take on the estimated target value.
        """
        assert len(x) == len(y)
        assert len(x) > 0
        assert fitted_function is not None
        if np.isnan(min_position) or min_position is None:
            return

        interval_indexes = [i for i in range(len(x)) if 0 <= x[i] <= 2 * min_position]
        if len(interval_indexes) < 50:
            interval_indexes = [i for i in range(len(x)) if
                                np.abs(x[i] - min_position) in np.sort(np.abs(x - min_position))[0:50]]
        target_value = np.nanquantile(y[interval_indexes], 0.75)
        shifted_function_abs = lambda z: abs((abs(fitted_function(z)) - abs(target_value)))
        result = root(shifted_function_abs, [min_position], options={"disp": False},
                      method="df-sane")
        if result.success and len(result.x) > 0:
            root_pos = result.x[0]
            distance_to_root = np.abs(root_pos - min_position)
            distance_to_root = distance_to_root if distance_to_root < self.measure_interval_factor * 5 else self.measure_interval_factor * 5
            self.measure_interval_factor = np.max([1.5 * np.abs(min_position), distance_to_root])
            if self.is_logging: print("measure_interval_factor: {0:f}".format(self.measure_interval_factor))

    def _handle_nonsuitable_fit(self, x, y, loss_fn, params, fitted_function):
        """
        Decides wether to sample more losses or to perform a new line seach if the maximal losses to measure is reached.
        """
        self._plot(x, y, None, None, fitted_function, None, message="no suitable fit")
        if len(x) < self.amount_losses_to_measure_per_line:
            print("not able to find a suitable fit, sample more points")
            meanx = np.nanmean(x)
            interval = (np.max(x) - meanx) * 1.5
            self.points_to_measure = self._get_uniformly_distributed_points_to_measure(
                meanx - interval,
                meanx + interval,
                self.amount_losses_to_measure_per_line // self.amount_readjustments)
        else:
            print(
                "not able to find a suitable fit after reaching maximal number of points to measure, measure next line")
            step = 0 - self.current_position
            self._perform_update_step(params, step, loss_fn, new_direction=True)
            assert self.current_position == 0.0
            self._set_measure_line_state()

    def _set_new_update_step_originating_from_multiple_lines(self):
        """
        After measuring 3 consecutive line function approximation the average update step and loss improvement is used in further training.
        """
        mean_update_step_np = np.nanmean(self._to_numpy(self.update_steps_per_line))
        mean_update_step = torch.Tensor([mean_update_step_np])
        if torch.cuda.is_available():
            mean_update_step = mean_update_step.cuda()
        mean_loss_improvement = np.nanmean(self.loss_improvements_per_line)
        if np.isnan(mean_update_step_np) or mean_update_step_np < 0 or (
                self.backtracking and mean_update_step.item() > self.old_update_step_size) \
                or mean_loss_improvement <= 0:
            self._log_print("found step size or found loss improvement is not valid  step size:" + str(
                mean_update_step_np) + "loss_improvement:" + str(
                mean_loss_improvement) + " reuse old step size " + str(
                self.old_update_step_size))
            self.update_step_size = self.old_update_step_size
            self._reset_to_best_checkpoint()
        else:
            self.update_step_size = mean_update_step
            self.old_update_step_size = self.update_step_size
            self.loss_improvement = mean_loss_improvement
            self._log_print("update steps for conescutive lines:", [x.item() for x in self.update_steps_per_line])
            self._log_print("--" * 30)
            self._log_print("new update step:", mean_update_step_np)
            self._log_print("new loss improvement:", self.loss_improvement)
            self._log_print("--" * 30)
        self.update_steps_per_line = []
        self.loss_improvements_per_line = []
        self.set_sgd_training_state()

    @staticmethod
    def _to_numpy(value):
        if isinstance(value, torch.Tensor):
            return value.cpu().detach().numpy()
        if isinstance(value, list):
            if len(value) > 0 and isinstance(value[0], torch.Tensor):
                return [x.cpu().detach().numpy() for x in value]
        return value

    def _plot(self, x_train, y_train, x_eval, y_eval, fitted_function, estimated_min, degree=-1, message=""):
        matplotlib.use('Agg')
        if self.is_logging:
            plt.figure()
            x = list(self.measured_locations)
            if fitted_function is not None:
                approx_x = np.arange(min(x), max(x), 0.001)
                approx_y = fitted_function(approx_x)
                plt.plot(approx_x, approx_y, color="red", linewidth=3)

            plt.scatter(x_train, y_train, color="orange", marker="o", s=25)
            if x_eval is not None:
                plt.scatter(x_eval, y_eval, color="green", marker="o", s=25)

            if fitted_function is not None:
                approx_min_index = np.where(approx_y == np.amin(approx_y))[0]
                plt.scatter([approx_x[approx_min_index]],
                            fitted_function(approx_x[approx_min_index]),
                            color="lime", marker="o",
                            s=50, zorder=100)
                if estimated_min is not None:
                    plt.scatter([estimated_min], fitted_function(estimated_min), color="red",
                                marker="o",
                                s=50, zorder=100)
                plt.title(
                    "step {0}   degree: {1}".format(int(self.performed_training_steps), degree))
            plt.xlabel("step on line")
            plt.ylabel("loss")
            global_step = int(self.performed_training_steps)
            import os
            dir_ = self.plot_save_dir  # +"/loss_fitting_plots"
            os.makedirs(dir_, exist_ok=True)
            plt.savefig("{0}/measured_line_{1:d}_{2}.png".format(dir, global_step, message))
            # plt.show()
            plt.close()
