import numpy as np
import seaborn as sns
import pandas as pd


def plot_samples_joint(sample_dict, save_path=None, **kwargs):
    """
    Plot multiple sets of 2D samples using seaborn's jointplot.
    
    Args:
        sample_dict: dict, keys are labels and values are numpy arrays of shape (n_samples, 2)
        save_path: str, path to save the plot. If None, the plot is displayed.
    """
    data = []
    for label, samples in sample_dict.items():
        for point in samples:
            data.append([point[0], point[1], label])
    
    df = pd.DataFrame(data, columns=['x', 'y', 'label'])
    plot = sns.jointplot(data=df, x='x', y='y', hue='label', alpha=0.5, **kwargs)
    
    legend = plot.ax_joint.get_legend()
    if legend:
        legend.remove()
    
    if save_path is not None:
        plot.savefig(save_path, bbox_inches='tight', dpi=150)
    return plot