import jax
import jax.numpy as jnp

# Import the compiled extension module.
from jax_libccd import _gjk_epa_module

# Register the FFI handler with JAX.
# The extension module returns a dict mapping "compute_penetration" to a nanobind capsule.
for name, target in _gjk_epa_module.registrations().items():
  jax.ffi.register_ffi_target(name, target)

# Define a custom-call primitive that wraps our FFI function.
def compute_penetration(mesh1, mesh2):
    shape = mesh1.shape[:-2]
    out = jax.ffi.ffi_call(
        "compute_penetration",
        (
            jax.ShapeDtypeStruct(shape, jnp.float32),
            jax.ShapeDtypeStruct((*shape, 3,), jnp.float32),
            jax.ShapeDtypeStruct((*shape, 3,), jnp.float32),
        ),
        vmap_method="broadcast_all",
    )(mesh1, mesh2)
    return out

def create_cube(c, s):
    return jnp.array([
        c + s * jnp.array([0.0, 0.0, 0.0]),
        c + s * jnp.array([1.0, 0.0, 0.0]),
        c + s * jnp.array([0.0, 1.0, 0.0]),
        c + s * jnp.array([0.0, 0.0, 1.0]),
        c + s * jnp.array([1.0, 1.0, 0.0]),
        c + s * jnp.array([1.0, 0.0, 1.0]),
        c + s * jnp.array([0.0, 1.0, 1.0]),
        c + s * jnp.array([1.0, 1.0, 1.0]),
    ])

def main():
    # Define two simple convex meshes.
    # For example, mesh1 is a tetrahedron and mesh2 is a slightly translated version.
    cpu_device = jax.devices("cpu")[0]
    gpu_device = jax.devices("gpu")[0]

    mesh1 = create_cube(jnp.array([0.0, 0.0, 0.0], device=cpu_device), 1.0)
    mesh2 = create_cube(jnp.array([0.2, 0.2, 0.2], device=cpu_device), 0.6)
    mesh2 = mesh1

    mesh1 = jnp.stack([mesh1] * 16, axis=0)
    mesh2 = jnp.stack([mesh2] * 16, axis=0)

    # Call the custom function. This will trigger the FFI handler.
    depth, penetration_dir, contact_point = compute_penetration(mesh1, mesh2)

    # Print the results.
    print("Penetration depth:", depth)
    print("Penetration direction:", penetration_dir)
    print("Contact point:", contact_point)

if __name__ == "__main__":
    main()
