import torch

def quantile(samples):
    # Returns a function that computes the quantile of a given sample
    samples_sorted, _ = samples.sort()
    def quantile_func(p):
        return samples_sorted[torch.floor(p * len(samples_sorted)).long()]
    return quantile_func

def W22(u_samples, v_samples):
    # Adapted from https://github.com/nklb/wasserstein-distance
    u_samples_sorted, _ = u_samples.sort()
    v_samples_sorted, _ = v_samples.sort()
    u_icdf_grids = torch.linspace(0, 1, steps=len(u_samples))
    v_icdf_grids = torch.linspace(0, 1, steps=len(v_samples))
    grids = torch.unique(torch.cat((u_icdf_grids, v_icdf_grids))).sort()[0]
    U_icdf = quantile(u_samples)(grids[:-1])
    V_icdf = quantile(v_samples)(grids[:-1])
    return torch.sum((U_icdf - V_icdf) ** 2 * torch.diff(grids))

# create an optimizer
optimizer = torch.optim.Adam([u_samples], lr=0.1)

# create two samples
u_samples = torch.randn(1000).requires_grad_(True)
v_samples = torch.randn(1000)

# compute the distance
print(W22(u_samples, v_samples))

# compute gradient of W22 with respect to u_samples
u_samples.requires_grad_(True)
print()
optimizer.zero_grad()
W22(u_samples, v_samples).backward()
optimizer.step()
print(W22(u_samples, v_samples))