
"""Proximity operators."""

from typing import Any
from typing import Optional
from typing import Tuple

import torch


def prox_none(x: Any,
              hyperparams: Optional[Any] = None,
              scaling: float = 1.0) -> Any:
  r"""Proximal operator for :math:`g(x) = 0`, i.e., the identity function.

  Since :math:`g(x) = 0`, the output is:

  .. math::

    \underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2 = x

  Args:
    x: input pytree.
    hyperparams: ignored.
    scaling: ignored.
  Returns:
    output pytree, with the same structure as ``x``.
  """
  del hyperparams, scaling
  return x


def prox_lasso(x: Any,
               l1reg: Optional[Any] = None,
               scaling: float = 1.0) -> Any:
  l1reg=torch.exp(l1reg)
  fun = lambda u, v: torch.sign(u) * torch.clip(torch.abs(u) - v * scaling, min=0)
  return fun(x, l1reg)

def jacobain_prox_lasso(x: Any,
               l1reg: Optional[Any] = None,
               scaling: float = 1.0) -> Any:
  '''
  The Jacobian matrix of the proximal operator w.r.t x
  :param x:
  :param l1reg:
  :param scaling:
  :return: 
  '''
  l1reg = torch.exp(l1reg)
  fun = lambda u, v: torch.sign(u) * torch.clip(torch.abs(u) - v * scaling, min=0)
  res = fun(x, l1reg)
  return  torch.diag(torch.abs(torch.sign(res)))

def jacobain_prox_lasso_lam(x: Any,
               l1reg: Optional[Any] = None,
               scaling: float = 1.0) -> Any:
  '''
  The Jacobian matrix of the proximal operator w.r.t x
  :param x:
  :param l1reg:
  :param scaling:
  :return:
  '''
  l1reg = torch.exp(l1reg)
  fun = lambda u, v: torch.sign(u) * torch.clip(torch.abs(u) - v * scaling, min=0)
  res = fun(x, l1reg)
  return  -l1reg*torch.sign(res)

def prox_non_negative_lasso(x: Any,
                            l1reg: Optional[float] = None,
                            scaling: float = 1.0) -> Any:
  r"""Proximal operator for the l1 norm on the non-negative orthant.

  .. math::

    \underset{y \ge 0}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2
    + \text{scaling} \cdot \text{l1reg} \cdot ||y||_1

  Args:
    x: input pytree.
    l1reg: regularization strength.
    scaling: a scaling factor.

  Returns:
    output pytree, with the same structure as ``x``.
  """
  if l1reg is None:
    l1reg = 1.0

  pytree = tree_util.tree_add(x, -l1reg * scaling)
  return tree_util.tree_map(jax.nn.relu, pytree)


def prox_elastic_net(x: Any,
                     hyperparams: Optional[Tuple[Any, Any]] = None,
                     scaling: float = 1.0) -> Any:
  r"""Proximal operator for the elastic net.

  .. math::

    \underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2
    + \text{scaling} \cdot \text{hyperparams[0]} \cdot g(y)

  where :math:`g(y) = ||y||_1 + \text{hyperparams[1]} \cdot 0.5 \cdot ||y||_2^2`.

  Args:
    x: input pytree.
    hyperparams: a tuple, where both ``hyperparams[0]`` and ``hyperparams[1]``
      can be either floats or pytrees with the same structure as ``x``.
    scaling: a scaling factor.

  Returns:
    output pytree, with the same structure as ``x``.
  """
  if hyperparams is None:
    hyperparams = (1.0, 1.0)

  prox_l1 = lambda u, lam: jnp.sign(u) * jax.nn.relu(jnp.abs(u) - lam)
  fun = lambda u, lam, gamma: (prox_l1(u, scaling * lam) /
                               (1.0 + scaling * lam * gamma))
  return tree_util.tree_map(fun, x, hyperparams[0], hyperparams[1])


def prox_group_lasso(x: Any,
                     l2reg: Optional[float] = 1.0,
                     scaling=1.0) -> Any:
  r"""Proximal operator for the l2 norm, i.e., block soft-thresholding operator.

  .. math::

    \underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2
    + \text{scaling} \cdot \text{l2reg} \cdot ||y||_2

  Blocks can be grouped using ``jax.vmap``.

  Args:
    x: input pytree.
    l2reg: regularization strength.
    scaling: a scaling factor.

  Returns:
    output pytree, with the same structure as ``x``.
  """
  if l2reg is None:
    l2reg = 1.0

  l2_norm = tree_util.tree_l2_norm(x)
  factor = 1 - l2reg * scaling / l2_norm
  factor = jnp.where(factor >= 0, factor, 0)
  return tree_util.tree_scalar_mul(factor, x)


def prox_ridge(x: Any,
               l2reg: Optional[float] = 1.0,
               scaling=1.0) -> Any:
  r"""Proximal operator for the squared l2 norm.

  .. math::

    \underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2
    + \text{scaling} \cdot \text{l2reg} \cdot ||y||_2^2

  Args:
    x: input pytree.
    l2reg: regularization strength.
    scaling: a scaling factor.

  Returns:
    output pytree, with the same structure as ``x``.
  """
  if l2reg is None:
    l2reg = 1.0

  factor = 1. / (1 + scaling * l2reg)
  return tree_util.tree_scalar_mul(factor, x)


def prox_non_negative_ridge(x: Any,
                            l2reg: Optional[float] = 1.0,
                            scaling: float = 1.0):
  r"""Proximal operator for the squared l2 norm on the non-negative orthant.

  .. math::

    \underset{y \ge 0}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2
    + \text{scaling} \cdot \text{l2reg} \cdot ||y||_2^2

  Args:
    x: input pytree.
    l2reg: regularization strength.
    scaling: a scaling factor.

  Returns:
    output pytree, with the same structure as ``x``.
  """
  if l2reg is None:
    l2reg = 1.0

  pytree = tree_util.tree_scalar_mul(1./ (1 + l2reg * scaling), x)
  return tree_util.tree_map(jax.nn.relu, pytree)


def make_prox_from_projection(projection):
  """Transforms a projection into a proximal operator."""
  def prox(x, hyperparams=None, scaling=1.0):
    del scaling  # The scaling parameter is meaningless for projections.
    return projection(x, hyperparams)
  return prox
