import warnings
import numpy as np
from scipy import optimize, linalg, sparse

from copt import utils
# from copt import splitting


def minimize_tos(
    f_grad,
    x0,
    prox_1=None,
    prox_2=None,
    tol=1e-6,
    max_iter=1000,
    verbose=0,
    callback=None,
    adaptive=True,
    gamma_0=1,
    args_prox=(),
):
    """Adaptive three operator splitting method.

    This algorithm can solve problems of the form

                minimize_x f(x) + g(x) + h(x)

    where the proximal operator is known for g and h.

    Args:
      f_grad: callable
        Returns the function value and gradient of the objective function.
        With return_gradient=False, returns only the function value.

      x0 : array-like
        Initial guess

      prox_1 : callable or None, optional
        prox_1(x, alpha, *args) returns the proximal operator of g at xa
        with parameter alpha.

      prox_2 : callable or None, optional
        prox_2(x, alpha, *args) returns the proximal operator of g at xa
        with parameter alpha.

      tol: float, optional
        Tolerance of the stopping criterion.

      max_iter : int, optional
        Maximum number of iterations.

      verbose : int, optional
        Verbosity level, from 0 (no output) to 2 (output on each iteration)

      callback : callable, optional
        Callback function. Takes a single argument (x) with the
        current coefficients in the algorithm. The algorithm will exit if
        callback returns False.

      adaptive : boolean, optional
        Whether to use adaptives step-size.

      gamma_0 : float, optional
        Step-size parameter.

      args_prox : tuple, optional
        Optional Extra arguments passed to the prox functions.
      

    Returns:
      res : OptimizeResult
        The optimization result represented as a
        ``scipy.optimize.OptimizeResult`` object. Important attributes are:
        ``x`` the solution array, ``success`` a Boolean flag indicating if
        the optimizer exited successfully and ``message`` which describes
        the cause of the termination. See `scipy.optimize.OptimizeResult`
        for a description of other attributes.


    References:
      [1] Davis, Damek, and Wotao Yin. `"A three-operator splitting scheme and
      its optimization applications."
      <https://doi.org/10.1007/s11228-017-0421-z>`_ Set-Valued and Variational
      Analysis, 2017.

      [2] Pedregosa, Fabian, and Gauthier Gidel. `"Adaptive Three Operator
      Splitting." <https://arxiv.org/abs/1804.02339>`_ Proceedings of the 35th
      International Conference on Machine Learning, 2018.
    """
    success = False

    if prox_1 is None:

        def prox_1(x, s, *args):
            return x

    if prox_2 is None:

        def prox_2(x, s, *args):
            return x

    y = x0
    step_size = gamma_0
    u_sum = 0.0

    for it in range(max_iter):

        z = prox_2(y, step_size, *args_prox)
        fk, grad_fk = f_grad(z)

        if adaptive:
            u_sum += np.sum(grad_fk**2)
            step_size = gamma_0 / np.sqrt(u_sum)
        else:
            step_size = gamma_0 / np.sqrt(it+1)

        x = prox_1(2*z - y - step_size * grad_fk, step_size, *args_prox)
        y = y + x - z

        certificate = np.linalg.norm(x - z)

        if callback is not None:
            if callback(locals()) is False:
                break

        if it > 0 and certificate < tol:
            success = True
            break

        if verbose and it % 10 == 0:
            print(f"Iteration {it}: f(z)={fk}: ny={np.sqrt(np.linalg.norm(y))}")

    return optimize.OptimizeResult(
        x=x, success=success, nit=it, certificate=certificate, step_size=step_size
    )

