import elegy as eg
import jax
from jax import numpy as jn


def laplacian(fn):
    hess = jax.hessian(fn)
    return lambda x: jn.diag(hess(x).reshape(x.shape * 2)).sum(keepdims=True)


def poisson(u, x, *args):
    def v(x):
        return u(x, *args)

    lap = jax.vmap(laplacian(v))
    return lap(x)


def helmholtz(u, x, *args, k_sqr=1):
    def v(x):
        return u(x, *args)

    lap = jax.vmap(laplacian(v))
    # can we avoid the double forward pass here?
    return lap(x) + k_sqr * v(x)


def reaction_diffusion(u, x, *args):
    def v(x):
        return u(x, *args)

    lap = jax.vmap(laplacian(v))

    v_x = v(x)

    return lap(x) + v_x * (1 - v_x)


def dirichlet(u, x, n, *args):
    return u(x, *args)


def neumann(u, x, n, *args):
    def v(x):
        return u(x, *args)

    _, du = jax.jvp(v, (x,), (n,))
    return du


class PDE(eg.Module):
    u: eg.Module

    def __init__(self, u, pde, bc):
        self.u = u
        self.pde = pde
        self.bc = bc

    def __call__(self, x_int, x_bc=None):
        if not isinstance(x_int, tuple):
            x_int = (x_int,)
        if not isinstance(x_bc, tuple):
            x_bc = (x_bc,)
        return {
            "pde": self.pde(self.u, *x_int),
            "bc": self.bc(self.u, *x_bc),
        }
