from functools import partial
from typing import Any
from typing import Callable
from typing import Tuple
import torch

def _projection_unit_simplex(x1):
  """Projection onto the unit simplex."""
  x = torch.reshape(x1,[-1])
  s = 1.0
  n_features = x.shape[0]
  u = torch.sort(x,0,descending=True)[0]
  cumsum_u = torch.cumsum(u, dim=0)
  ind = (torch.arange(n_features)+1).to(x.device)
  ind = torch.reshape(ind, cumsum_u.shape)
  cond = s / ind + (u - cumsum_u / ind) > 0
  idx = torch.count_nonzero(cond)
  tmp = torch.clip(s / idx + (x - cumsum_u[idx - 1] / idx), min=0)
  return torch.reshape(tmp, x1.shape)


def projection_simplex(x, value: float = 1.0):
  r"""Projection onto a simplex:

  .. math::

    \underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad
    p \ge 0, p^\top 1 = \text{value}

  By default, the projection is onto the probability simplex.

  Args:
    x: vector to project, an array of shape (n,).
    value: value p should sum to (default: 1.0).
  Returns:
    projected vector, an array with the same shape as ``x``.
  """
  if value is None:
    value = 1.0
  return value * _projection_unit_simplex(x / value)

def projection_l1_sphere(x, value: float = 1.0):
  r"""Projection onto the l1 sphere:

  .. math::

    \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
    ||y||_1 = \text{value}

  Args:
    x: array to project.
    value: radius of the sphere.

  Returns:
    output array, with the same shape as ``x``.
  """
  return torch.sign(x) * projection_simplex(torch.abs(x), value)


def projection_l1_ball(x, max_value: float = 1.0):
  r"""Projection onto the l1 ball:

  .. math::

    \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
    ||y||_1 \le \text{max_value}

  Args:
    x: array to project.
    max_value: radius of the ball.

  Returns:
    output array, with the same structure as ``x``.
  """
  # a  = torch.linalg.norm(torch.reshape(x,[-1]), 1)
  l1_norm = torch.norm(x, 1)
  # d = torch.sum(torch.abs(x))
  # c = torch.norm( torch.reshape(x,[-1]), 1)
  if l1_norm <= max_value:
      return x
  else:
      return projection_l1_sphere(x, max_value)


def projection_l2_ball(x, max_value: float = 1.0):
  r"""Projection onto the l1 ball:

  .. math::

    \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
    ||y||_1 \le \text{max_value}

  Args:
    x: array to project.
    max_value: radius of the ball.

  Returns:
    output array, with the same structure as ``x``.
  """
  # a  = torch.linalg.norm(torch.reshape(x,[-1]), 1)
  l2_norm = torch.norm(x)
  # d = torch.sum(torch.abs(x))
  # c = torch.norm( torch.reshape(x,[-1]), 1)
  if l2_norm <= max_value:
      return x
  else:
      b = max_value*x/l2_norm
      c = torch.norm(b)
      return max_value*x/l2_norm


def jvp_projection_l2_ball(x, xdot, max_value):
  l2_norm = torch.norm(x)
  if l2_norm <= max_value:
    return torch.ones_like(x, device=x.device) * xdot
  else:
    tangent_out = max_value*(x*xdot/l2_norm-torch.sum(x* xdot)*x/(l2_norm**2))
    return tangent_out


def jvp_projection_l1_ball(x, xdot, max_value):
  l1_norm = torch.norm(x, 1)
  if l1_norm <= max_value:
    return torch.ones_like(x, device=x.device) * xdot
    # return torch.diag(torch.ones_like(torch.squeeze(x), device=x.device))
  else:
    primal_out = projection_l1_sphere(x, max_value)* torch.sign(x)
    # primal_out2 = torch.reshape(primal_out,[-1])
    supp = primal_out > 0
    card = torch.count_nonzero(supp)
    supp = torch.tensor(supp,dtype=torch.float32, device=x.device)
    tangent_out = supp*xdot-torch.sum(supp* xdot)*supp/card
    return tangent_out

def jvp_projection_unit_simplex(x, xdot, max_value):
  primal_out = projection_l1_sphere(x, max_value) * torch.sign(x)
  supp = primal_out > 0
  card = torch.count_nonzero(supp)
  supp = torch.tensor(supp,dtype=torch.float32, device=x.device)
  tangent_out = supp*xdot-torch.sum(supp* xdot)*supp/card
  return tangent_out

# def jvp_projection_l1_ball(model, x, xdot, max_value):
#   start = 0
#   jvp_list = []
#   for param in model.parameters():
#     if len(param.shape)==1:
#       offset = len(torch.reshape(param, [-1]))
#       x_sub = x[start:start + offset]
#       xdot_sub = xdot[start:start + offset]
#       x_sub_l1norm = torch.linalg.norm(x_sub, ord=1)
#       if x_sub_l1norm <= max_value:
#         jvp = torch.ones_like(x_sub, device=x.device) * x_sub
#       else:
#         jvp = jvp_projection_unit_simplex(x_sub, xdot_sub, max_value)
#       jvp_list.append(jvp)
#       start = start + offset
#     else:
#       for p in param:
#         offset = len(torch.reshape(p, [-1]))
#         x_sub = x[start:start+offset]
#         xdot_sub = xdot[start:start+offset]
#         x_sub_l1norm = torch.linalg.norm(x_sub, ord=1)
#         if x_sub_l1norm<= max_value:
#           jvp = torch.ones_like(x_sub, device=x.device) * x_sub
#         else:
#           jvp = jvp_projection_unit_simplex(x_sub, xdot_sub, max_value)
#         jvp_list.append(jvp)
#         start = start + offset
#   jvp_final = torch.cat(jvp_list)
#   return jvp_final


