import ot
import numpy as np

def w2_weighted(X,Y,b):
    M=ot.dist(X,Y)
    a=np.ones((X.shape[0],))/X.shape[0]
    return ot.emd2(a,b,M)