"""
Contains implementations of differential operators.

"""
import jax.numpy as jnp
from jax import hessian
from jax import jacrev
from jax import grad

from jaxtyping import Array, Float, jaxtyped, PyTree
from typeguard import typechecked as typechecker
from typing import Callable

@jaxtyped
###@typechecker
def del_i(
            func: Callable[[Float[Array, "d"]], Float[Array, ""]],
            argnum: int = 0,
       ) -> Callable[[Float[Array, "d"]], Float[Array, ""]]:
    """
    Partial derivative for a function of signature (d,) ---> ().
    
    """
    def func_splitvar(*args): return func(jnp.array(args))

    d_splitvar_di = jacrev(func_splitvar, argnum)

    ##@typechecker
    def dfunc_di(x: Float[Array, "d"]) -> Float[Array, ""]:
        return d_splitvar_di(*x)

    return dfunc_di

# Question: Does this also work for vector-valued functions?
# Question: How to implement an efficient Laplace?
def laplace(func, argnum=0):
    """
    Computes laplacian of func with respect to the argument argnum.
    
    Parameters
    ----------
    func: Callable
        Function whose laplacian should be computed.

    argnum: int
        Argument number wrt which laplacian should be computed.

    Returns
    -------
    Callable of same signature as func.

    Issues
    ------
    Vector valued func. So far not tested if this function works 
    appropriately for vector valued functions. We need an
    implementation that does this.
        
    """
    hesse = hessian(func, argnum)
    return (
        lambda *args, **kwargs: jnp.trace(
            hesse(*args, **kwargs), 
            axis1=-2, 
            axis2=-1,
        )
    #return (
    #    lambda params, x: jnp.trace(
    #        hesse(params, x), 
    #        axis1=1, 
    #        axis2=2,
    #    )
    )

# should be applied to functions (d,) -> (d,) 
def symgrad(func, argnum=0):
    def eps_u(*args, **kwargs):
        Du = jacrev(func, argnum)(*args, **kwargs)
        return 0.5 * (Du + jnp.transpose(Du))
    return eps_u

# should be applicable to functions (d,) -> (d,)
# but also (d,) -> (d, d)
def div(func, argnum=0):
    def div_f(*args, **kwargs):
        # of shape (d, d)
        # or shape (d, d, d)
        # take trace of the last two dimensions
        # this is not the best implementation of course
        J = jacrev(func, argnum)(*args, **kwargs)
        return jnp.trace(J, axis1=-2, axis2=-1)
    return div_f