import jax.numpy as jnp
import jax
import numpy as np
import einops
import jax.debug as jdb
from functools import partial
import os, sys
from typing import Optional, Tuple
BASEDIR = os.path.dirname(os.path.dirname(__file__))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

# distribute 2d points and define Sinkhorn
import util.latent_obj_util as loutil
try:
    import util.Auction as auction
    auction_on = True
except:
    auction_on = False
    print("Auction not found")


def debug_callback(inputs):
    print(inputs, np.sum(inputs[0]), np.sum(inputs[1]))

def Sinkhorn(cost_matrix, epsilon=0.01, max_iter=1000, tol=1e-3, epsilon_tiny=1e-10):

    na = cost_matrix.shape[-2]
    nb = cost_matrix.shape[-2]

    C_rescaled = cost_matrix - jnp.min(cost_matrix)

    Kmat = jnp.exp(-C_rescaled/epsilon)
    u = jnp.ones(Kmat.shape[:-2] + (na,))
    v = jnp.ones(Kmat.shape[:-2] + (nb,))

    for i in range(max_iter):
        u_new = 1.0 / jnp.maximum(jnp.einsum('...ij,...j', Kmat, v), epsilon_tiny)
        v_new = 1.0 / jnp.maximum(jnp.einsum('...ij,...i->...j', Kmat, u_new), epsilon_tiny)
        
        if jnp.linalg.norm(u_new - u) < tol and jnp.linalg.norm(v_new - v) < tol:
            break

        u = u_new
        v = v_new
    gamma = jnp.einsum('...i,...ij,...j->...ij', u, Kmat, v)

    return gamma



def Sinkhorn_LSE(cost_matrix, mu=1., nu=1., epsilon=0.01, max_iter=1000, tol=1e-3, epsilon_tiny=1e-9):
    na = cost_matrix.shape[-2]
    nb = cost_matrix.shape[-1]
    # Log of kernel matrix: log Kmat = -C / epsilon
    log_Kmat = -cost_matrix / epsilon
    # Initialize log_u and log_v (logarithms of u and v)
    log_u = jnp.zeros((*cost_matrix.shape[:-2], na))
    log_v = jnp.zeros((*cost_matrix.shape[:-2], nb))

    for i in range(max_iter):
        log_u_new = jnp.log(mu) - jax.scipy.special.logsumexp(log_Kmat + log_v[..., None, :], axis=-1)
        log_v_new = jnp.log(nu) - jax.scipy.special.logsumexp(log_Kmat.swapaxes(-1,-2) + log_u_new[..., None, :], axis=-1)
        # Check for convergence (optional)
        # if jnp.linalg.norm(log_u_new - log_u) < tol and jnp.linalg.norm(log_v_new - log_v) < tol:
        #     break
        log_u = log_u_new
        log_v = log_v_new

    # Compute the optimal transport plan gamma in the log domain
    # gamma = exp(log_u + log_Kmat + log_v)
    gamma = jnp.exp(log_u[..., :, None] + log_Kmat + log_v[..., None, :])

    return gamma



def Sinkhorn_LSE_fori(cost_matrix, mu=1.0, nu=1.0, epsilon=0.001, max_iter=10):
    na = cost_matrix.shape[-2]
    nb = cost_matrix.shape[-1]

    # Log of kernel matrix: log Kmat = -C / epsilon
    log_Kmat = -cost_matrix / epsilon

    # Initialize log_u and log_v (logarithms of u and v)
    log_u_init = jnp.zeros((*cost_matrix.shape[:-2], na))
    log_v_init = jnp.zeros((*cost_matrix.shape[:-2], nb))

    def body_fun(i, carry):
        log_u, log_v = carry

        # Update log_u: using log-sum-exp trick to ensure stability
        log_u_new = jnp.log(mu) - jax.scipy.special.logsumexp(log_Kmat + log_v[..., None, :], axis=-1)

        # Update log_v: using log-sum-exp trick to ensure stability
        log_v_new = jnp.log(nu) - jax.scipy.special.logsumexp(log_Kmat.swapaxes(-1,-2) + log_u_new[..., None, :], axis=-1)

        return log_u_new, log_v_new

    # def cond_fun(carry):
    #     log_u, log_v = carry
    #     # This is an optional stop condition based on tolerance (if you want to stop early)
    #     return jnp.linalg.norm(log_u_init - log_u) > tol or jnp.linalg.norm(log_v_init - log_v) > tol

    # Run the loop using jax.lax.fori_loop
    log_u, log_v = jax.lax.fori_loop(0, max_iter, body_fun, (log_u_init, log_v_init))

    # Compute the optimal transport plan gamma in the log domain
    # gamma = exp(log_u + log_Kmat + log_v)
    gamma = jnp.exp(log_u[..., :, None] + log_Kmat + log_v[..., None, :])

    return gamma


def greedy_bipartite_matching(cost_matrix, sort_by_a=False):
    """
    Greedy matching algorithm in JAX.
    Args:
        cost_matrix: (N, M) cost matrix
    Returns:
        matched_pairs: (K, 2) array of matched indices
    """
    outer_shape = cost_matrix.shape[:-2]
    N, M = cost_matrix.shape[-2:]
    
    if N==1:
        return jnp.stack([jnp.broadcast_to(jnp.array([0.]), outer_shape+(1,)), jnp.argmin(cost_matrix, axis=-1)], axis=-1).astype(jnp.int32)
    if M==1:
        return jnp.stack([jnp.argmin(cost_matrix, axis=-2),jnp.broadcast_to(jnp.array([0.]), outer_shape+(1,))], axis=-1).astype(jnp.int32)
        

    max_matches = min(N, M)
    nm_max = max(N, M)

    def body_fun(carry, _):
        cost, mask_row, mask_col = carry
        # Apply masks to the cost matrix
        masked_cost = jnp.where(mask_row[...,:, None] + mask_col[...,None, :], jnp.inf, cost)
        # Find the minimal cost element
        min_idx = jnp.argmin(einops.rearrange(masked_cost, '... i j -> ... (i j)'), axis=-1)
        row_idx = min_idx // M
        col_idx = min_idx % M
        # Update masks to exclude the selected row and column
        mask_row = jnp.where(row_idx[...,None]==jnp.arange(N), 1, mask_row)
        mask_col = jnp.where(col_idx[...,None]==jnp.arange(M), 1, mask_col)
        matched_pair = jnp.stack([row_idx, col_idx], axis=-1)
        return (cost, mask_row, mask_col), matched_pair

    # Initialize masks
    mask_row = jnp.zeros(outer_shape + (cost_matrix.shape[-2],))
    mask_col = jnp.zeros(outer_shape + (cost_matrix.shape[-1],))

    # Run the loop
    init_carry = (cost_matrix, mask_row, mask_col)
    _, matched_pairs = jax.lax.scan(body_fun, init_carry, None, length=max_matches)

    if len(outer_shape) > 0:
        matched_pairs = einops.rearrange(matched_pairs, 'k ... t -> ... k t')

    if sort_by_a:
        sorted_idx = jnp.argsort(matched_pairs[...,0:1], axis=-2)
        matched_pairs = jnp.take_along_axis(matched_pairs, sorted_idx, axis=-2)
    # padding with -1
    matched_pairs = jnp.concat([matched_pairs, -1*jnp.ones(outer_shape + (nm_max-max_matches, 2), dtype=jnp.int32)], axis=-2)

    return matched_pairs





def beam_search_bipartite_matching_nobatch(cost_matrix, beam_width, sort_by_a=False):
    """
    Beam–search matching algorithm for bipartite matching using JAX.
    
    Args:
        cost_matrix: array of shape (..., N, M) containing the cost for matching
                     row i with column j.
        beam_width: integer, the beam size (number of candidate partial matchings
                    to keep at each step).
        sort_by_a: if True, sort final matched pairs by the row indices.
        
    Returns:
        matched_pairs: array of shape (..., nm_max, 2) containing the matched index
                       pairs (row, col). (If the number of matches is less than nm_max,
                       the remaining rows are padded with -1.)
    """
    outer_shape = cost_matrix.shape[:-2]
    N, M = cost_matrix.shape[-2:]
    
    # Handle trivial cases (when one side has only one element)
    if N == 1:
        col_idx = jnp.argmin(cost_matrix, axis=-1)
        matched_pairs = jnp.stack([jnp.zeros_like(col_idx), col_idx], axis=-1).astype(jnp.int32)
        nm_max = max(N, M)
        pad = -1 * jnp.ones(outer_shape + (nm_max - 1, 2), dtype=jnp.int32)
        return jnp.concatenate([matched_pairs[..., None, :], pad], axis=-2)
    if M == 1:
        row_idx = jnp.argmin(cost_matrix, axis=-2)
        matched_pairs = jnp.stack([row_idx, jnp.zeros_like(row_idx)], axis=-1).astype(jnp.int32)
        nm_max = max(N, M)
        pad = -1 * jnp.ones(outer_shape + (nm_max - 1, 2), dtype=jnp.int32)
        return jnp.concatenate([matched_pairs[..., None, :], pad], axis=-2)
    
    max_matches = min(N, M)
    nm_max = max(N, M)
    
    # --- Initialize beam–search state.
    # We store in the state:
    #   total_cost: cumulative cost so far (shape: (beam_width, *outer_shape))
    #   mask_row: binary mask for rows already matched (shape: (beam_width, *outer_shape, N))
    #   mask_col: binary mask for cols already matched (shape: (beam_width, *outer_shape, M))
    #   matched_pairs: the sequence of matched pairs so far (shape: (beam_width, *outer_shape, max_matches, 2))
    #   count: number of pairs selected so far (shape: (beam_width, *outer_shape))
    
    init_total_cost = jnp.concatenate([
        jnp.zeros(outer_shape)[None, ...],
        jnp.full((beam_width - 1,) + outer_shape, jnp.inf)
    ], axis=0)
    init_mask_row = jnp.zeros(outer_shape + (N,), dtype=jnp.int32)
    init_mask_row = jnp.broadcast_to(init_mask_row, (beam_width,) + outer_shape + (N,))
    init_mask_col = jnp.zeros(outer_shape + (M,), dtype=jnp.int32)
    init_mask_col = jnp.broadcast_to(init_mask_col, (beam_width,) + outer_shape + (M,))
    init_matched_pairs = -1 * jnp.ones(outer_shape + (max_matches, 2), dtype=jnp.int32)
    init_matched_pairs = jnp.broadcast_to(init_matched_pairs, (beam_width,) + outer_shape + (max_matches, 2))
    init_count = jnp.zeros(outer_shape, dtype=jnp.int32)
    init_count = jnp.broadcast_to(init_count, (beam_width,) + outer_shape)
    
    state = (init_total_cost, init_mask_row, init_mask_col, init_matched_pairs, init_count)
    
    # --- The body function for beam search.
    # We run a scan for max_matches iterations. At iteration “step”, we expect that for
    # any candidate that has not yet been “completed” the count equals step.
    # For each candidate we compute (or “expand”) the top local moves (here we use a local
    # beam equal to beam_width). Then we merge the beam dimensions and select the best
    # beam_width candidates to form the new state.
    def beam_body_fn(state, step):
        total_cost, mask_row, mask_col, matched_pairs, count = state
        # active candidates are those that still need to be expanded (i.e. count == step)
        active = (count == step)  # shape: (beam_width, *outer_shape)
        
        # Expand cost_matrix to have a beam axis:
        expanded_cost = jnp.broadcast_to(cost_matrix, (total_cost.shape[0],) + cost_matrix.shape)
        # Compute the cost matrix “masked” so that already–used rows/cols are set to inf:
        candidate_mask = (mask_row[..., :, None] + mask_col[..., None, :]).astype(jnp.bool_)
        masked_cost = jnp.where(candidate_mask, jnp.inf, expanded_cost)
        # Flatten the last two dimensions so that each candidate sees a (N*M,) vector:
        flat_cost = masked_cost.reshape(masked_cost.shape[:-2] + (N * M,))
        
        # For each candidate, select the best local_beam moves.
        local_beam = beam_width  # you could choose a different expansion factor here.
        sorted_indices = jnp.argsort(flat_cost, axis=-1)[..., :local_beam]  # shape: (beam_width, *outer_shape, local_beam)
        selected_cost = jnp.take_along_axis(flat_cost, sorted_indices, axis=-1)  # same shape
        new_rows = sorted_indices // M  # shape: (beam_width, *outer_shape, local_beam)
        new_cols = sorted_indices % M   # shape: (beam_width, *outer_shape, local_beam)
        
        # New total cost for each expansion is the candidate cost plus the cost for the move.
        new_total_cost = total_cost[..., None] + selected_cost  # shape: (beam_width, *outer_shape, local_beam)
        
        # Update the row mask: set the new matched row to 1.
        new_mask_row_update = jax.nn.one_hot(new_rows, N, dtype=mask_row.dtype)  # (beam_width, *outer_shape, local_beam, N)
        new_mask_row_expanded = jnp.maximum(mask_row[..., None, :], new_mask_row_update)
        # Similarly for the column mask.
        new_mask_col_update = jax.nn.one_hot(new_cols, M, dtype=mask_col.dtype)  # (beam_width, *outer_shape, local_beam, M)
        new_mask_col_expanded = jnp.maximum(mask_col[..., None, :], new_mask_col_update)
        
        # Update the list of matched pairs: record the new pair at index "step".
        new_pair = jnp.stack([new_rows, new_cols], axis=-1)  # shape: (beam_width, *outer_shape, local_beam, 2)
        # To “insert” these new pairs at position step in the matched_pairs array,
        # we first expand matched_pairs along a new axis (for the local expansion).
        matched_pairs_expanded = jnp.broadcast_to(matched_pairs[..., None, :, :],
                                                  mask_row.shape[:-1] + (local_beam, max_matches, 2))
        # Create a one–hot vector for the “step” index
        step_one_hot = jax.nn.one_hot(step, max_matches, dtype=matched_pairs.dtype)
        step_one_hot = step_one_hot.reshape((1,) * (matched_pairs_expanded.ndim - 2) + (max_matches, 1))
        new_matched_pairs_expanded = matched_pairs_expanded * (1 - step_one_hot) + new_pair[..., None, :] * step_one_hot
        # Increase the count (number of moves taken) by one.
        new_count = count[..., None] + 1  # shape: (beam_width, *outer_shape, local_beam)
        
        # For candidates that are not active (i.e. already complete) we simply replicate
        # the original state along the new “local” dimension.
        active_expanded = jnp.broadcast_to(active[..., None], new_total_cost.shape)
        rep_total_cost = jnp.broadcast_to(total_cost[..., None], new_total_cost.shape)
        rep_mask_row = jnp.broadcast_to(mask_row[..., None, :], new_mask_row_expanded.shape)
        rep_mask_col = jnp.broadcast_to(mask_col[..., None, :], new_mask_col_expanded.shape)
        rep_matched_pairs = jnp.broadcast_to(matched_pairs[..., None, :, :], new_matched_pairs_expanded.shape)
        rep_count = jnp.broadcast_to(count[..., None], new_count.shape)
        
        final_total_cost = jnp.where(active_expanded, new_total_cost, rep_total_cost)
        final_mask_row = jnp.where(active_expanded[..., None], new_mask_row_expanded, rep_mask_row)
        final_mask_col = jnp.where(active_expanded[..., None], new_mask_col_expanded, rep_mask_col)
        final_matched_pairs = jnp.where(active_expanded[..., None, None], new_matched_pairs_expanded, rep_matched_pairs)
        final_count = jnp.where(active_expanded, new_count, rep_count)
        
        # Now, each candidate in the old beam has expanded into local_beam candidates.
        # We merge the two beam dimensions (the “old” beam and the new local branch) into one.
        final_total_cost = einops.rearrange(final_total_cost, 'b ... l -> (b l) ...')
        final_mask_row   = einops.rearrange(final_mask_row,   'b ... l n -> (b l) ... n')
        final_mask_col   = einops.rearrange(final_mask_col,   'b ... l m -> (b l) ... m')
        final_matched_pairs = einops.rearrange(final_matched_pairs, 'b ... l max t -> (b l) ... max t')
        final_count = einops.rearrange(final_count, 'b ... l -> (b l) ...')
        
        # Finally, we select the best beam_width candidates (lowest total_cost) along the merged beam dimension.
        sorted_beam_indices = jnp.argsort(final_total_cost, axis=0)
        top_beam_indices = sorted_beam_indices[:beam_width, ...]
        
        new_total_cost = jnp.take_along_axis(final_total_cost, top_beam_indices, axis=0)
        new_mask_row = jnp.take_along_axis(final_mask_row, top_beam_indices[...,None], axis=0)
        new_mask_col = jnp.take_along_axis(final_mask_col, top_beam_indices[...,None], axis=0)
        new_matched_pairs = jnp.take_along_axis(final_matched_pairs, top_beam_indices[...,None,None], axis=0)
        new_count = jnp.take_along_axis(final_count, top_beam_indices, axis=0)
        
        new_state = (new_total_cost, new_mask_row, new_mask_col, new_matched_pairs, new_count)
        return new_state, None
    
    # Run the beam–search for max_matches steps.
    state, _ = jax.lax.scan(beam_body_fn, state, jnp.arange(max_matches))
    
    total_cost, mask_row, mask_col, matched_pairs, count = state
    # Choose the best candidate from the final beam (lowest total cost).
    best_idx = jnp.argmin(total_cost, axis=0)
    best_matched_pairs = matched_pairs[best_idx]
    # best_matched_pairs = jnp.take_along_axis(matched_pairs, best_idx[None, ...], axis=0)
    # best_matched_pairs = einops.squeeze(best_matched_pairs, '1 ... -> ...')
    
    if sort_by_a:
        sorted_idx = jnp.argsort(best_matched_pairs[..., 0:1], axis=-2)
        best_matched_pairs = jnp.take_along_axis(best_matched_pairs, sorted_idx, axis=-2)

    # Pad with -1 so that the number of rows equals nm_max.
    pad_length = nm_max - max_matches
    if pad_length > 0:
        pad = -1 * jnp.ones(outer_shape + (pad_length, 2), dtype=jnp.int32)
        best_matched_pairs = jnp.concatenate([best_matched_pairs, pad], axis=-2)
    
    return best_matched_pairs

def beam_search_bipartite_matching(cost_matrix, beam_width=4, sort_by_a=False):

    outer_shape = cost_matrix.shape[:-2]
    cost_matrix_flat = cost_matrix.reshape((-1, cost_matrix.shape[-2], cost_matrix.shape[-1]))
    best_matched_pairs = jax.vmap(partial(beam_search_bipartite_matching_nobatch, beam_width=beam_width, sort_by_a=sort_by_a))(cost_matrix_flat)
    best_matched_pairs = best_matched_pairs.reshape(outer_shape + best_matched_pairs.shape[-2:])
    return best_matched_pairs



from jax.experimental import io_callback
from scipy.optimize import linear_sum_assignment

def hungarian_host(cost_matrix, valid_pair_mask=None):
    min_dim = min(cost_matrix.shape[-2:])
    if not valid_pair_mask:
        return np.stack([np.arange(min_dim), np.arange(min_dim)], axis=-1).astype(np.int32)
    else:
        cost_matrix = np.where(np.isnan(cost_matrix), 1e6, cost_matrix)
        cost_matrix = np.maximum(cost_matrix, -1e6)
        cost_matrix = np.minimum(cost_matrix, 1e6)
        cost_matrix = np.where(np.isinf(cost_matrix), 1e6, cost_matrix)
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        # return col_ind.astype(np.int32)
        return np.stack([row_ind, col_ind], axis=-1).astype(np.int32)
    
    # col_ind_batch = np.zeros((cost_matrix.shape[0], cost_matrix.shape[-2]), dtype=np.int32)
    # for i in range(cost_matrix.shape[0]):
    #     row_ind, col_ind = linear_sum_assignment(cost_matrix[i])
    #     col_ind_batch[i] = col_ind
    # return col_ind_batch

def hungarian_jax(cost_matrix, valid_pair_mask=None):
    def _hungarian_callback(cost_matrix_np, valid_pair_mask_np):
        return hungarian_host(cost_matrix_np, valid_pair_mask_np)
    min_dim = min(cost_matrix.shape[-2:])
    cost_matrix = jax.lax.stop_gradient(cost_matrix)
    matched_pair = io_callback(
        _hungarian_callback,
        jax.ShapeDtypeStruct((min_dim,2), jnp.int32),
        # jax.ShapeDtypeStruct((cost_matrix.shape[0],cost_matrix.shape[1]), jnp.int32),
        cost_matrix, valid_pair_mask
    )
    return jax.lax.stop_gradient(matched_pair)

def bipartite_matching_sp(cost_matrix, valid_pair_mask=None):
    N, M = cost_matrix.shape[-2:]

    if N==1:
        return jnp.stack([jnp.broadcast_to(jnp.array([0.]), outer_shape+(1,)), jnp.argmin(cost_matrix, axis=-1)], axis=-1).astype(jnp.int32)
    if M==1:
        return jnp.stack([jnp.argmin(cost_matrix, axis=-2),jnp.broadcast_to(jnp.array([0.]), outer_shape+(1,))], axis=-1).astype(jnp.int32)
        
    min_dim = min(N, M)
    nm_max = max(N, M)
    outer_shape = cost_matrix.shape[:-2]
    if valid_pair_mask is None:
        valid_pair_mask = jnp.ones(outer_shape, dtype=jnp.bool_)
    valid_pair_mask = valid_pair_mask.reshape((-1,))
    cost_matrix_flat = cost_matrix.reshape((-1, cost_matrix.shape[-2], cost_matrix.shape[-1]))
    matched_pairs = jax.vmap(hungarian_jax)(cost_matrix_flat, valid_pair_mask)
    matched_pairs = matched_pairs.reshape(outer_shape + (min_dim, 2))

    matched_pairs = jnp.concat([matched_pairs, -1*jnp.ones(outer_shape + (nm_max-min_dim, 2), dtype=jnp.int32)], axis=-2)

    return matched_pairs


def bipartite_matching_optax(cost_matrix):
    outer_shape = cost_matrix.shape[:-2]
    cost_matrix_flat = cost_matrix.reshape((-1, cost_matrix.shape[-2], cost_matrix.shape[-1]))
    matching_pair = jax.vmap(optax.assignment.hungarian_algorithm)(cost_matrix_flat)
    matching_pair = jnp.stack(matching_pair, axis=-1).astype(jnp.int32)
    return matching_pair.reshape(outer_shape + matching_pair.shape[-2:])




def extract_min(pwdif, matched_pair):
    valid_pair_mask = jnp.all(matched_pair >= 0, axis=-1)
    pwdif = jnp.take_along_axis(pwdif, matched_pair[...,None, 0], axis=-2)
    pwdif = jnp.take_along_axis(pwdif, matched_pair[..., None, 1], axis=-1).squeeze(-1)
    pw_dif = jnp.where(valid_pair_mask, pwdif, 0)
    return jnp.sum(pwdif, axis=-1)



def obj_dif(obj1:loutil.LatentObjects, obj2:loutil.LatentObjects, pos_loss_coef, dc_pos_loss_coef, loss_func=None, dif_type='hg', fps_matched_pair=None, fps_only=False): # ot for particles
    em_loss_log = {}
    if loss_func is None:
        loss_func = lambda x, y: ((x-y)**2).sum(axis=-1)
    pos_dif = loss_func(obj1.pos, obj2.pos)
    
    fps_matched_pair, (chloss, ch_fps, ch_z), _ = fps_matching(obj1, obj2, dif_type, loss_func=loss_func, dc_pos_loss_coef=dc_pos_loss_coef, fps_matched_pair=fps_matched_pair, fps_only=fps_only)

    len_dif = (obj1.len - obj2.len)**2
    len_dif = jax.lax.stop_gradient(len_dif)
    # fps_len1 = jnp.linalg.norm(obj1.rel_fps, axis=-1)
    # fps_len2 = jnp.linalg.norm(obj2.rel_fps, axis=-1)
    # len_dif = jnp.sum((jnp.sort(fps_len1, axis=-1) - jnp.sort(fps_len2, axis=-1))**2, axis=-1)
    
    assert chloss.shape == pos_dif.shape
    assert pos_dif.shape == len_dif.shape
    # entire_loss = chloss + pos_loss_coef * (pos_dif + 0.5*len_dif)
    entire_loss = chloss + pos_loss_coef * pos_dif
    em_loss_log['pos_dif'] = pos_dif
    em_loss_log['len_dif'] = len_dif
    em_loss_log['chloss'] = chloss
    em_loss_log['ch_fps'] = ch_fps
    em_loss_log['ch_z'] = ch_z
    return entire_loss, em_loss_log


def fps_matching(obj1:loutil.LatentObjects, obj2:loutil.LatentObjects, 
                 dif_type, dc_pos_loss_coef=10.0, loss_func=None, 
                 fps_only=False, fps_matched_pair=None)->Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], loutil.LatentObjects]:
    if loss_func is None:
        loss_func = lambda x, y: jnp.sum((x-y)**2, axis=-1)
    if dif_type=='cf' and fps_matched_pair is None:
        # Chamfer distance
        cen_pwdif = loss_func(obj1.fps_tf[..., :, None, :], obj2.fps_tf[..., None, :, :])
        if fps_only:
            z_pwdif = jnp.zeros_like(cen_pwdif)
        else:
            z_pwdif = loss_func(obj1.z_flat[..., :, None, :], obj2.z_flat[..., None, :, :])
        pwdif = z_pwdif + dc_pos_loss_coef * cen_pwdif
        i_argmin = jnp.argmin(pwdif, axis=-1, keepdims=True)
        j_argmin = jnp.argmin(pwdif, axis=-2, keepdims=True)
        def extract_min_cf(pwdif, i_argmin, j_argmin):
            return 0.5*jnp.sum(jnp.take_along_axis(pwdif, i_argmin, axis=-1).squeeze(-1) + jnp.take_along_axis(pwdif, j_argmin, axis=-2).squeeze(-2), axis=-1)
        chloss = extract_min_cf(pwdif, i_argmin, j_argmin)
        ch_fps = extract_min_cf(cen_pwdif, i_argmin, j_argmin)
        ch_z = extract_min_cf(z_pwdif, i_argmin, j_argmin)
        matched_pair = None
        obj2_aligned = None
    elif dif_type == 'ec' and fps_matched_pair is None:
        # Euclidian distance
        ch_fps = jnp.sum(loss_func(obj1.fps_tf, obj2.fps_tf), axis=(-1,))
        if fps_only:
            ch_z = 0
        else:
            ch_z = jnp.sum(loss_func(obj1.z_flat, obj2.z_flat), axis=(-1,))
        chloss = dc_pos_loss_coef*ch_fps + ch_z
        matched_pair = None
    elif dif_type == 'em' and fps_matched_pair is None:
        # Earth Mover's distance
        cen_pwdif = loss_func(obj1.fps_tf[..., :, None, :], obj2.fps_tf[..., None, :, :])
        if fps_only:
            z_pwdif = jnp.zeros_like(cen_pwdif)
        else:
            z_pwdif = loss_func(obj1.z_flat[..., :, None, :], obj2.z_flat[..., None, :, :])
        pwdif = z_pwdif + dc_pos_loss_coef * cen_pwdif
        ot_matrix = Sinkhorn_LSE_fori(pwdif, epsilon=0.01, max_iter=4)
        ot_matrix = jax.lax.stop_gradient(ot_matrix)
        chloss = jnp.sum(ot_matrix * pwdif, axis=(-1,-2)) * obj2.nobj
        ch_fps = jnp.sum(ot_matrix * cen_pwdif, axis=(-1,-2))
        ch_z = jnp.sum(ot_matrix * z_pwdif, axis=(-1,-2))
    elif dif_type in ['gbp', 'hg', 'hg_sp', 'beam2', 'beam3', 'beam4']:
        # greedy bipartite matching
        # cen_pwdif = loss_func(obj1.rel_fps[..., :, None, :], obj2.rel_fps[..., None, :, :])
        cen_pwdif = loss_func(obj1.fps_tf[..., :, None, :], obj2.fps_tf[..., None, :, :])
        if fps_only:
            z_pwdif = jnp.zeros_like(cen_pwdif)
        else:
            z_pwdif = loss_func(obj1.z_flat[..., :, None, :], obj2.z_flat[..., None, :, :])
        pwdif = z_pwdif + dc_pos_loss_coef * cen_pwdif
        if fps_matched_pair is not None:
            matched_pair = fps_matched_pair
        else:
            if dif_type=='gbp' or auction_on==False:
                matched_pair = greedy_bipartite_matching(jax.lax.stop_gradient(pwdif))
            elif dif_type[:-1] == 'beam':
                beamk = int(dif_type[-1])
                matched_pair = beam_search_bipartite_matching(jax.lax.stop_gradient(pwdif), beam_width=beamk)
            elif dif_type == 'hg_sp':
                matched_pair = bipartite_matching_sp(jax.lax.stop_gradient(pwdif))
            elif dif_type == 'hg':
                matched_pair = auction.auction_jax(jax.lax.stop_gradient(pwdif))
        
        # sort by obj1
        sorted_idx = jnp.argsort(matched_pair[...,0:1], axis=-2)
        matched_pair = jnp.take_along_axis(matched_pair, sorted_idx, axis=-2)
        
        matched_pair = jax.lax.stop_gradient(matched_pair)
        chloss = extract_min(pwdif, matched_pair)
        ch_fps = extract_min(cen_pwdif, matched_pair)
        ch_z = extract_min(z_pwdif, matched_pair)

        obj2_aligned = obj2.replace(rel_fps=jnp.take_along_axis(obj2.rel_fps, matched_pair[...,1:2], axis=-2),
                                        z=jnp.take_along_axis(obj2.z, matched_pair[...,1:2,None], axis=-3))
    
    return matched_pair, (chloss, ch_fps, ch_z), obj2_aligned


def obj_matching(obj_pred:loutil.LatentObjects, obj_target:loutil.LatentObjects, 
                 target_obj_valid_mask=None, dif_type='hg', pos_loss_coef=1.0, dc_pos_loss_coef=1.0, base_order='target', fps_only=False, callback=None):

    nobj_pred = obj_pred.nobj
    nobj_target = obj_target.nobj

    if target_obj_valid_mask is None:
        target_obj_valid_mask = jnp.ones_like(obj_target.obj_valid_mask).astype(jnp.bool_)

    if nobj_pred == 1:
        x0_pred_selected = obj_pred
        x0_target_selected = obj_target
        obj_matching_pair = jnp.zeros(obj_pred.shape + (2,), dtype=jnp.int32)
        matched_obj_valid_target_mask = jnp.ones_like(target_obj_valid_mask)
    else:
        obj_cost_matrix, em_loss_log = obj_dif(
            obj_pred.extend_and_repeat_outer_shape(nobj_target, -1),
            obj_target.extend_and_repeat_outer_shape(nobj_pred, -2), pos_loss_coef, dc_pos_loss_coef, 
            loss_func=None,
            fps_only=fps_only,
            # dif_type=dif_type,
            # dif_type='cf'
            # dif_type='hg'
            dif_type='cf' if dif_type in ['hg_sp', 'hg', 'beam2', 'beam3', 'beam4'] else dif_type,
            )
        obj_cost_matrix = jnp.where(target_obj_valid_mask[..., None, :], obj_cost_matrix, -jnp.inf)
        obj_cost_matrix = jnp.where(target_obj_valid_mask[..., None, :], obj_cost_matrix, 
                                    jax.lax.stop_gradient(jnp.max(obj_cost_matrix, axis=(-1,-2), keepdims=True))*1.05)
        if dif_type == 'gbp' or auction_on==False:
            obj_matching_pair = greedy_bipartite_matching(obj_cost_matrix)
        elif dif_type[:-1] == 'beam':
            beamk = int(dif_type[-1])
            obj_matching_pair = beam_search_bipartite_matching(obj_cost_matrix, beam_width=beamk)
        elif dif_type == 'hg':
            # padding to square matrix and use auction jax
            assert len(obj_pred.shape) == 2
            assert nobj_pred >= nobj_target
            obj_matching_pair = auction.auction_jax(obj_cost_matrix)
        elif dif_type == 'hg_sp':
            assert len(obj_pred.shape) == 2
            assert nobj_pred >= nobj_target
            obj_matching_pair = bipartite_matching_sp(obj_cost_matrix)
        else:
            raise ValueError(f"Invalid dif_type: {dif_type}")

        if base_order == 'target':
            order_idx = jnp.argsort(obj_matching_pair[..., 1:2], axis=-2)
        elif base_order == 'pred':
            order_idx = jnp.argsort(obj_matching_pair[..., 0:1], axis=-2)
        else:
            raise ValueError(f"Invalid base_order: {base_order}")
        obj_matching_pair = jnp.take_along_axis(obj_matching_pair, order_idx, axis=-2)

        # # Gather objects
        matched_obj_valid_target_mask = jnp.take_along_axis(target_obj_valid_mask, obj_matching_pair[..., 1], axis=-1)
        matched_obj_valid_target_mask = jnp.where(jnp.any(obj_matching_pair==-1, axis=-1), False, matched_obj_valid_target_mask)
        x0_pred_selected = obj_pred.take_along_outer_axis(obj_matching_pair[..., 0], axis=obj_pred.ndim-1)
        x0_target_selected = obj_target.take_along_outer_axis(obj_matching_pair[..., 1], axis=obj_pred.ndim-1)

        # selected costs
        obj_cost_selected = jnp.take_along_axis(obj_cost_matrix, obj_matching_pair[..., 0:1], axis=-2)
        obj_cost_selected = jnp.take_along_axis(obj_cost_selected, obj_matching_pair[..., 1:2], axis=-1).squeeze(-1)
        obj_cost_selected = jnp.where(matched_obj_valid_target_mask, obj_cost_selected, 0)

        if callback is not None:
            jdb.callback(callback, (em_loss_log, 
                                    obj_cost_matrix, order_idx,
                                    obj_matching_pair, target_obj_valid_mask, matched_obj_valid_target_mask))

    return x0_pred_selected, x0_target_selected, obj_matching_pair, matched_obj_valid_target_mask, obj_cost_selected



if __name__ == '__main__':
    # import util.Hungarian as hg
    
    jkey = jax.random.PRNGKey(0)
    num_batch = 2
    n = 64
    m = 64
    d = 3
    jkey, subkey = jax.random.split(jkey)
    a = jax.random.normal(subkey, (n, d)) * 0.001 + np.array([1, 1,1])
    # a = jnp.ones((n, d))
    jkey, subkey = jax.random.split(jkey)
    b = jax.random.normal(subkey, (m, d))

    # wa = jnp.ones(n).at[-2:].set(0)
    # wb = jnp.ones(m).at[-2:].set(0)
    wa = jnp.ones(n)
    wb = jnp.ones(m)

    def cal_cost(cost_mat, matched_pair_indices):
        cost = 0
        for i in range(matched_pair_indices.shape[-2]):
            iidx = matched_pair_indices[...,i,0]
            jidx = matched_pair_indices[...,i,1]
            if iidx == -1 or jidx == -1:
                continue
            if iidx > cost_mat.shape[-2] or jidx > cost_mat.shape[-1]:
                continue
            cost += cost_mat[...,iidx,jidx]
        return cost

        return jnp.sum(cost_mat[matched_pair_indices[:,0], matched_pair_indices[:,1]])
    
    cost_matrix = jnp.sum((a[...,None,:] - b[...,None,:,:])**2, axis=-1)
    # cost_matrix = cost_matrix.at[...,-2:].set(jnp.max(cost_matrix)*2.0)
    # matched_pairs = bipartite_matching_sp(cost_matrix)
    matched_pairs = greedy_bipartite_matching(cost_matrix, sort_by_a=True)
    
    matched_pairs2 = beam_search_bipartite_matching(cost_matrix, 1, sort_by_a=True)
    matched_pairs3 = beam_search_bipartite_matching(cost_matrix, 2, sort_by_a=True)
    matched_pairs4 = beam_search_bipartite_matching(cost_matrix, 3, sort_by_a=True)

    print(cal_cost(cost_matrix, matched_pairs), cal_cost(cost_matrix, matched_pairs2), cal_cost(cost_matrix, matched_pairs3), cal_cost(cost_matrix, matched_pairs4))

    # test optimization
    def loss_func(source, target, jkey, matched_pairs=None):
        # # source = source + jax.random.normal(jkey, source.shape)*0.1
        # cost_matrix = jnp.sum((source[...,None,:] - target[...,None,:,:])**2, axis=-1)
        # # cost_matrix = jnp.sum((target[...,None,:] - source[...,None,:,:])**2, axis=-1)
        # op_mat =  Sinkhorn_LSE_fori(cost_matrix, epsilon=1e-4, max_iter=4)
        # # op_mat = (jnp.argmax(op_mat, axis=-1, keepdims=True) == jnp.arange(op_mat.shape[-1])).astype(jnp.float32)
        # op_mat = (jnp.argmax(op_mat, axis=-2, keepdims=True) == jnp.arange(op_mat.shape[-1]))[...,None].astype(jnp.float32)
        # op_mat = jax.lax.stop_gradient(op_mat)
        # return jnp.sum(cost_matrix.min(axis=-1) + cost_matrix.min(axis=-2))

        # geom = pointcloud.PointCloud(source, target, epsilon=1e-6)
        # cost_matrix = geom.cost_matrix
        # ot = linear.solve(geom)
        # op_mat = ot.matrix*100
        # op_mat = jax.lax.stop_gradient(op_mat)
        # return jnp.sum(cost_matrix * op_mat)

        # cf
        # cost_matrix = jnp.sum((source[...,None,:] - target[...,None,:,:])**2, axis=-1)
        # return jnp.sum(cost_matrix.min(axis=-1) + cost_matrix.min(axis=-2))

        # sec
        # nsec = 1000
        # def unique_permutations(nsample, M):
        #     permutations_set = set()
        #     nprng = np.random.default_rng(0)
        #     while len(permutations_set) < nsample:
        #         perm = tuple(nprng.permutation(M))
        #         permutations_set.add(perm)
        #     return np.array(list(permutations_set))
        # permutation_matrix = unique_permutations(nsec, b.shape[0])
        # dist = jnp.sum((source[...,None,:,:] - target[...,permutation_matrix,:])**2, axis=(-1,-2))
        # return jnp.sum(dist.min(axis=-1))


        # ec
        # return jnp.sum((source - target)**2)
    
        # bipartite matching
        cost_matrix = jnp.sum((source[...,None,:] - target[...,None,:,:])**2, axis=-1)
        # matched_pairs = bipartite_matching_sp(cost_matrix)
        # matched_pairs = greedy_bipartite_matching(cost_matrix) # (N, 2)
        # matched_pairs = hg.hungarian_jax_cuda(cost_matrix[None]).squeeze(0)
        matched_pairs = auction.auction_jax(cost_matrix)
        source_matched = source[matched_pairs[:,0]]
        target_matched = target[matched_pairs[:,1]]
        return jnp.sum((source_matched - target_matched)**2)
    
        # hungarian matching



    
    # loss_func_jit = jax.jit(loss_func)
    loss_func_jit = loss_func
    grad_fn = jax.jit(jax.grad(loss_func))
    import optax
    opt = optax.adam(1e-2)
    opt_state = opt.init(a)
    jkey = jax.random.PRNGKey(0)
    for i in range(2000):
        jkey, subkey = jax.random.split(jkey)
        grad = grad_fn(a, b, subkey)
        updates, opt_state = opt.update(grad, opt_state)
        a = optax.apply_updates(a, updates)
        if i%10 == 0:
            print(i, loss_func_jit(a, b, subkey))
        # perform random permutation
        jkey, subkey = jax.random.split(jkey)
        a = a[jax.random.permutation(subkey, jnp.arange(a.shape[0]))]
    
    # draw two points
    import matplotlib.pyplot as plt
    plt.figure()
    plt.scatter(b[:,0], b[:,1], c='b')
    plt.scatter(a[:,0], a[:,1], c='r')
    plt.show()



    # res1 =  Sinkhorn_LSE(cost_matrix)
    res1 =  Sinkhorn_LSE_fori(cost_matrix, wa, wb, epsilon=1e-6, max_iter=1000)
    res2 =  Sinkhorn_LSE_fori(cost_matrix, wa, wb, epsilon=0.1, max_iter=1000)
    res3 =  Sinkhorn_LSE_fori(cost_matrix, wa, wb, epsilon=1e-6, max_iter=3)


    geom = pointcloud.PointCloud(a, b, cost_fn=None, epsilon=1e-6)
    solve_fn = jax.jit(linear.solve)
    ot = solve_fn(geom, wa, wb)
    res7 = ot.matrix

    import matplotlib.pyplot as plt
    plt.figure()
    plt.subplot(2,2,1)
    plt.imshow(ot.matrix)
    plt.subplot(2,2,2)
    plt.imshow(res1)
    plt.subplot(2,2,3)
    plt.imshow(res2)
    plt.subplot(2,2,4)
    plt.imshow(res3)
    # plt.colorbar()
    plt.title("Optimal Coupling Matrix")
    plt.show()

    print(1)
