import numpy as np

def gaussian_tv(x, y, sigma=1.0):
    support_size = max(len(x), len(y))
    x = x.astype(float)
    y = y.astype(float)
    if len(x) < len(y):
        x = np.hstack((x, [0.0] * (support_size - len(x))))
    elif len(y) < len(x):
        y = np.hstack((y, [0.0] * (support_size - len(y))))

    dist = np.abs(x - y).sum() / 2.0
    return np.exp(-dist * dist / (2 * sigma * sigma))


def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs):
    """MMD between two samples"""
    if is_hist:
        samples1 = [s1 / (np.sum(s1) + 1e-6) for s1 in samples1]
        samples2 = [s2 / (np.sum(s2) + 1e-6) for s2 in samples2]
    mmd = (
            disc(samples1, samples1, kernel, *args, **kwargs)
            + disc(samples2, samples2, kernel, *args, **kwargs)
            - 2 * disc(samples1, samples2, kernel, *args, **kwargs)
    )

    mmd = np.abs(mmd)

    if mmd < 0:
        import pdb

        pdb.set_trace()

    return mmd

def l2(x, y):
    dist = np.linalg.norm(x - y, 2)
    return dist

def disc(samples1, samples2, kernel, *args, **kwargs):
    """Discrepancy between 2 samples"""
    d = 0
    for s1 in samples1:
        for s2 in samples2:
            d += kernel(s1, s2, *args, **kwargs)

    if len(samples1) * len(samples2) > 0:
        d /= len(samples1) * len(samples2)
    else:
        d = 1e6
    return d