# https://github.com/jindongwang/transferlearning/blob/master/code/distance/mmd_numpy_sklearn.py

import numpy as np
from sklearn import metrics


def mmd_rbf(x: np.ndarray, y: np.ndarray, gamma=1.0) -> float:
    """MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))

    Arguments:
        x {[n_sample1, dim]} -- [x matrix]
        y {[n_sample2, dim]} -- [y matrix]

    Keyword Arguments:
        gamma {float} -- [kernel parameter] (default: {1.0})

    Returns:
        [scalar] -- [MMD value]
    """
    xx = metrics.pairwise.rbf_kernel(x, x, gamma)
    yy = metrics.pairwise.rbf_kernel(y, y, gamma)
    xy = metrics.pairwise.rbf_kernel(x, y, gamma)
    return xx.mean() + yy.mean() - 2 * xy.mean()
