import jax
import jax.numpy as jnp
import numpy as np

import os, sys
BASEDIR = os.path.dirname(os.path.dirname(__file__))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)
    
import ctypes
import os
# turn offjax memory allocation
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# import jax.extend as jex

# Load the shared library
auction_lib = ctypes.cdll.LoadLibrary("bindings/auction/libauction_kernel.so")

# Register the FFI target
jax.ffi.register_ffi_target(
    "auction_matching",
    jax.ffi.pycapsule(auction_lib.AuctionMatching),
    platform="CUDA",
)

def auction_jax(cost_matrices):
    # Ensure the input is a JAX array
    cost_matrices = jnp.asarray(cost_matrices, dtype=jnp.float32)
    outer_shape = cost_matrices.shape[:-2]
    N, M = cost_matrices.shape[-2:]
    
    if N == 1:
        return jnp.stack([
            jnp.broadcast_to(jnp.array([0]), outer_shape + (1,)),
            jnp.argmin(cost_matrices, axis=-1)
        ], axis=-1).astype(jnp.int32)
    
    if M == 1:
        return jnp.stack([
            jnp.argmin(cost_matrices, axis=-2),
            jnp.broadcast_to(jnp.array([0]), outer_shape + (1,))
        ], axis=-1).astype(jnp.int32)
    
    # Flatten the batch dimensions
    cost_matrix_flat = cost_matrices.reshape((-1, N, M))
    B, N, M = cost_matrix_flat.shape
    max_dim = max(N, M)
    # top_cost=200
    top_cost=80
    cost_max = jnp.max(cost_matrix_flat, axis=(-1, -2), keepdims=True)
    cost_min = jnp.min(cost_matrix_flat, axis=(-1, -2), keepdims=True)
    cost_matrix_flat = cost_matrix_flat - cost_min
    cost_matrix_flat = cost_matrix_flat / (cost_max - cost_min) * top_cost
    # cost_matrix_flat = jnp.round(cost_matrix_flat)
    
    # Compute a high cost for padding
    # high_cost = jnp.max(cost_matrix_flat)+1
    high_cost = top_cost + 0.1
    
    # Pad the cost matrix to make it square
    cost_matrix_padded = cost_matrix_flat
    if N < M:
        # Pad rows (agents)
        padding = ((0, 0), (0, M - N), (0, 0))
        cost_matrix_padded = jnp.pad(cost_matrix_flat, padding, constant_values=high_cost)
    elif N > M:
        # Pad columns (tasks)
        padding = ((0, 0), (0, 0), (0, N - M))
        cost_matrix_padded = jnp.pad(cost_matrix_flat, padding, constant_values=high_cost)
    # else:
    #     cost_matrix_padded = cost_matrix_flat
    
    # Define the output shape and type
    matched_col_type = jax.ShapeDtypeStruct((B, max_dim), jnp.int32)
    
    # Call the FFI function (auction algorithm implementation)
    cost_matrix_padded = jax.lax.stop_gradient(cost_matrix_padded)
    # matched_col = jax.ffi.ffi_call(
    #     "auction_matching",
    #     matched_col_type,
    #     cost_matrix_padded,
    #     vectorized=False  # Batching is handled inside the kernel
    # )
    matched_col = jax.ffi.ffi_call("auction_matching", matched_col_type)(cost_matrix_padded)
    matched_col = jax.lax.stop_gradient(matched_col)
    
    # Build the matched pairs (agent-task assignments)
    row_indices = jnp.arange(max_dim)[None, :].repeat(B, axis=0)  # Shape: (B, max_dim)
    matched_pair = jnp.stack([row_indices, matched_col], axis=-1)  # Shape: (B, max_dim, 2)
    
    # Exclude dummy agents and tasks from the results
    if N < M:
        # Exclude dummy agents (extra rows)
        matched_pair = matched_pair[:, :N, :]
    elif N > M:
        # Adjust matched_col to set assignments to dummy tasks as -1
        matched_col = jnp.where(matched_col < M, matched_col, -1)
        matched_pair = jnp.stack([row_indices, matched_col], axis=-1)
        matched_pair = matched_pair[:, :N, :]
    else:
        # No adjustment needed for square matrices
        matched_pair = matched_pair[:, :N, :]
    
    # Reshape back to the original batch dimensions
    matched_pair = matched_pair.reshape(outer_shape + (matched_pair.shape[1], 2))
    return matched_pair.astype(jnp.int32)


def auction_algorithm_numpy(cost_matrices, epsilon=None):
    """
    Implements the auction algorithm for the assignment problem using NumPy.
    
    Parameters:
    - cost_matrices: NumPy array of shape (B, N, N), where B is the batch size and N is the dimension.
    - epsilon: Optional parameter for the algorithm's epsilon; defaults to 1/(N+1) if not provided.
    
    Returns:
    - worker_to_task: NumPy array of shape (B, N) containing the assignments.
    """
    B, N, _ = cost_matrices.shape
    if epsilon is None:
        epsilon = 1.0 / (N + 1)

    # Initialize variables
    prices = np.zeros((B, N))
    worker_to_task = -np.ones((B, N), dtype=int)
    task_to_worker = -np.ones((B, N), dtype=int)
    unassigned_workers = np.full((B, N), True, dtype=bool)

    itr_no = 0
    while np.any(unassigned_workers):
        for b in range(B):
            unassigned = np.where(unassigned_workers[b])[0]
            if len(unassigned) == 0:
                continue

            # Bidding phase
            profits = -cost_matrices[b][unassigned, :] - prices[b][np.newaxis, :]
            max_profit_indices = np.argmax(profits, axis=1)
            max_profits = profits[np.arange(len(unassigned)), max_profit_indices]

            # Set the max profits to negative infinity to find second max
            profits[np.arange(len(unassigned)), max_profit_indices] = -np.inf
            second_max_profits = np.max(profits, axis=1)

            bid_values = max_profits - second_max_profits + epsilon

            # Record bids
            # idx over unassigned
            worker_indices = unassigned  # Workers who are bidding
            tasks = max_profit_indices   # Tasks they're bidding on
            # bid_values = bid_values      # Their bid values

            # Assignment phase
            highest_bid_values = -np.inf * np.ones(N)
            bidders = -np.ones(N, dtype=int)

            # For each bid, check if it's the highest for the task
            for i in range(len(worker_indices)):
                # for each unassigned worker
                worker = worker_indices[i]
                task = tasks[i]
                bid_value = bid_values[i]

                if bid_value > highest_bid_values[task]:
                    highest_bid_values[task] = bid_value
                    bidders[task] = worker

            tasks_with_bids = highest_bid_values > -np.inf
            prices[b][tasks_with_bids] += highest_bid_values[tasks_with_bids]
            # print(f"itr_no: {itr_no}, prices: {prices[b]}, highest_bid_values: {highest_bid_values}, bidders: {bidders}, task_to_worker: {task_to_worker[b]}")

            # Update assignments
            for task in np.where(tasks_with_bids)[0]:
                new_worker = bidders[task]
                prev_worker = task_to_worker[b][task]

                if prev_worker != -1 and prev_worker != new_worker:
                    worker_to_task[b][prev_worker] = -1
                    unassigned_workers[b][prev_worker] = True

                worker_to_task[b][new_worker] = task
                task_to_worker[b][task] = new_worker
                unassigned_workers[b][new_worker] = False
            # print(f"itr_no: {itr_no}, unassigned_workers: {unassigned_workers[b]}, worker_to_task: {worker_to_task[b]}")

        itr_no += 1
    print(itr_no)
    worker_to_task = np.stack([np.arange(N)[None].repeat(B, axis=0), worker_to_task], axis=-1)
    return worker_to_task

if __name__ == '__main__':

    import time
    import optax
    from scipy.optimize import linear_sum_assignment
    import util.bp_matching_util as bmutil

    N = 9
    M = 7
    B = 1

    def extact_cost(idx_pair):
        cost = 0
        for i, j in idx_pair[0]:
            if i==-1 or j==-1:
                continue
            cost += cost_matrix[0,i,j]
        return cost

    # calculcate time per each methds
    hungarian_jax_cuda_jit = jax.jit(auction_jax)
    # hungarian_jax_cuda_jit = auction_jax_rectangular
    # hungarian_jax_cuda_jit = auction_jax
    hungarian_jax_sp_jit = jax.jit(bmutil.bipartite_matching_sp)
    hungarian_jax_optax_jit = jax.jit(jax.vmap(optax.assignment.hungarian_algorithm))
    for seed in range(10000):
        # seed = 108
        # seed=2
        # seed = 2160
        # seed = 93
        # cost_matrix = np.random.default_rng(seed).integers(0, 100, size=(B, N, M)).astype(np.float32)
        # cost_matrix = np.random.default_rng(seed).random(size=(B, N, N)).astype(np.float32)
        # cost_matrix = np.random.default_rng(seed).uniform(300, 340, size=(B, N, N)).astype(np.float32)
        cost_matrix = np.random.default_rng(seed).random(size=(B, N, M)).astype(np.float32)
        cost_matrix = jnp.array(cost_matrix)

        time_start_optax = time.time()
        x_optax = hungarian_jax_optax_jit(cost_matrix)
        x_optax = jnp.stack(x_optax, axis=-1)
        x_optax = jax.block_until_ready(x_optax)
        time_end_optax = time.time()

        # test with scipy
        time_start_sp = time.time()
        row_ind, col_ind = linear_sum_assignment(cost_matrix[0])
        x_sp = np.stack([row_ind, col_ind], axis=-1).astype(np.int32)[None]
        # x_sp = hungarian_jax_sp_jit(cost_matrix)
        # x_sp = jax.block_until_ready(x_sp)
        time_end_sp = time.time()

        time_start = time.time()
        x = hungarian_jax_cuda_jit(cost_matrix)
        x = jax.block_until_ready(x)
        # x = auction_algorithm_numpy(cost_matrix)
        time_end = time.time()


        print(f"jax_cuda: {time_end-time_start:.6f}, jax_sp: {time_end_sp-time_start_sp:.6f}, optax: {time_end_optax-time_start_optax:.6f}")

        
        # x_np = auction_algorithm_numpy(np.array(cost_matrix).copy())
        # print(f"jax_cuda: {extact_cost(x)}, jax_sp: {extact_cost(x_sp)}, np: {extact_cost(x_np)}")

        # x_np = hungarian_algorithm_np(cost_matrix[0].copy())

        # x_warp = hungarian_jax(cost_matrix)
        # print(x)
        assert extact_cost(x) == extact_cost(x_sp)
        assert extact_cost(x) == extact_cost(x_optax)

        # assert np.all(x == x_sp)
        