import ot
import torch


def get_quantile_means_torch(Points,
                             prob=None,
                             n_quantile: int = 20,
                             sigma: float = 10,
                             tor = 0.1):
    """
    Get Quantile Distance Means (Torch version)
    """

    quantile = (1 - 2*tor) * ((torch.arange(1, n_quantile+1, dtype=Points.dtype, device=Points.device) - 0.5) / n_quantile) + tor

    if prob is None:
        P_mean = Points.mean(dim=0)
        P_norm = torch.linalg.norm(Points - P_mean, dim=1)
        P_norm = P_norm / (P_norm.mean() + 1e-6)
        P_norm_q = torch.quantile(P_norm, quantile)
        weight = (P_norm[:, None] - P_norm_q[None, :]) ** 2
        weight = torch.exp(-weight * sigma)
        weight = weight / weight.sum(dim=0)
        Q_means = weight.T @ Points


    else:
        P_mean = Points.T @ prob
        P_norm = torch.linalg.norm(Points - P_mean, dim=1)
        P_norm = P_norm / ((P_norm * prob).sum() + 1e-6)
        P_norm_q = torch.quantile(P_norm, quantile)
        weight = (torch.abs(P_norm[:, None] - P_norm_q[None, :]) ** 2)
        weight = torch.exp(-weight * sigma)
        weight = weight * prob[:, None]
        weight = weight / weight.sum(dim=0)
        Q_means = weight.T @ Points

    return Q_means


def QDOT_torch(X,
               Y,
               a=None,
               b=None,
               n_quantile = 50,
               tor = 0.1,
               sigma = 10,
               initial = True,
               intergal = True,
               EMD = True, 
               Sink_reg = 0.01):
    """Torch Implement for 2-QDOT/IQDOT distance."""

    n, p = X.shape
    m, q = Y.shape

    if initial:
        Xn = (X - X.mean(dim=0)) / X.std(dim=0, unbiased=False)
        Yn = (Y - Y.mean(dim=0)) / Y.std(dim=0, unbiased=False)
    else:
        Xn, Yn = X, Y

    X_Q = get_quantile_means_torch(Xn, a, n_quantile, sigma, tor)
    Y_Q = get_quantile_means_torch(Yn, b, n_quantile, sigma, tor)

    X_QD = torch.linalg.norm(Xn[:, :, None] - X_Q.T[None, :, :], dim=1)
    Y_QD = torch.linalg.norm(Yn[:, :, None] - Y_Q.T[None, :, :], dim=1)
    # or
    # X_QD = torch.cdist(Xn, X_Q, p=2)
    # Y_QD = torch.cdist(Yn, Y_Q, p=2)

    if a is None:
        a = torch.ones(n, dtype=X.dtype, device=X.device) / n
    if b is None:
        b = torch.ones(m, dtype=Y.dtype, device=Y.device) / m
    if(intergal):
        loss_list = [
            ot.wasserstein_1d(X_QD[:, i],
                                Y_QD[:, i],
                                a, b)
            for i in range(n_quantile)
        ]
        loss_arr = torch.stack(loss_list)
        loss = loss_arr.mean()
        return loss
    else:
        C = torch.cdist(X_QD, Y_QD, p=2) ** 2
        C1 = C / C.max()
        if(EMD):
            P = ot.emd(a, b, C1)
        else:
            P = ot.sinkhorn(a, b, C1, reg=Sink_reg)
        loss = (C * P).sum()
        return loss, P
        