from functools import partial

import jax

add = jax.numpy.add
equal = jax.numpy.array_equal
zeros_like = jax.numpy.zeros_like
min = jax.numpy.minimum
max = jax.numpy.maximum
allclose = partial(jax.numpy.allclose, rtol=5e-3, atol=5e-03)


def concatenate(tensors, dim=0):
    return jax.numpy.concatenate(tensors, axis=dim)


def chunk(input, chunks, dim=0):
    return jax.numpy.array_split(input, chunks, axis=dim)


def narrow(input, dim, start, length):
    indices = jax.numpy.asarray(range(start, start + length))
    return jax.numpy.take(input, indices, axis=dim)


Tensor = jax.Array

tree_flatten = jax.tree_util.tree_flatten


def tree_unflatten(values, spec):
    return jax.tree_util.tree_unflatten(spec, values)


clone = jax.numpy.copy

from_numpy = jax.numpy.array
