"""Smarter replacements for some jax functions plus other useful utilities

v0.0.24

Copyright (c) 2023 <>
Distributed under MIT License (https://opensource.org/license/mit/)
"""

_LICENSE = \
"""
Copyright (c) 2023 <>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from functools import partial, wraps
import inspect
from typing import Any, Callable, NamedTuple

__all__ = [ 'vmap', 'with_subkeys', 'tree_flatten_once', 'tree_transpose',
            'tree_zip', 'tree_unzip', 'tree_diff' ]

class _VmapWrapper(NamedTuple):
  fun: Callable
  in_axes: Any
  kw_axes: Any
  split_argnames: dict
  split_argnums: dict
  vmap: Callable
  vmap_kwargs: dict

  @classmethod
  def wrap(cls, fun, in_axes, kw_axes, vmap, vmap_kwargs):
    # Find arguments with keys to split
    params = inspect.signature(fun).parameters
    splitargs = { k[5:]: v for k,v in vmap_kwargs.items()
                  if k.startswith('split') and v }

    split_argnames, split_argnums = {}, {}

    # Figure out which argument type each is
    for name,reps in splitargs.items():
      vmap_kwargs.pop('split'+name)

      if reps is True:
        reps = 0
      elif not isinstance(reps, int) or reps <= 1:
        raise ValueError()

      try:
        p = params[name]
      except KeyError:
        split_argnames[name] = reps
        continue

      if p.kind == p.POSITIONAL_ONLY:
        split_argnums[tuple(params.keys()).index(name)] = reps
      elif p.kind == p.KEYWORD_ONLY:
        split_argnames[name] = reps
      elif p.kind == p.POSITIONAL_OR_KEYWORD:
        split_argnums[tuple(params.keys()).index(name)] = reps
        split_argnames[name] = reps
      else:
        raise ValueError(f"argument named '{name}' cannot be split: {p.kind}")

    return cls(fun, in_axes, kw_axes, split_argnames, split_argnums,
               vmap, vmap_kwargs)

  @staticmethod
  def _process_axes(tree, axes):
    # Handle simple case of no mapping
    if len(jtu.tree_leaves(axes)) == 0:
      tree_vs, treedef = jtu.tree_flatten(tree)
      axes = jtu.tree_unflatten(treedef, [None] * len(tree_vs))
      return tree, axes

    # Complete the axes PyTree
    axes_vs, treedef = jtu.tree_flatten(axes)
    tree_vs = treedef.flatten_up_to(tree)

    def _complete(a, v):
      vs, d = jtu.tree_flatten(v)
      return jtu.tree_unflatten(d, [a] * len(vs))

    axes = jtu.tree_unflatten(treedef, map(_complete, axes_vs, tree_vs))

    # Set axes to None if leaf should not be mapped
    axes = jtu.tree_map(
      lambda a, t: a if (isinstance(t, (jax.Array, np.ndarray, np.generic)) and
                         a < len(t.shape)) else None,
      axes, tree)

    # Strip singleton dimensions for broadcast
    is_single = jtu.tree_map(
      lambda t, a: a is not None and t.shape[a] == 1,
      tree, axes)

    tree = jtu.tree_map(lambda t, a, s: jnp.take(t, 0, a) if s else t,
                        tree, axes, is_single)
    axes = jtu.tree_map(
        lambda a, s: None if a is None else (None if s else a),
        axes, is_single,
        is_leaf=lambda x: x is None
    )

    return tree, axes

  @staticmethod
  def _inner_fun(fun, kwargs, *args):
    return fun(*args, **kwargs)

  def __call__(self, *args, **kwargs):
    # Proces inputs
    if isinstance(self.kw_axes, dict) and ... in self.kw_axes:
      kw_axes = {k: self.kw_axes[k] if k in kw_axes else self.kw_axes[...]
                 for k in kwargs}
    else:
      kw_axes = self.kw_axes

    if hasattr(self.in_axes, '__iter__'):
      in_axes = list(self.in_axes)
    else:
      in_axes = self.in_axes

    # Update axes, given the arguments
    args, in_axes = self._process_axes(list(args), in_axes)
    kwargs, kw_axes = self._process_axes(kwargs, kw_axes)

    # Determine the number of mapped calls
    reps = max(
        jtu.tree_leaves(
            jtu.tree_map(
                lambda a, t: None if a is None else t.shape[a],
                (in_axes, kw_axes), (args, kwargs),
                is_leaf=lambda x: x is None
            )
        ),
        default=1
    )

    # Split RNG keys
    for k,v in kwargs.items():
      if v is not None and k in self.split_argnames:
        kwargs[k] = jax.random.split(v, self.split_argnames[k] or reps)
        kw_axes[k] = 0

    for i,v in enumerate(args):
      if v is not None and i in self.split_argnums:
        args[i] = jax.random.split(v, self.split_argnums[i] or reps)
        in_axes[i] = 0

    # Call the function
    if reps == 1:
      rv = self.fun(*args, **kwargs)
      rv = jtu.tree_map(
        lambda t: (jnp.expand_dims(t, 0)
                   if isinstance(t, (jax.Array, np.ndarray, np.generic))
                   else t),
        rv)
      return rv
    else:
      return self.vmap(self._inner_fun,
                       in_axes=(None, kw_axes) + tuple(in_axes),
                       **self.vmap_kwargs)(self.fun, kwargs, *args)


class _VmapFunc:
  """A smarter replacement for `jax.vmap`

  Adds broadcasting for arguments without needing to use `in_axes`.
    Broadcasts all non-arrays as well as arrays with a single dimension
    along the requested axis.
  Adds a `kw_axes` argument for specifying mapping axes for keyword arguments.
    The key ... can be used to specify a default axis for kwargs.
  Handles keyword arguments that begins with `split`, (e.g., `splitkey`) to
    call `jax.random.split` on the argument (e.g., `key`).

  Can wrap both `jax.vmap` and `equinox.vmap` by setting `vmap` argument or
    conveniently as `smarter_jax.vmap[equinox.filter_vmap]`.
  """
  def __call__(self, fun, in_axes=0, kw_axes=0, vmap=jax.vmap, **vmap_kwargs):
    return _VmapWrapper.wrap(fun, in_axes, kw_axes, vmap, vmap_kwargs)

  def __getitem__(self, vmap):
    return partial(self, vmap=vmap)

vmap = _VmapFunc()


class Subkeys:
  def __init__(self, key=None):
    self.key = key
    self._keys = []
    self._split_n = 2

  def __next__(self):
    if self.key is None:
      return None
    elif not self._keys:
      self._keys = list(jax.random.split(self.key, self._split_n))
      self._split_n *= 2
    return self._keys.pop()

  def __bool__(self):
    return self.key is not None

  def __mul__(self, num):
    return jax.random.split(next(self), num)

  def __getstate__(self):
    return (repr(self.__class__),
            self.key.tolist(), self._split_n, len(self._keys))

  def __setstate__(self, state):
    assert state[0] == repr(self.__class__)
    self.key = jnp.array(state[1], dtype=jnp.uint32)
    self._split_n = state[2]
    self._keys = list(jax.random.split(self.key, self._split_n // 2))
    self._keys = self._keys[:state[3]]


def with_subkeys(arg, *args):
  """A smarter way to split a PRNGkey

  Examples:

  key = with_subkeys(jax.random.PRNGKey(0))
  a = jax.random.normal(next(key), (4,))
  b = jax.random.normal(next(key), (4,))
  any(a == b)  # False

  @with_subkeys
  def func(key, shape):
    a = jax.random.normal(next(key), shape)
    b = jax.random.normal(next(key), shape)
    return a, b
  func(jax.random.PRNGKey(0), (4,))
  any(a == b)  # False
  """
  if isinstance(arg, str):
    args = [arg] + list(args)
    def dec(f):
      return with_subkeys(f, *args)
    return dec

  elif hasattr(arg, '__call__'):
    f = arg
    f_params = inspect.signature(f).parameters
    try:
      key_names = args or ['key']
      key_args = [
        (p.name, i if p.kind in (p.POSITIONAL_ONLY,
                                 p.POSITIONAL_OR_KEYWORD) else -1)
        for i, p in enumerate(f_params.values()) if p.name in key_names ]

    except ValueError as e:
      raise ValueError(
        f"invalid argument names for `with_subkeys` decorated function")

    @wraps(f)
    def _f(*args, **kwargs):
      args = list(args)
      for n,i in key_args:
        if 0 <= i < len(args) and isinstance(args[i], jax.numpy.ndarray):
          args[i] = with_subkeys(args[i])
        elif n in kwargs and isinstance(kwargs[n], jax.numpy.ndarray):
          kwargs[n] = with_subkeys(kwargs[n])
        elif f_params[n].default is None:
          kwargs[n] = with_subkeys(None)
      return f(*args, **kwargs)

    return _f

  elif arg is None or isinstance(arg, jax.numpy.ndarray):
    return Subkeys(arg)

  else:
    raise TypeError(
      f'argument not a PRNGKey')


class with_sharing:
  """Decorator for passing a PyTree with sharing through a jax transformation

  Jax transformations assume an input PyTree is atree, i.e., leaves are unique
  as opposed to situations where a jax.Array may appear multiple times in the
  PyTree.  In such situations, copies are made within the transformation.

  By wrapping the transformation with this decorator (e.g.,
  `@with_sharing(jax.jit)`) it keeps the sharing structure through the
  transformation.
  """

  class _ShareIndex(int):
    def __repr__(self):
        return f'_ShareIndex({int(self)})'

  @classmethod
  def _tree_unshare(cls, x):
    leaves, treedef = jtu.tree_flatten(x)

    shared_leaves = []
    shared_ids = {}
    for l in leaves:
      if id(l) in shared_ids:
        shared_leaves.append(cls._ShareIndex(shared_ids[id(l)]))
      else:
        shared_leaves.append(l)
        shared_ids[id(l)] = len(shared_leaves) - 1

    return jtu.tree_unflatten(treedef, shared_leaves)

  @classmethod
  def _tree_share(cls, x):
    leaves, treedef = jtu.tree_flatten(x)
    leaves = [leaves[l] if isinstance(l, cls._ShareIndex) else l for l in leaves]
    return jtu.tree_unflatten(treedef, leaves)

  def __init__(self, dec):
    self.dec = dec

  def __call__(self, f):
    def _f_inner(*args, **kwargs):
      args = tuple(self._tree_share(a) for a in args)
      kwargs = {k: self._tree_share(v) for k,v in kwargs.items()}
      return self._tree_unshare(f(*args, **kwargs))

    _dec_f = self.dec(_f_inner)

    @wraps(f)
    def _f_outer(*args, **kwargs):
      args = tuple(self._tree_unshare(a) for a in args)
      kwargs = {k: self._tree_unshare(v) for k,v in kwargs.items()}
      return self._tree_share(_dec_f(*args, **kwargs))

    return _f_outer


def tree_flatten_once(t):
  """Flattens a PyTree a single level

  Returns the root's children and the one-level PyTreeDef.
  """
  return jtu.tree_flatten(t, lambda st: st is not t)


def tree_transpose(t):
  """A smarter replacement for `jax.tree_util.tree_transpose`

  Transposes the top two "axes" of the PyTree.  Unlike the builtin
  `tree_transpose`, automatically infers the PyTreeDef's from the provided
  input.

  The PyTree to transpose must have all children share the same top-level
  PyTreeDef.
  """
  flat, treedef = jtu.tree_flatten(t)

  outer_children, outer = tree_flatten_once(t)
  inner_children, inners = list(zip(*map(tree_flatten_once, outer_children)))

  assert len(set(inners))==1, \
    f'Not all children have the same PyTreeDef:\n{set(inners)}'

  return jtu.tree_unflatten(inners[0],
    list(map(partial(jtu.tree_unflatten, outer), list(zip(*inner_children)))))


def tree_zip(trees, axis=0, is_leaf=None):
  """Zips a list of PyTrees of arrays into a single PyTree

  All leaves must be arrays and all trees must share same PyTreeDef with
  with the same shaped leaves.
  """
  vs, treedefs = zip(*(jtu.tree_flatten(t, is_leaf) for t in trees))
  assert all(t == treedefs[0] for t in treedefs[1:])
  vs = tree_transpose(vs)
  return jtu.tree_unflatten(treedefs[0], (jnp.stack(v, axis=axis) for v in vs))


def tree_unzip(tree, axis=0, is_leaf=None):
  """Unzips a PyTree of arrays into a list of PyTrees

  All leaves must be arrays and share the same length on the specified axis.
  """
  vs, treedef = jtu.tree_flatten(tree, is_leaf)
  vs = zip(*( jnp.moveaxis(v, axis, 0) for v in vs ))
  return [ jtu.tree_unflatten(treedef, v) for v in vs ]


def tree_diff(trees, distinguish_numpy=True, distinguish_weak_type=True):
  """Looks for differences between lists of PyTrees

  Returns a PyTree where all leaves are sets containing the different
  values of the leaf found across the trees, or diffent PyTreeDefs if the
  leaves differ in structure.

  For any part of the tree (possibly including the entire tree) where
  there is no differences the tree's value is None.

  For array values it only looks for differences on shape and type (and
  optionally whether it is a JAX or numpy array).
  """
  def _type_and_shape_if_array(a):
    if isinstance(a, (jax.Array, np.ndarray, np.generic)):
      s = f"{a.dtype}{list(a.shape)}"
      s = s.replace('float', 'f').replace('int', 'i').replace('complex', 'c')
      s = s.replace(' ', '')
      if distinguish_numpy and isinstance(a, (np.ndarray, np.generic)):
        s = 'np.' + s
      if distinguish_weak_type and isinstance(a, jax.Array) and a.weak_type:
        s = s + 'w'
      return s
    else:
      return a

  trees = jtu.tree_map(_type_and_shape_if_array, trees)

  if all(tree == trees[0] for tree in trees):
    return None

  vs, ts = zip(*(tree_flatten_once(tree) for tree in trees))

  if any(t != ts[0] for t in ts):
    return set(ts)

  if tuple(v[0] for v in vs) == trees:
    return set(v[0] for v in vs)

  diff_leaves = [tree_diff(v) for v in zip(*vs)]

  # This covers the case where an object's equality is not consistent with
  # equality over its flattened representation (e.g., jax.tree_util.Partial)
  if all(l is None for l in diff_leaves):
    return None

  return jtu.tree_unflatten(ts[0], diff_leaves)


def tree_error_if_all_not_equal(vs):
  """Checks that a list of PyTrees are all equal
  """
  diff = tree_diff(vs)

  if diff:
    msg = 'PyTrees not all equal; found the following differences:\n'
    msg += repr(diff).replace('\n', '\n\t')
    raise ValueError(msg)
