from functools import partial
import operator

from chex import ArrayTree
import jax
from jax import jit, numpy as jnp


@jit
def soft_update(params: ArrayTree, target_params: ArrayTree, tau: float):
    return jax.tree.map(
        lambda x, y: tau * x + (1.0 - tau) * y, params, target_params
    )


@jit
def tree_add(tree0: ArrayTree, tree1: ArrayTree, *, alpha: float = 1):
    return jax.tree.map(lambda x, y: x + alpha * y, tree0, tree1)


@jit
def tree_inner(tree0: ArrayTree, tree1: ArrayTree):
    inner_products = jax.tree.map(
        lambda x, y: jnp.inner(jnp.ravel(x), jnp.ravel(y)), tree0, tree1
    )
    return sum(x.sum() for x in jax.tree.leaves(inner_products))


@jit
def tree_norm(tree: ArrayTree, p: float):
    return jnp.power(
        sum(jnp.power(x, p).sum() for x in jax.tree.leaves(tree)), 1 / p
    )


@jit
def tree_ones_like(tree: ArrayTree) -> ArrayTree:
    return jax.tree.map(jnp.ones_like, tree)


@jit
def tree_scale(tree: ArrayTree, scale: float) -> ArrayTree:
    return jax.tree.map(partial(operator.mul, scale), tree)


@jit
def tree_shape(tree: ArrayTree):
    return jax.tree.map(jnp.shape, tree)


@jit
def tree_zeros_like(tree: ArrayTree) -> ArrayTree:
    return jax.tree.map(jnp.zeros_like, tree)
