import os
import ctypes
import jax
import jax.numpy as jnp
import numpy as np

# Load the compiled CUDA library
library_path = os.path.join(os.path.dirname(__file__), "argmin.so")
argmin_lib = ctypes.cdll.LoadLibrary(library_path)

# Register the FFI target with JAX
jax.ffi.register_ffi_target(
    "argmin",
    jax.ffi.pycapsule(argmin_lib.Argmin),
    platform="CUDA"
)

def cuda_argmin(x):
    """Argmin along last dimension using CUDA kernel via JAX FFI.
    
    Args:
        x: Input array of shape [32, 64] and dtype float32
        
    Returns:
        Array of shape [32] with dtype int32 containing argmin indices
    """
    # Validate input shape and dtype
    if x.shape != (32, 64):
        raise ValueError(f"Input must have shape [32, 64], got {x.shape}")
    
    if x.dtype != jnp.float32:
        raise ValueError(f"Input must be float32, got {x.dtype}")
    
    # Call the FFI function
    result = jax.ffi.ffi_call(
        "argmin",
        jax.ShapeDtypeStruct((32,), jnp.int32),  # output shape/dtype
        vmap_method="sequential"
    )(x)
    
    return result

# Test the kernel
if __name__ == "__main__":
    # Create test input with known argmin positions
    x = jnp.ones((32, 64), dtype=jnp.float32)
    
    # Set minimum values at specific positions for each row
    test_indices = jnp.arange(32) % 64  # positions 0, 1, 2, ..., 31, 32, 33, ..., 63, 0, 1, ...
    
    # Create array with known minimum positions
    x_modified = x.at[jnp.arange(32), test_indices].set(0.0)
    
    # Test our CUDA kernel
    cuda_result = cuda_argmin(x_modified)
    
    # Compare with JAX reference
    jax_result = jnp.argmin(x_modified, axis=-1)
    
    print(f"CUDA result shape: {cuda_result.shape}")
    print(f"JAX result shape: {jax_result.shape}")
    print(f"Expected indices: {test_indices}")
    print(f"CUDA result: {cuda_result}")
    print(f"JAX result: {jax_result}")
    print(f"Results match: {jnp.allclose(cuda_result, jax_result)}")
    
    # Test with random data
    key = jax.random.PRNGKey(42)
    x_random = jax.random.normal(key, (32, 64), dtype=jnp.float32)
    
    cuda_result_random = cuda_argmin(x_random)
    jax_result_random = jnp.argmin(x_random, axis=-1)
    
    print(f"\nRandom test results match: {jnp.allclose(cuda_result_random, jax_result_random)}")
    
    # Test error handling
    try:
        wrong_shape = jnp.ones((16, 64), dtype=jnp.float32)
        cuda_argmin(wrong_shape)
    except ValueError as e:
        print(f"\nCorrectly caught shape error: {e}")
    
    try:
        wrong_dtype = jnp.ones((32, 64), dtype=jnp.bfloat16)
        cuda_argmin(wrong_dtype)
    except ValueError as e:
        print(f"Correctly caught dtype error: {e}")
