import functools
from typing import Callable, List, Optional

import chex
import jax
import jax.numpy as jnp

Kernel = Callable[[chex.Array, chex.Array], chex.Scalar]


@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
def norm_alpha(x: chex.Array, alpha: float) -> float:
    return jnp.square(x).sum() ** (alpha / 2)


@norm_alpha.defjvp
def _norm_alpha_dx(alpha: float, primal: chex.Array, tangent: chex.Array) -> float:
    (x,) = primal
    (x_dot,) = tangent
    grad_problem = jnp.all(jnp.logical_not(x))
    x_processed = jnp.where(grad_problem, jnp.ones_like(x), x)
    grad = (
        alpha * (jnp.linalg.norm(x_processed) ** (alpha - 2)) * x_processed.dot(x_dot)
    )
    ans = norm_alpha(x, alpha)
    return ans, jnp.where(grad_problem, 0.0, grad)


def energy_distance(
    x: chex.Array, y: chex.Array, alpha: float = 1.0, zero: Optional[chex.Array] = None
) -> float:
    d = norm_alpha(x - y, alpha)
    if zero is None:
        return -d
    return norm_alpha(x - zero, alpha) + norm_alpha(y - zero, alpha) - d


@jax.custom_jvp
def l1_norm(x: chex.Array) -> float:
    return jnp.sum(jnp.abs(x))


@l1_norm.defjvp
def _l1_dx(primal: chex.Array, tangent: chex.Array) -> float:
    (x,) = primal
    (x_dot,) = tangent
    grad = jnp.sign(x)
    return l1_norm(x), grad.dot(x_dot)


def l1(x: chex.Array, y: chex.Array) -> float:
    return -l1_norm(x - y)


def gaussian_kernel(x: chex.Array, y: chex.Array, bandwidth: float = 1.0) -> float:
    return jnp.exp(-jnp.sum(jnp.square(x - y)) / (2 * bandwidth))


def imq_kernel(
    x: chex.Array, y: chex.Array, bandwidth: float = 1.0, alpha: float = 0.5
) -> float:
    return jnp.power(1 + jnp.square(x - y).sum() / bandwidth, -alpha)


def kernel_matrix(k: Kernel, xs: List[chex.Array], ys: List[chex.Array]) -> chex.Array:
    return jax.vmap(jax.vmap(k, in_axes=(None, 0)), in_axes=(0, None))(xs, ys)
