
from scipy.optimize import linear_sum_assignment
def bipartite_matching_mse(tensor1, tensor2):
    """
    Perform bipartite matching based on the MSE matrix for batched tensor.
    
    :param tensor1: Tensor of shape [batch, N, C]
    :param tensor2: Tensor of shape [batch, M, C] (Note: N and M can be different)
    :return: the optimal matched loss
    """
    
    batch_size = tensor1.shape[0]
    
    all_matched_losses = []
    mse_matrix = th.mean((tensor1[:,:, None, :] - tensor2[:,None, :, :])**2, dim=-1)
    mse_matrix_numpy = mse_matrix.detach().cpu().numpy()
    
    # Loop through the batches
    for b in range(batch_size):
        
        # Step 3: Hungarian algorithm to solve linear sum assignment
        pred_indices, gt_indices = linear_sum_assignment(mse_matrix_numpy[b])
        
        # Compute the matched loss for this batch within PyTorch
        matched_loss = mse_matrix[b, pred_indices, gt_indices].mean()
        all_matched_losses.append(matched_loss)
    
    return th.stack(all_matched_losses).mean()

