import torch


def wasserstein_distance_vec(u_values, v_values, u_weights=None, v_weights=None):
    """Batched implementation of scipy.stats.wasserstein_distance"""
    return _cdf_distance_vec(1, u_values, v_values, u_weights, v_weights)


def _cdf_distance_vec(p, u_values, v_values, u_weights=None, v_weights=None):
    """Batched implementation of scipy.stats._cdf_distance"""

    u_sorter = torch.argsort(u_values, dim=-1)
    v_sorter = torch.argsort(v_values, dim=-1)

    all_values = torch.cat([u_values, v_values], dim=-1)
    all_values, _ = torch.sort(all_values, dim=-1)

    deltas = torch.diff(all_values, dim=-1)

    u_cdf_indices = torch.searchsorted(u_values, all_values[..., :-1], side="right", sorter=u_sorter)
    v_cdf_indices = torch.searchsorted(v_values, all_values[..., :-1], side="right", sorter=v_sorter)

    if u_weights is None:
        u_cdf = u_cdf_indices / u_values.size(-1)
    else:
        u_sorted_cumweights = torch.concatenate(([0], torch.cumsum(u_weights[u_sorter], dim=-1)), dim=-1)
        u_cdf = u_sorted_cumweights[u_cdf_indices] / u_sorted_cumweights[-1]

    if v_weights is None:
        v_cdf = v_cdf_indices / v_values.size(-1)
    else:
        v_sorted_cumweights = torch.concatenate(([0], torch.cumsum(v_weights[v_sorter])), dim=-1)
        v_cdf = v_sorted_cumweights[v_cdf_indices] / v_sorted_cumweights[-1]

    # Compute the value of the integral based on the CDFs.
    # If p = 1 or p = 2, we avoid using torch.power, which introduces an overhead
    # of about 15%.
    if p == 1:
        return torch.sum(torch.multiply(torch.abs(u_cdf - v_cdf), deltas), dim=-1)
    if p == 2:
        return torch.sqrt(torch.sum(torch.multiply(torch.square(u_cdf - v_cdf), deltas), dim=-1))
    return (torch.sum(torch.multiply((torch.abs(u_cdf - v_cdf)) ** (1 / p), deltas), dim=-1)) ** (1 / p)
