import ctypes
import os
# turn offjax memory allocation
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import jax.numpy as jnp
import jax.extend as jex
import numpy as np

# Load the shared library
hungarian_lib = ctypes.cdll.LoadLibrary("build/libhungarian_kernel.so")

# Register the FFI target
jex.ffi.register_ffi_target(
    "hungarian_matching",
    jex.ffi.pycapsule(hungarian_lib.HungarianMatching),
    platform="CUDA"
)

def hungarian_jax(cost_matrices):
    # Ensure the input is a JAX array on the GPU
    cost_matrices = jnp.asarray(cost_matrices, dtype=jnp.float32)
    assert cost_matrices.ndim == 3
    B, N, N2 = cost_matrices.shape
    assert N == N2, "Cost matrices must be square"

    # Define the output types
    matched_pair_i_type = jax.ShapeDtypeStruct((B, N), jnp.int32)
    matched_pair_j_type = jax.ShapeDtypeStruct((B, N), jnp.int32)

    # Call the FFI function
    matched_pair_i, matched_pair_j = jex.ffi.ffi_call(
        "hungarian_matching",
        (matched_pair_i_type, matched_pair_j_type),
        cost_matrices,
        vectorized=False,  # Batching is handled inside the kernel
    )

    return matched_pair_i, matched_pair_j

# Example usage
if __name__ == "__main__":
    B = 10  # Batch size
    N = 5
    cost_matrices = np.random.default_rng(3).integers(0, 100, size=(B, N, N)).astype(np.float32)

    # Move cost_matrices to JAX
    cost_matrices = jnp.array(cost_matrices)

    # JIT compile the function
    hungarian_jax_jit = jax.jit(hungarian_jax)
    # hungarian_jax_jit = hungarian_jax

    # Call the function
    matched_i, matched_j = hungarian_jax_jit(cost_matrices)

    # Print the results
    for b in range(B):
        print(f"Batch {b}:")
        print("Matched pairs (i):", matched_i[b])
        print("Matched pairs (j):", matched_j[b])
