import geomloss
import ot
import ot.utils
from ot.backend import get_backend
import numpy as np
import torch

def emd_samples(x, y, x_w = None, y_w = None):
    C = ot.utils.euclidean_distances(x, y, squared=True)
    nx = get_backend(x, y)
    p = nx.full((x.shape[0], ), 1/x.shape[0]) if x_w is None else x_w / x_w.sum()
    q = nx.full((y.shape[0], ), 1/y.shape[0]) if y_w is None else y_w / y_w.sum()
    return ot.emd2(p, q, C)

def sinkhorn_divergence(x, y, x_w = None, y_w = None, reg = 1.0):
    # p = np.full((x.shape[0], ), 1/x.shape[0]) if x_w is None else x_w / x_w.sum()
    # q = np.full((y.shape[0], ), 1/y.shape[0]) if y_w is None else y_w / y_w.sum()
    # return ot.bregman.empirical_sinkhorn_divergence(x, y, reg, a = p, b = q)
    p = torch.full((x.shape[0], ), 1/x.shape[0]) if x_w is None else x_w / x_w.sum()
    q = torch.full((y.shape[0], ), 1/y.shape[0]) if y_w is None else y_w / y_w.sum()
    loss = geomloss.SamplesLoss(loss = 'sinkhorn')
    return loss(p, x, q, y)

def energy_distance(x, y, x_w = None, y_w = None):
    nx = get_backend(x, y)
    x_w = nx.full((x.shape[0], ), 1/x.shape[0]) if x_w is None else x_w / x_w.sum()
    y_w = nx.full((y.shape[0], ), 1/y.shape[0]) if y_w is None else y_w / y_w.sum()
    xy=nx.dot(x_w, ot.utils.euclidean_distances(x, y, squared=False) @ y_w)
    xx=nx.dot(x_w, ot.utils.euclidean_distances(x, x, squared=False) @ x_w)
    yy=nx.dot(y_w, ot.utils.euclidean_distances(y, y, squared=False) @ y_w)
    return 2*xy-xx-yy

def energy_distance_paths(x, y):
    return energy_distance(x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1))

def emd_paths(x, y):
    return emd_samples(x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1))

def get_centroid_probs(_x0, get_xt, centroids, n_sample = 10):
    x0 = _x0.repeat_interleave(n_sample, 0)
    idx = torch.arange(_x0.shape[0]).repeat_interleave(n_sample)
    xs_t = get_xt(x0)
    nearest_centroid = (xs_t.unsqueeze(0) - centroids[:, None, :]).norm(dim = 2).argmin(0)
    return torch.vstack([(1.0*torch.nn.functional.one_hot(nearest_centroid[idx == i], centroids.shape[0])).mean(0) for i in range(_x0.shape[0])])
