import jax
import jax.numpy as jnp
import numpy as onp
from numpy.typing import ArrayLike, NDArray

FloatN = onp.ndarray
FloatNx3 = onp.ndarray
FloatBxB = onp.ndarray
Float2xBxB = onp.ndarray
FloatBxBxBxB = onp.ndarray


def relative_error(value: ArrayLike, target: ArrayLike, eps=1e-15) -> jax.Array:
    return jnp.abs(value - target) / (jnp.abs(target) + eps)  # type: ignore


def is_close(value: jax.Array, target: jax.Array, tolerance: float, absolute=False):
    """
    Returns True if the error is less than the tolerance.
    if absolute is True, the error is the absolute difference between value and target,
    otherwise it is the relative error.
    """
    if absolute:
        error = jnp.abs(value - target)
    else:
        error = relative_error(value, target)
    return jnp.all(error < tolerance)


def assert_is_close(
    value: jax.Array | NDArray,
    target: jax.Array | NDArray,
    mask: jax.Array | None = None,
    tolerance: float = 1e-10,
    absolute: bool = False,
    name='',
):
    """
    Asserts that two scalars or arrays are close element-wise.
    If mask is not None, only compare the elements where mask is True.
    "tolerance" is the maximum relative error allowed if absolute is False,
    otherwise it is the maximum absolute error allowed.
    """
    assert (
        value.shape == target.shape
    ), f'Shapes do not match: {value.shape} != {target.shape}'
    if mask is not None:
        value = value[mask]
        target = target[mask]
    rel = relative_error(value, target)
    abs = jnp.abs(value - target)
    if rel.size != 1:  # check if scalar
        rel = rel.flatten()
        abs = abs.flatten()
        if absolute:
            idx = abs.argmax()
        else:
            idx = rel.argmax()
        rel = rel[idx]
        abs = abs[idx]
    else:
        assert mask is None, 'Mask cannot be used with scalar values'
    if absolute:
        assert abs < tolerance, f'{name}: max (rel) / abs error: ({rel:.2e}) / {abs:.2e}'
    else:
        assert rel < tolerance, f'{name}: max rel / (abs) error: {rel:.2e} / ({abs:.2e})'


def assert_either_abs_or_rel_close(
    value: jax.Array | NDArray,
    target: jax.Array | NDArray,
    mask: jax.Array | None = None,
    relative_tolerance: float = 1e-10,
    absolute_tolerance: float = 1e-14,
    name='',
):
    """
    Asserts that two scalars or arrays are close element-wise, either by relative or absolute tolerance.

    This function checks if the maximum relative error or the maximum absolute error between
    the elements of `value` and `target` is within specified tolerances. If a `mask` is provided,
    only the elements where the mask is True are compared.

    Parameters:
    - value (jax.Array): The array of values to compare.
    - target (jax.Array): The target array to compare against.
    - mask (jax.Array | None, optional): A boolean array indicating which elements to compare. Defaults to None.
    - relative_tolerance (float, optional): The maximum allowed relative error. Defaults to 1e-10.
    - absolute_tolerance (float, optional): The maximum allowed absolute error. Defaults to 1e-14.
    - name (str, optional): A name to include in the assertion error message for identification. Defaults to ''.

    Raises:
    - AssertionError: If the shapes of `value` and `target` do not match.
    - AssertionError: If the maximum relative or absolute error exceeds the specified tolerances.
    - AssertionError: If a mask is provided for scalar values.

    The function ensures that either the maximum relative error is below `relative_tolerance` or
    the maximum absolute error is below `absolute_tolerance`, and vice versa.
    """
    assert (
        value.shape == target.shape
    ), f'Shapes do not match: {value.shape} != {target.shape}'
    if mask is not None:
        value = value[mask]
        target = target[mask]
    rel = relative_error(value, target)
    abs = jnp.abs(value - target)
    if rel.size != 1:  # check if scalar
        rel = rel.flatten()
        abs = abs.flatten()
        idx_abs_max = abs.argmax()
        idx_rel_max = rel.argmax()
        assert (
            rel[idx_rel_max] < relative_tolerance or abs[idx_rel_max] < absolute_tolerance
        ), f'{name}: max rel / (abs) error: {rel[idx_rel_max]:.2e} / ({abs[idx_rel_max]:.2e})'
        assert (
            abs[idx_abs_max] < absolute_tolerance or rel[idx_abs_max] < relative_tolerance
        ), f'{name}: (rel) / abs error: ({rel[idx_abs_max]:.2e}) / {abs[idx_abs_max]:.2e}'
    else:
        assert mask is None, 'Mask cannot be used with scalar values'
        assert (
            abs < absolute_tolerance or rel < relative_tolerance
        ), f'{name}: max rel / (abs) error: {rel:.2e} / ({abs:.2e})'
