import jax
import jax.numpy as jnp


def chart_north(s):
  """From sphere to local coords except for s = (0, ..., 0, 1)."""
  return s[:-1] / (1 + s[-1])


def embed_north(x):
  """From local coords to sphere except for s = (0, ..., 0 , 1)."""
  n = (x**2).sum()
  return jnp.hstack([x, (1 - n) / 2]) * 2 / (1 + n)


def chart_south(s):
  """From sphere to local coords except for s = (1, 0, ..., 0)."""
  ret = s[:-1] / (1 - s[-1])
  return ret.at[0].multiply(-1)


def embed_south(x):
  """From local coords to sphere except for s = (1, 0, ..., 0)."""
  n = (x**2).sum()
  ret = jnp.hstack([x, (n - 1) / 2]) * 2 / (1 + n)
  return ret.at[0].multiply(-1)


def coord_jac(s, parametrisation=None):
  """Computes Jacobian of the map from local coordinates to the unit sphere.

  :param s:
    `d`-dimensional array containing *single* producer embedding.
  :param parametrisation:
    String describing which parametrisation should be used; `'north'` for param
    valid all vectors but `(0, ..., 0, 1)`, `'south'` for param valid for all
    but `(1, 0, ..., 0)`. `None` will pick parametrisation automatically based]
    on the value of the last coordinate ('north' for embeddings with positive
    last component, 'south' otherwise) so as to control numerical stability.
  :return:

  """
  def _jac(s, chart, embed):
    return jax.jacfwd(embed)(chart(s))

  jac_north = lambda s: _jac(s, chart=chart_north, embed=embed_north)
  jac_south = lambda s: _jac(s, chart=chart_south, embed=embed_south)

  if parametrisation is None:
    ret = jnp.where(s[-1] > 0, jac_north(s), jac_south(s))
  elif parametrisation.lower() == 'north':
    ret = jac_north(s)
  elif parametrisation.lower() == 'south':
    ret = jac_south(s)
  else:
    raise ValueError(f'unknown sphere parametrisation `{parametrisation}`')

  return ret


def rieman_jac(s, utility):
  """Compute the Riemannian gradients for all producers.

  :param s:
    `n x d` array of producer embeddings. Assumed unit norm.
  :param utility:
    Function which takes `n x d` array of producer embeddings and returns
    an `n`-dimensional array with the corresponding producer utilities.
  :return:
    Riemannian gradients of each producer's utilities wrt their own strategy.
    Equal to the standard Euclidean gradient pre-multiplied by `(I - s_i s_i^T)`
     (orthogonal projection to the tangent space at `s_i`).
  """
  Id = jnp.eye(s.shape[-1])

  def _jac(sid):
    u_i = lambda v: utility(s.at[sid].set(v))[sid]
    g_i = jax.grad(u_i)(s[sid])  # Euclidean gradient
    return (Id - jnp.outer(s[sid], s[sid])) @ g_i  # Riemannian gradient

  return jax.vmap(_jac)(jnp.arange(len(s)))


def rieman_jac_and_hess(s, utility):
  """Compute the Riemannian gradients and Hessians for all producers.

  :param s:
    `n x d` array of producer embeddings. Assumed unit norm.
  :param utility:
    Function which takes `n x d` array of producer embeddings and returns
    an `n`-dimensional array with the corresponding producer utilities.
  :return:
    Riemannian gradients and Riemannian Hessians of each producer's utilities
    wrt their own strategies. If `P_i = (I - s_i s_i^T)` is the projection
    matrix to the tangent space at `s_i`, the Riemannian gradient is simply
    `P_i g_i` where `g_i` is the Euclidean gradient. The Riemannian Hessian is
    `P_i H_i P_i - <g_i, s_i> P_i` where `H_i` is the Euclidean Hessian.
  """
  Id = jnp.eye(s.shape[-1])

  def _jac_and_hess(sid):
    # utility function of the i^th producer
    u_i = lambda v: utility(s.at[sid].set(v))[sid]
    P_i = Id - jnp.outer(s[sid], s[sid])  # (symm) projection to tangent space

    # Euclidean gradient and Hessian
    def hvp(t):
      return jax.jvp(jax.grad(u_i), (s[sid],), (t,))
    G_i, H_i = jax.vmap(hvp)(Id)
    g_i = G_i[0]  # the gradient is the same for all vmapped values

    # Riemannian gradient and Hessian
    rg_i = P_i @ g_i
    rH_i = P_i @ H_i @ P_i - (g_i @ s[sid]) * P_i

    return rg_i, rH_i

  return jax.vmap(_jac_and_hess)(jnp.arange(len(s)))


def rieman_jac_and_hess_exposure(s, consumers, tau, logits_and_probs):
  _, probs = logits_and_probs(s)

  # gradients (jacobians)
  q = probs * (1 - probs)
  euclid_grad = jnp.einsum('mn,md->nd', q, consumers)
  euclid_grad /= len(consumers) * tau
  jacobians = euclid_grad - jnp.einsum('ni,nj,nj->ni', s, s, euclid_grad)

  # hessians
  q *= 1 - 2 * probs
  euclid_hess = jnp.einsum('mn,mi,mj->nij', q, consumers, consumers)
  euclid_hess /= len(consumers) * (tau**2)
  proj = jnp.eye(s.shape[-1])[None] - jnp.einsum('ni,nj->nij', s, s)
  hessians = jnp.einsum('nij,njk,nkl->nil', proj, euclid_hess, proj)
  hessians -= jnp.einsum('ni,ni,nkl->nkl', s, euclid_grad, proj)

  return jacobians, hessians


def _first_order_violation(rg, tol):
  """Check if the Riemannian gradient's `l2` norm is close to zero."""
  excess = jnp.linalg.norm(rg, ord=2) / (rg.size ** 0.5)
  violations = excess >= tol
  return violations, excess


def _second_order_violation(rH, tol):
  """Check if the Riemannian Hessian corresponds to a local optimum."""
  eigvals = jnp.linalg.eigvalsh(rH)
  eigvals = jnp.where(jnp.abs(eigvals) < tol, 0., eigvals)
  excess = jnp.maximum(0., eigvals).sum()  # eigvals[eigvals >= 0.] crashes JAX
  violations = jnp.sum(eigvals >= 0.) - 1  # one 0 for the radial direction
  return violations, excess


def first_order_riemann_test(s, pid, utility, tol):
  """Test whether `s` has Riemannian gradient with (close to) zero norm.

  :param s:
    `n x d` array of producer embeddings. Assumed unit norm.
  :param pid:
    Integer specifying the ID of the producer that should be checked (agrees
    with the row in `s` and the position within the return of `utility`).
  :param utility:
    Function which takes `n x d` array of producer embeddings and returns
    an `n`-dimensional array with the corresponding producer utilities.
  :param tol:
    Float. Smaller values are considered to be zero within the test.
  :return:
    Tuple `(success, excess)` where `success` is a boolean flagging whether
    given condition was satisfied, and `excess` is the `l2` norm of the
    Riemannian gradient. If `pid` is `None`, `success` is summed over producers,
    and `excess` is taken to be the `l2` norm of the individual excesses.
  """
  _first_order = jax.vmap(lambda v: _first_order_violation(v, tol))

  jacobians = rieman_jac(s, utility)
  violations, excesses = _first_order(jacobians)
  if pid is None:
    return violations.sum() == 0, jnp.linalg.norm(excesses, ord=2)
  else:
    return violations[pid] == 0, excesses[pid]


def second_order_riemann_test(
    s, pid, utility, tol, utility_type, consumers, tau, logits_and_probs):
  """Test whether `s` is a local PNE.

  :param s:
    `n x d` array of producer embeddings. Assumed unit norm.
  :param pid:
    Integer specifying the ID of the producer that should be checked (agrees
    with the row in `s` and the position within the return of `utility`).
  :param utility:
    Function which takes `n x d` array of producer embeddings and returns
    an `n`-dimensional array with the corresponding producer utilities.
  :param tol:
    Float. Smaller values are considered to be zero within the test.
  :return:
    Tuple `(success, excess)`. `success` is a boolean flagging whether all
    consumers have zero Riemannian gradient, and whether all eigenvalues of the
    Riemannian Hessian are strictly negative (except for a single zero
    corresponding to the projected away radial direction). `excess` is then a
    tuple of `l2` of the Riemannian gradients of each producer's utility at
    `s_i` (concatenated across producers if `pid == None`), and sum of positive
    (i.e., violating) eigenvalues of the Riemannian Hessian (additionally summed
    over all producers if `pid == None`).
  """
  _first_order = jax.vmap(lambda v: _first_order_violation(v, tol))
  _second_order = jax.vmap(lambda M: _second_order_violation(M, tol))

  if utility_type == 'exposure':
    jacobians, hessians = rieman_jac_and_hess_exposure(
      s=s, consumers=consumers, tau=tau, logits_and_probs=logits_and_probs)
  else:
    jacobians, hessians = rieman_jac_and_hess(s, utility)

  v1, e1 = _first_order(jacobians)
  v2, e2 = _second_order(hessians)
  if pid is None:
    success = (v1.sum() + v2.sum()) == 0
    excess = (jnp.linalg.norm(e1, ord=2), e2.sum())
    return success, excess
  else:
    return (v1[pid] + v2[pid]) == 0, (e1[pid], e2[pid])


def get_riemann_checkers(
    utility, reparam_fn, tol, jit_compile=False, utility_type=None,
    consumers=None, tau=None, logits_and_probs=None):
  """
  Wraps `first_order_riemann_test` and `second_order_riemann_test` so that
  they can be invoked with the optimised parameter instead of its
  reparametrisation (when employed), and without supplying the `utility` and
  `eps` random variables.

  See documentation of the wrapped methods for details.
  """
  cmpl = jax.jit if jit_compile else (lambda fun: fun)
  def param_to_strategy(p):
    s = reparam_fn(p)
    return s / jnp.linalg.norm(s, axis=-1, ord=2, keepdims=True)

  @cmpl
  def first_order_checker(param, pid=None):  # `None` means check all
    s = param_to_strategy(param)
    return first_order_riemann_test(s, pid=pid, utility=utility, tol=tol)

  @cmpl
  def second_order_checker(param, pid=None):  # `None` means check all
    s = param_to_strategy(param)
    return second_order_riemann_test(
      s, pid=pid, utility=utility, tol=tol, utility_type=utility_type,
      consumers=consumers, tau=tau, logits_and_probs=logits_and_probs)

  return first_order_checker, second_order_checker
