import jax
from jax import numpy as jnp, Array, ShapeDtypeStruct
import os
import ctypes
from functools import partial

library_path = os.path.join(os.path.dirname(__file__), "assign_indices.so")
assign_indices_lib = ctypes.cdll.LoadLibrary(library_path)
jax.ffi.register_ffi_target(
    "assign_indices",
    jax.ffi.pycapsule(assign_indices_lib.assign_indices),
    platform="CUDA"
)

def assign_indices(CAP, costs: Array, *, num_warps=8):
    K, N = costs.shape
    #assert K == 64, f"Expected K to be 64, got {K}"
    result = jax.ffi.ffi_call(
        "assign_indices",
        (ShapeDtypeStruct((K,), jnp.int32),
         ShapeDtypeStruct((N,), jnp.int32), 
         ShapeDtypeStruct((K, CAP), jnp.int32),
         ShapeDtypeStruct((N,), jnp.int32),),
        vmap_method="broadcast_all",
    )(costs, N=N, CAP=CAP, num_warps=num_warps)
    cnt, lab, fwd, bwd = result
    #mask = jnp.arange(CAP)[None, :] < cnt[:, None]
    #fwd = jnp.where(mask, fwd, -1)
    return cnt, lab, fwd, bwd

if __name__ == "__main__":
    N = 32
    K = 64
    CAP = 32
    costs = jnp.ones((K, N), dtype=jnp.float32)
    correct_labels = jnp.arange(N, dtype=jnp.int32) % K
    costs = costs.at[correct_labels, jnp.arange(N)].set(0.0)

    result = assign_indices(CAP, costs)
    print(f"Result shape: {jax.tree.map(lambda x: x.shape, result)}")
    print(f"Result dtype: {jax.tree.map(lambda x: x.dtype, result)}")
    cnt, lab, fwd, bwd = result
    print(f"Acc: {jnp.mean(lab == correct_labels)}")

    B = 4
    batch_costs = jnp.ones((B, K, N), dtype=jnp.float32)
    batch_correct_labels = jnp.arange(N, dtype=jnp.int32) % K
    batch_costs = batch_costs.at[:, batch_correct_labels, jnp.arange(N)].set(0.0)
    batch_result = jax.vmap(partial(assign_indices, CAP))(batch_costs)
    print(f"Batch result shape: {jax.tree.map(lambda x: x.shape, batch_result)}")
    print(f"Batch result dtype: {jax.tree.map(lambda x: x.dtype, batch_result)}")
    batch_cnt, batch_lab, batch_fwd, batch_bwd = batch_result
    print(f"Batch Acc: {jnp.mean(batch_lab == batch_correct_labels)}")


