import warnings
from typing import Callable, Dict, Optional, Text, Tuple, Union

import jax
import jax.numpy as jnp
import optax
import scipy as sp
import logging
from tqdm import tqdm

# scipy libraries need separate import
import scipy.optimize
# import jax.scipy.optimize


Loss = Callable[[jnp.ndarray, jnp.ndarray], Union[jnp.ndarray, float]]
LossAndGrad = Callable[[jnp.ndarray, jnp.ndarray],
                       Tuple[Union[float, jnp.ndarray], jnp.ndarray]]
ArrToArr = Callable[[jnp.ndarray], jnp.ndarray]
ArrToNum = Callable[[jnp.ndarray], float]
NumFunc = Union[None, float, int, ArrToNum]
StrFunc = Union[None, Text, ArrToArr]
NumStrFunc = Union[StrFunc, float]
ConvCheck = Callable[[jnp.ndarray, Optional[int]], Tuple[bool, float]]
OptimiseOut = Tuple[
  jnp.ndarray,
  Tuple[bool, float, Union[None, jnp.ndarray, Dict[int, jnp.ndarray]]]]
Optimise = Callable[[jnp.ndarray, Optional[int]], OptimiseOut]


def get_reparam_fn(reparam: StrFunc) -> ArrToArr:
  """Constructs reparametrisation function mapping into the strategy space.

  Args:
    reparam: Either a string (describing the type of reparametrisation that
        should be constructed), `callable` already implementing the
        reparametrisation (returned without changes), or `None` (returns
        identity reparametrisation).

  Returns:
    A callable taking a strategy and returning a reparametrised strategy.
  """
  if reparam == 'rescale':
    reparam_fn = lambda v: v / jnp.linalg.norm(v, ord=2, axis=-1, keepdims=True)
  elif callable(reparam):
    reparam_fn = reparam
  elif reparam is None or reparam == 'identity':
    reparam_fn = lambda v: v
  else:
    raise NotImplementedError(reparam)

  return reparam_fn


def get_regulariser_fn(regulariser: NumFunc) -> ArrToNum:
  """Constructs regularisation function to softly enforce the sphere constraint.

  Args:
    regulariser: Either a float specifying multiplier of `l2` penalty
        (non-positive values translate into no penalty), `callable` already
        implementing the regulariser (returned without changes), or `None`
        (results in no regularisation).

  Returns:
    A callable taking an array and returning a softly enforced penalty term of
    the sphere constraint (float).
  """
  if callable(regulariser):
    reg_fn = regulariser
  elif regulariser is None or regulariser <= 0.0:
    reg_fn = lambda v: 0.0
  elif isinstance(regulariser, (int, float)) and regulariser > 0.0:
    reg_fn = lambda v: regulariser * jnp.sum(v ** 2, axis=-1) / v.shape[0]
  else:
    raise NotImplementedError(regulariser)

  return reg_fn


def get_util_loss(
    utility: ArrToArr,
    reparam: StrFunc = 'rescale',
    regulariser: NumFunc = 0.0
) -> Tuple[Loss, LossAndGrad]:
  """
  Converts the utility function into `loss` and `loss_and_grad` functions.
  The latter returns gradients of the loss besides the actual loss (negative
  utility) value.

  Args:
    utility: Function which takes in an `n x d` array, where each row represents
        a single producer's strategy vector, and returns an `n`-dimensional
        vector of producer utilities (in form of an array).
    reparam: Either a `string` or a custom function which takes in the `n x d`
        array of producer strategies, and returns its reparametrisation. The
        function is applied before feeding the vectors into the `utility`
        function (see above), but not when feeding into the regularisation
        function (see below). Allows, e.g., optimisation in unconstrained space
        while mapping into a constrained space like a sphere. The only supported
        `string` value is `'rescale'` which implements the reparametrisation
        `v -> v / norm(v)` to support the sphere constraint. Pass `None` if no
        reparametrisation is to be applied.
    regulariser: Either a `float` or a function which takes in the `n x d` array
        of producer strategies, and returns a vector of `n` regularisation
        values, one per producer. These are then *added* to each producer's
        loss. Useful when, e.g., the sphere constraint is relaxed into an `l2`
        regulariser. If a `float` is passed, it is used as a multiplier of the
        `l2` penalty (square norm divided by the embedding dimension to ensure
        independence on the dimension). Values of zero or `None` result in no
        regularisation.

  Returns: Tuple of `loss` and `loss_and_grad` functions. The former computes
      only the loss `- utility(reparam_fn(param)) + reg_fn(param)`, where
      `param` is the `n x d` array of producer strategies, which returns a
      `n`-dimensional vector of losses, one per producer. The `loss_and_grad`
      function then computes the per-producer gradients which point in the
      direction of the greatest improvement of their utility (modulo the
      regularisation), assuming all other producers remain at their current
      position.
  """
  reparam_fn = get_reparam_fn(reparam)  # no-op if already a callable
  reg_fn = get_regulariser_fn(regulariser)  # no-op if already a callable

  def loss(param: jnp.ndarray) -> jnp.ndarray:
    # regulariser takes raw param, not reparametrised
    return - utility(reparam_fn(param)) + reg_fn(param)

  def loss_and_grad(param: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    val, lvjp = jax.vjp(loss, param)
    grad = jax.vmap(lvjp)(jnp.eye(len(param)))[0]
    grad = jnp.diagonal(grad, axis1=0, axis2=1).T  # non-cooperative agents
    return val, grad

  return loss, loss_and_grad


def get_exposure_loss(
    utility, reparam, regulariser, consumers, tau, logits_and_probs):
  """A more computationally efficient version of exposure loss.

  :param utility:
    The exposure utility function. Takes in `n x d` array of user strategies
    and returns an `n`-dimensional array of the correcsponding utilities.
    Exposure utility. Takes in an `n x d` array, where each row represents a
    single producer's strategy vector, and returns an `n`-dimensional vector of
    producer utilities (in form of an array).
  :param reparam:
    Either a string (describing the type of reparametrisation that should be
    constructed), `callable` already implementing the reparametrisation
    (used without changes), or `None` (identity reparametrisation).
  :param regulariser:
    Either a float specifying multiplier of `l2` penalty (non-positive values
    translate into no penalty), `callable` already implementing the regulariser
    (used without changes), or `None` (no regularisation).
  :param consumers:
    An `m x d` array of consumer embedding locations.
  :param tau:
    A positive float specifying the temperature scaling for softmax.
  :param logits_and_probs:
    A function which takes in an `n x d` array of strategies (i.e., vectors on
    a unit sphere), and returns *two* `m x n` arrays of respectively `logits`
    and the corresponding `probabilities` of recommending the each producer
    to given consumer (rows sum to one). The temperature used to compute both
    of these should be the same as the `tau` argument supplied to this function.
  :return:
    Tuple of `loss` and `loss_and_grad` functions. The former computes only the
    loss `- utility(reparam_fn(param)) + reg_fn(param)`, where `param` is the
    `n x d` array of producer strategies, which returns a `n`-dimensional vector
    of losses, one per producer. The `loss_and_grad` function then computes the
    per-producer gradients which point in the direction of the greatest
    improvement of their utility (modulo the regularisation), assuming all other
    producers remain at their current position.
  """
  reparam_fn = get_reparam_fn(reparam)  # no-op if already a callable
  reg_fn = get_regulariser_fn(regulariser)  # no-op if already a callable

  def loss(param: jnp.ndarray) -> jnp.ndarray:
    # regulariser takes the raw param, not reparametrised!
    return - utility(reparam_fn(param)) + reg_fn(param)

  def loss_and_grad(param):
    s = reparam_fn(param)  # TODO: relies on reparam universality
    _, probs = logits_and_probs(s)

    # value
    val = - probs.mean(0) + reg_fn(param)

    # grad
    euclid_grad = probs * (1 - probs)
    euclid_grad = - jnp.einsum('mn,md->nd', euclid_grad, consumers)
    euclid_grad /= tau * len(consumers)  # uniform average over consumers

    grad = jax.vmap(  # jvp and vjp are similarly fast here
      lambda p, v: jax.vjp(reparam_fn, p)[1](v)[0])(param, euclid_grad)
    # grad = jax.vmap(
    #   lambda p, v: jax.jvp(reparam_fn, (p,), (v,))[1])(param, euclid_grad)
    grad += jax.vmap(jax.grad(reg_fn))(param)

    return val, grad

  return loss, loss_and_grad


def get_exposure_bound_loss(
    logits_and_probs: ArrToArr,
    reparam: StrFunc = 'rescale',
    regulariser: NumFunc = 0.0
) -> Tuple[Loss, LossAndGrad]:
  """
  Constructs a `loss` and `loss_and_grad` functions for the *exposure* based
  utility. This is the negative of the lower bound derived in the EM part of
  our paper (potentially w/ reparametrisation and regularisation).

  :param logits_and_probs:
    Function which returns a tuple `(logits, probs)`, both of which are arrays
    of size `m x n` where each row corresponds to the logits (/probabilities),
    of probabilities of recommendation of all producers to a given consumer. See
    `games.get_logits_and_probs` for more details.
  :param reparam:
    Either a `string` or a custom function which takes in the `n x d`
    array of producer strategies, and returns its reparametrisation. The
    function is applied before feeding the vectors into the `logits_and_probs`
    function (see above), but not when feeding into the regularisation
    function (see below). Allows, e.g., optimisation in unconstrained space
    while mapping into a constrained space like a sphere. The only supported
    `string` value is `'rescale'` which implements the reparametrisation
    `v -> v / norm(v)` to support the sphere constraint. Pass `None` if no
    reparametrisation is to be applied.
  :param regulariser:
    Either a `float` or a function which takes in the `n x d` array
    of producer strategies, and returns a vector of `n` regularisation
    values, one per producer. These are then *added* to each producer's
    loss. Useful when, e.g., the sphere constraint is relaxed into an `l2`
    regulariser. If a `float` is passed, it is used as a multiplier of the
    `l2` penalty (square norm divided by the embedding dimension to ensure
    independence on the dimension). Values of zero or `None` result in no
    regularisation.
  :return:
    Tuple of `loss` and `loss_and_grad` functions. Both take in the parameter
    and the auxiliary `q` distribution used in the lower bound (should be of
    the same shape as the outputs of `logits_and_probs`, i.e., `m x n`). `loss`
    then computes  `- log[utility(s)] + KL(q || p_tilde) + reg_fn(param)` where
    `s = reparam_fn(param)` and `p_tilde = p / utility(s)`, i.e., the array of
    recommendation probabilities (`m x n`) normalised so that columns sum to 1.
    The `loss_and_grad` function then computes teh per-producer gradients which
    point in the direction of the greatest improvement of the above `loss`,
    assuming all other producers remain at their current position.
  """
  reparam_fn = get_reparam_fn(reparam)  # no-op if already a callable
  reg_fn = get_regulariser_fn(regulariser)  # no-op if already a callable

  def loss(param: jnp.ndarray, q: jnp.ndarray):
    logits, _ = logits_and_probs(reparam_fn(param))
    return - (q * logits).mean(0) + reg_fn(param)

  def loss_and_grad(param: jnp.ndarray, q: jnp.ndarray):
    val, lvjp = jax.vjp(lambda p: loss(p, q), param)
    grad = jax.vmap(lvjp)(jnp.eye(len(param)))[0]
    grad = jnp.diagonal(grad, axis1=0, axis2=1).T  # non-cooperative agents
    return val, grad

  return loss, loss_and_grad


def get_pid_loss_and_grad(
    pid: int,
    param: jnp.ndarray,
    loss: Optional[Loss],
    loss_and_grad: Optional[LossAndGrad]
) -> Tuple[Optional[Loss], Optional[LossAndGrad]]:
  """
  Converts output of `get_util_loss` and `get_exposure_bound_loss` to a single
  producer `loss` and `loss_and_grad` functions with the strategy vectors for
  the other producers frozen. Useful, e.g., in `sequential_ascent` which cycles
  through the producers, optimising each for multiple steps or until convergence
  (best-response if local optima are global optima).

  Args:
    pid: Integer between in `[0, n-1]` identifying the producer for which the
      loss and gradients should be computed, keeping the others fixed.
    param: An `n x d` array of producer strategies which will be used as the
      frozen strategies except for the `pid` row which will be replaced with the
      parameter fed into the `loss` (resp. `loss_and_grad`) function.
    loss: A function returned by `get_loss_and_grad`, or satisfying the same
      API. Can also be `None` to ignore the loss function.
    loss_and_grad: A function returned by `get_loss_and_grad`, or satisfying the
      same API. Can also be `None` to ignore the loss_and_grad function.

  Returns:
    A tuple of `loss` and `loss_and_grad` functions, which take a
    `d`-dimensional array representing the strategy of the `pid` producer, and
    return loss (resp. loss and the gradient) for the producer, given all other
    producer strategy vectors are those in the `param` parameter. See, e.g.,
    `get_util_loss` for details on the construction of the two functions.
    Both entries can also be `None`.
  """
  _loss, _loss_and_grad = None, None

  if loss is not None:
    def _loss(p, *args, **kwargs):
      p = param.at[pid].set(p)
      return loss(p, *args, **kwargs)[pid]

  if loss_and_grad is not None:
    def _loss_and_grad(p, *args, **kwargs):
      p = param.at[pid].set(p)
      val, grad = loss_and_grad(p, *args, **kwargs)
      return val[pid], grad[pid]

  return _loss, _loss_and_grad


def _iter_wrapper(
    get_optimiser: Callable[[Loss, LossAndGrad], Tuple[Callable, Callable]],
    loss: Loss,
    loss_and_grad: LossAndGrad,
    n_rounds: int,
    tol: float = 1e-5
) -> Optimise:
  """Wrapper around optimisers implementing training loop and result logging.

  Args:
    get_optimiser: Function taking a `value` and `value_and_grad` as arguments,
        and returning a tuple of `(init, update)`. The `value` function returns
        losses for all updated producers, whereas the `value_and_grad` function
        returns both the losses and gradients (one per row). The `init` function
        accepts the initial value of the optimised parameter and returns an
        initial `state` which represents the initial state of the optimiser. The
        `state` is then accepted by the `update` function together with `param`
        which is current value of the optimised parameters; it returns
        `(state, (val, param))` where `state` and `param` are updated version of
        the supplied arguments, and `val` is the loss value before the update.
    loss: A `loss` function produced by `get_loss_and_grad` or satisfying the
        same API. See the documentation of `get_loss_and_grad` for details.
        Unlike in `optax_minimisation`, it is possible to only supply this
        function and let `scipy.optimize.minimize` to estimate the gradient
        (see its documentation).
    loss_and_grad: A `loss_and_grad` function produced by `get_loss_and_grad` or
        satisfying the same API. See the documentation of `get_loss_and_grad`
        for details.
    tol: A float determining the minimum improvement in each round before the
        optimiser considers the parameter a critical point.
    n_rounds: The *maximum* number of update steps to take. Also see the `tol`
        parameter.

  Returns:
    A function which takes in the initial `n x d` array of producer strategies,
    and `pid` of the producer for whom to simulate the dynamics (pass `None` if
    all producers are to be updated in each round simultaneously), and simulates
    the better response dynamics for `n_round` steps. If `pid is None`, the 1st
    returned value is the optimised `n x d` array of producer strategies; if
    `pid is not None`, the 1st returned value is just the `d`-dim optim vector.
    The 2nd return value in both cases is a `(bool, np.ndarray)` tuple where 1st
    is a boolean indicating whether the optimiser succeeded (see the `tol`
    parameter), and the 2nd an array of the losses from all rounds.
  """
  if n_rounds < 1:
    raise ValueError(
      f'must take at least one optimisation step; got `n_rounds == {n_rounds}`')

  # `pid == None` means simultaneous optimisation
  def optimise(param: jnp.ndarray, pid: Optional[int]) -> OptimiseOut:
    value, value_and_grad = loss, loss_and_grad
    if pid is not None:
      value, value_and_grad = get_pid_loss_and_grad(
        pid, param=param, loss=value, loss_and_grad=loss_and_grad)

    init, update = get_optimiser(value, value_and_grad)

    opt_param = jnp.array(param if pid is None else param[pid])
    losses, state, diff, success = [], init(opt_param), None, True
    for _ in tqdm(range(n_rounds)):
      state, (l, new_param) = update(state, opt_param)
      diff = jnp.linalg.norm(opt_param - new_param, ord=2, axis=-1)
      opt_param = new_param

      losses.append(l)
      if (diff < tol).all():
        success = True
        break

    return opt_param, (success, jnp.sum(diff**2)**0.5, jnp.array(losses))

  return optimise


def _em_wrapper(
    loss: Loss,
    loss_and_grad: LossAndGrad,
    logits_and_probs,
    n_em_rounds: int,
    tol: float = 1e-5,
    **iter_kwargs
) -> Optimise:
  if n_em_rounds < 1:
    raise ValueError(
      f'must take at least one optimisation step; '
      f'got `n_em_rounds == {n_em_rounds}`')

  # `pid == None` means simultaneous optimisation
  def optimise(param: jnp.ndarray, pid: Optional[int]) -> OptimiseOut:
    opt_param = jnp.array(param if pid is None else param[pid])
    losses, diff, success = [], None, True
    for _ in tqdm(range(n_em_rounds)):
      # E-step (optimise wrt the auxiliary distribution `q` w/ `param` fixed)
      _, probs = logits_and_probs(param)
      q = probs / jnp.sum(probs, axis=-1, keepdims=True)

      # M-step (optimise wrt the `param` w/ `q` fixed)
      value = lambda param: loss(param=param, q=q)
      value_and_grad = lambda param: loss_and_grad(param=param, q=q)
      m_optimise = _iter_wrapper(
        loss=value, loss_and_grad=value_and_grad, **iter_kwargs)

      new_param, (_, _, l) = m_optimise(param, pid)
      diff = jnp.linalg.norm(opt_param - new_param, ord=2, axis=-1)
      param = new_param if pid is None else param.at[pid].set(new_param)
      opt_param = new_param

      losses.append(l)
      if (diff < tol).all():
        success = True
        break

    return opt_param, (success, jnp.sum(diff**2)**0.5, jnp.vstack(losses))

  return optimise


def optax_minimisation(opt: optax.GradientTransformation):
  """
  Constructs a function which simulates the local gradient update dynamics
  implemented by a given `optax` optimiser. Support both simultaneous updating
  of all producer strategies, and optimisation of only single producer while
  keeping others fixed (useful in the sequential ascent approach).

  Args:
    opt: An `optax` optimiser, i.e., an instance of
        `optax.GradientTransformation`. Determines the local update rule
        employed by the producers.
    wrapper_kwargs: Parameters passed on to the `_iter_wrapper` method. See its
        documentation.

  Returns:
    As in `_iter_wrapper`.
  """
  if not isinstance(opt, optax.GradientTransformation):
    raise NotImplementedError(opt)

  def get_optimiser(value, value_and_grad):
    def init(param):
      return opt.init(param)

    @jax.jit
    def update(state, param):
      val, grad = value_and_grad(param)
      grad, state = opt.update(updates=grad, state=state, params=param)
      param = optax.apply_updates(params=param, updates=grad)
      return state, (val, param)

    return init, update

  return get_optimiser


def riemann_gd_minimisation(lr: Union[float, Callable]):
  """
  Combination of vanilla gradient descent with post-update retraction back onto
  the sphere. Also known as the Riemannian gradient descent. The retraction is
  fixed to be the renormalisation by norm: `x -> x / || x ||_2`.

  CAVEAT:
    Do not use together with either reparametrisation or regularisation.

  Args:
    lr: Either a float representing a fixed learning rate, or a callable which
        takes in the current step and returns stepsize for given step.
    wrapper_kwargs: Parameters passed on to the `_iter_wrapper` method. See its
        documentation.

  Returns:
    As in `_iter_wrapper`.
  """
  if isinstance(lr, (float, int)) and lr <= 0.0:
    raise ValueError(f'learning rate must be non-negative; was {lr}')
  project = lambda v: v / jnp.linalg.norm(v, ord=2, axis=-1, keepdims=True)

  def get_optimiser(value, value_and_grad):
    def init(param):
      return {'step': 0}

    @jax.jit
    def update(state, param):
      state['step'] += 1
      stepsize = lr(state['step']) if callable(lr) else lr

      val, grad = value_and_grad(param)
      param = project(param - stepsize * grad)
      return state, (val, param)

    return init, update

  return get_optimiser


def scipy_minimisation(tol: float = 1e-5, method: str = 'BFGS'):
  def get_optimiser(value, value_and_grad):
    jac = None if value_and_grad is None else True

    def init(param):
      return {}

    # no `jit` since we're using normal scipy (faster at small scale)
    def update(state, param):
      fn = value if jac is None else value_and_grad
      opt = sp.optimize.minimize(
        fn, x0=param, method=method, tol=tol, jac=jac)
      return state, (opt.fun, opt.x)

    return init, update

  return get_optimiser


def leap_minimisation(tol: float = 1e-5):
  """
  A custom optimizer based on the necessary condition `(I - s_i s_i^T) g_i = 0`
  for a strategy profile `s = (s_1, ... , s_n)` to be a critical point. Here
  `g_i` is the gradient of the loss function. To satisfy the condition, either
  `g_i` is zero or it is proportional to `s_i`. This optimser thus sets `s_i`
  to `g_i` whenever `|| g_i ||_2 > tol * sqrt(dim)` (see the docs below), and
  otherwise keeps it fixed.

  CAVEAT:
    Do not use together with either reparametrisation or regularisation.

  Args:
    tol: A `float` specifying how large the `l2` gradient norm has to be for the
        gradient to not be considered zero (if gradient is considered zero, no
        update to given producer's embedding is executed). The effective cutoff
        is `tol / sqrt(dim)` to make the parameter independent of the param
        `dim`. Should match the `tol` used internally in `convergence_checker`,
        otherwise may stop changing locations w/o satisfying the checker.

  Returns:
    As in `_iter_wrapper`.
  """
  def get_optimiser(value, value_and_grad):
    def init(param):
      return {}

    @jax.jit
    def update(state, param):
      dim = param.shape[-1]
      val, grad = value_and_grad(param)

      euclid_to_riemann_grad = lambda s, g: (jnp.eye(dim) - jnp.outer(s, s)) @ g
      if param.ndim > 1:
        euclid_to_riemann_grad = jax.vmap(euclid_to_riemann_grad)

      rg = euclid_to_riemann_grad(param, grad)
      rn = jnp.linalg.norm(rg, ord=2, axis=-1, keepdims=True)
      gn = grad / jnp.linalg.norm(grad, ord=2, axis=-1, keepdims=True)

      # only update where Rieman grad is not approx zero
      zero_grad = rn < (tol * (dim**0.5))  # should match `convergence_checker`
      param = jnp.where(zero_grad, param, gn)

      return state, (val, param)

    return init, update

  return get_optimiser


def optimiser_from_config(
    C,
    utility,
    logits_and_probs,
    reparam_fn,
    reg_fn,
    consumers
):
  """Instantiates an optimiser based on the supplied config file and losses.

  :param C:
    Config file with the same structure and semantics as in `config.py`.
  :param utility:
    The exposure utility function. Takes in `n x d` array of user strategies
    and returns an `n`-dimensional array of the correcsponding utilities.
    Exposure utility. Takes in an `n x d` array, where each row represents a
    single producer's strategy vector, and returns an `n`-dimensional vector of
    producer utilities (in form of an array).
  :param logits_and_probs:
    A function which takes in an `n x d` array of strategies (i.e., vectors on
    a unit sphere), and returns *two* `m x n` arrays of respectively `logits`
    and the corresponding `probabilities` of recommending the each producer
    to given consumer (rows sum to one). The temperature used to compute both
    of these should be the same as the `tau` argument supplied to this function.
  :param reparam_fn:
    A function which takes in an `n x d` array of parameters, and returns an
    `n x d` array of strategies (i.e., vectors on the unit sphere).
  :param reg_fn:
    A function which takes in an `n x d` array of parameters, and returns an
    `n`-dimensional vector of regularisation values.
  :param consumers:
    An `m x d` array of consumer embedding locations.
  :return:
    A function with the same semantics as a product of, e.g., `_iter_wrapper`.
  """
  use_reparam = C.reparam is not None
  use_regulariser = not (C.regulariser is None or C.regulariser <= 0)
  lr = C.lr * (C.tau if C.scale_lr_by_temperature else 1.0)
  logging.info(f'Effective learning rate: {lr}')

  logging.info(f'Setting up {C.optimiser} optimiser')
  if C.optimiser == 'optax':
    if not (use_reparam or use_regulariser):
      raise ValueError(
        '`optax` does not incorporate the sphere constraint so will not '
        'converge without either reparametrisation or regularisation')

    get_optimiser = optax_minimisation(opt=optax.sgd(lr))
  elif C.optimiser == 'riemann':
    if use_reparam or use_regulariser:
      raise ValueError(
        '`riemann` optimiser does not work with reparam or regularisation')

    get_optimiser = riemann_gd_minimisation(lr=lr)
  elif C.optimiser == 'scipy':
    if not (use_reparam or use_regulariser):
      raise ValueError(
        '`scipy` does not incorporate the sphere constraint so will not '
        'converge without either reparametrisation or regularisation')
    if C.dynamics == 'simultaneous':
      raise NotImplementedError(
        '`scipy` optimiser with "simultaneous" dynamics is discouraged')
    if C.n_inner_rounds != 1:
      warnings.warn(
        f'running `scipy` optimiser within an iterative optimiser for more '
        f'than one round is wasteful; setting `C.n_inner_rounds == 1` (was '
        f'`C.n_inner_rounds == {C.n_inner_rounds})`')
      C.n_inner_rounds = 1

    get_optimiser = scipy_minimisation(tol=C.tol)
  elif C.optimiser == 'leap':
    if use_reparam or use_regulariser:
      raise ValueError(
        '`leap` optimiser does not work with reparam or regularisation')

    get_optimiser = leap_minimisation(tol=C.tol)
  else:
    raise NotImplementedError(f'unknown optimiser "{C.optimiser}"')

  if C.em:
    if C.utility != 'exposure':
      raise ValueError(
        f'EM only derived for the "exposure" utility; was "{C.utility}"')

    def logits_and_probs_from_param(param):
      return logits_and_probs(reparam_fn(param))
    loss, loss_and_grad = get_exposure_bound_loss(
      logits_and_probs=logits_and_probs, reparam=reparam_fn, regulariser=reg_fn)

    optimise = _em_wrapper(
      get_optimiser=get_optimiser, loss=loss, loss_and_grad=loss_and_grad,
      logits_and_probs=logits_and_probs_from_param, n_rounds=C.n_inner_rounds,
      n_em_rounds=C.n_em_rounds, tol=C.tol)
  else:
    if C.utility == 'exposure':  # use a more efficient implementation
      loss, loss_and_grad = get_exposure_loss(
        utility=utility, reparam=reparam_fn, regulariser=reg_fn, tau=C.tau,
        consumers=consumers, logits_and_probs=logits_and_probs)
    else:
      loss, loss_and_grad = get_util_loss(
        utility=utility, reparam=reparam_fn, regulariser=reg_fn)

    optimise = _iter_wrapper(
      get_optimiser=get_optimiser, loss=loss, loss_and_grad=loss_and_grad,
      n_rounds=C.n_inner_rounds, tol=C.tol)

  return optimise




def simultaneous_ascent(optimise: Optimise, param: jnp.ndarray) -> OptimiseOut:
  """
  Use an output of either `optax_minimisation` or `scipy_minimisation` to
  simulate the *simultaneous* updating dynamics.

  CAVEAT:
    Does not work properly with `scipy_minimisation` (see its documentation).

  Args:
    optimise: A function returned by either `optax_minimisation` or
        `scipy_minimisation` (see their documentation for details), or
        satisfying the same API.
    param: An `n x d` array of the initial producer strategies.

  Returns:
    The output of the supplied `optimise` function.
  """
  return optimise(param, None)


def sequential_ascent(
    optimise: Optimise,
    param: jnp.ndarray,
    n_rounds: int,
    tol: float = 1e-5,
    verbose: int = -1
) -> OptimiseOut:
  """
  Use an output of any `[name]_optimisation` method in this file to simulate the
  *sequential* updating dynamics.

  Args:
    optimise: A function returned by any `[name]_optimisation` method in this
        file (see their documentation for details), or satisfying the same API.
    param: An `n x d` array of the initial producer strategies.
    n_rounds: The "outer-loop" number of rounds, i.e., the number of times we
        will cycle through all the producers. In other words, the number of
        times `optimise` will be applied to each producer.
    verbose: An integer which specifies the frequency of logging. Pass any
        non-positive value to switch off logging.

  Returns:
    The vector of optimised producer strategies, and a 3-tuple with the same
    semantics as of that returned by the `optimise` function, except for the
    `success` variable which specifies whether the `tol` was satisfied in any
    of the outer-loop rounds in which case the optimisation is intermitted
    (also see the documentation of `tol`).
  """
  opt_param = jnp.array(param)  # copy
  if n_rounds < 1:
    raise ValueError(f'no. of rounds smaller than one: {n_rounds}')

  losses = {pid: [] for pid in range(len(param))}
  for rid in tqdm(range(n_rounds)):  # outer for-loop
    if verbose > 0 and rid % verbose == 0:
      print(f'round {rid}')

    success, excess = True, 0.0  # return stats from the last round
    for pid in range(len(param)):  # inner for-loop
      p, (s, e, l) = optimise(opt_param, pid)

      opt_param = opt_param.at[pid].set(p)
      success &= s
      excess += e**2  # l2 norm
      losses[pid].append(jnp.nan if l is None else l[-1])

    excess **= 0.5
    success &= excess < tol
    if success:
      print(f'break at round {rid}')
      break
    else:
      success = False  # success != small changes which add up to a large number

  # convert losses into a format similar to the simultaneous dynamics
  losses = jnp.array([jnp.array(losses[k]) for k in losses]).T
  return opt_param, (success, excess, losses)
