import numpy as np
from methods.base_method import BaseMethod


class NIDS(BaseMethod):
    def __init__(self, gamma0, N=1000, type_gamma=1, alpha=1 / 2, c=1,
                 beta=1, eta=1, min_gamma=False,
                 history_with_gamma=False, normalized_by_gamma=True, const_gamma=False,
                 *args, **kwargs):
        """
        :param gamma0: float or np.array[n], starting step size
        :param N: int, iterations number
        :param type_gamma: int, type of step size. Possible values: 1, 2, 3 (see article)
        :param alpha: float, parameter for backtracking
        :param c: float, parameter in d^nu updating
        :param beta: float, parameter for tilde{f} constructing
        :param eta: float, linear coefficient in backtracking condition
        :param min_gamma: bool, if True, we choose minimal step size on all nodes
        :param history_with_gamma: bool, if True, method saves gamma values into history
        """
        super().__init__(*args, **kwargs)
        self.gamma0 = gamma0
        self.N = N
        self.type_gamma = type_gamma
        self.alpha = alpha
        self.beta = beta
        self.eta = eta
        self.c = c
        self.history_with_gamma = history_with_gamma
        self.min_gamma = min_gamma
        self.gamma_list = []
        self.normalized_by_gamma = normalized_by_gamma
        self.const_gamma = const_gamma
        if const_gamma:
            self.gamma = gamma0

    @staticmethod
    def _change_gamma(_cond_value, gamma, alpha):
        """
        Method implements one step of backtracking
        :param _cond_value: np.array[n], values for backtracking in each node
        :param gamma: np.array[n], current step sizes
        :param alpha: float, parameter for updating
        :return: updated gamma
        """
        change_mask = (_cond_value > 0).astype(int)
        gamma *= (change_mask * alpha + (1 - change_mask))
        return gamma

    @staticmethod
    def _get_step(X, gamma, D):
        Dgamma = gamma * D.T
        return X - Dgamma.T

    def update_gamma(self, F, gradF):
        """
        Function for updating Gamma through backtracking in each node
        :param gradF: callable, function for gradient calculating of function F(X)=sum_i F_i(x_i)
        :param F: callable, object function
        """
        if self.const_gamma:
            return
        F_tilde = lambda X: F(X) + self.beta * (self.Y * (X - self.X)).sum(axis=-1)

        d_tilde = self._new_grad + self.beta * self.Y

        F_tilde_x = F_tilde(self.X)
        gamma = np.copy(self.gamma) / self.alpha
        X_new = self._get_step(self.X, gamma, d_tilde)
        _cond_value = F_tilde(X_new) - F_tilde_x - gamma * self.eta * (d_tilde * d_tilde).sum(axis=-1)

        while not (_cond_value <= self.EPS).all():
            gamma = self._change_gamma(_cond_value, gamma, self.alpha)
            X_new = self._get_step(self.X, gamma, d_tilde)
            _cond_value = F_tilde(X_new) - F_tilde_x + gamma * self.eta * (d_tilde * d_tilde).sum(axis=-1)
        if not (_cond_value <= 0).all():
            gamma = gamma * self.alpha
        self.gamma_list.append(np.copy(gamma))
        self.gamma = np.copy(gamma)
        if self.min_gamma:
            self.gamma = self.gamma.min() * np.ones(self.gamma.shape)
        self.gamma = np.maximum(self.gamma, self.EPS_GAMMA)

    def __call__(self, X0, gradF, consensus, F=None, grad_sum=None):
        """
        :param X0: np.array[n, ...], starting points in each node
        :param gradF: callable, function for gradient calculating of function F(X)=sum_i F_i(x_i)
        :param consensus: callable, function that implements consensus procedure
        :param F: callable, object function
        :return: np.array[n, ...], obtained point with the same dimension as X0
        """
        gamma0, N, type_gamma = self.gamma0, self.N, self.type_gamma
        self.Y = gradF(X0)
        self.Y = np.zeros(X0.shape)
        self.X = X0.copy()
        self._old_grad = self.Y.copy()
        self.gamma = self.gamma0
        if type(self.gamma) is float:
            self.gamma = self.gamma * np.ones(X0.shape[0])
        for i in range(N):
            self._new_grad = gradF(self.X)
            self._old_Y = self.Y
            h = self._new_grad + self.beta * self.Y
            d = self._get_step(self.X, self.gamma,h)
            if i >= 1 and self.normalized_by_gamma:
                gamma_inv = 1/self.gamma
                self.Y = self.Y + self.c * (d - consensus(d)) * gamma_inv.reshape(-1, 1)
            else:
                self.Y =  self.Y + self.c * (d - consensus(d))
            self.update_gamma(F, gradF)
            self._old_grad = self._new_grad
            self._old_X = self.X
            if self.return_history:
                if self.history_with_gamma:
                    elem = (self.X.copy(), self.gamma.copy())
                else:
                    elem = self.X.copy()
                self.history.append(elem)
            h = self._new_grad + self.beta * self.Y
            self.X = self._get_step(self.X, self.gamma, h)
            if self._old_X is not None and np.linalg.norm(self.X - self._old_X) <= self.EPS:
                break
        if self.return_history:
            if self.history_with_gamma:
                elem = (self.X.copy(), None)
            else:
                elem = self.X.copy()
            self.history.append(elem)
        return self.X
