import numpy as np
import ot
import scipy as sp

def get_quantile_means(Points, 
                       prob = None,
                       n_quantile = 50, 
                       sigma = 10, 
                       tor = 0.1, 
                       initial = True):
    """
    Get Quantile Distance Means
    """
    quantile = (1 - 2 * tor) * ((np.arange(1, n_quantile + 1) - 0.5) / n_quantile) + tor
    
    if(prob is None):
        if(initial):
            P_mean = Points.mean(axis = 0)
            P_norm = np.linalg.norm(Points - P_mean, axis = 1)
        else:
            P_norm = np.linalg.norm(Points, axis = 1)
        P_norm = P_norm / P_norm.mean()
        P_norm_q = np.quantile(P_norm, quantile)
        weight = np.exp(-sigma * (P_norm[:, np.newaxis] - P_norm_q[np.newaxis, :])**2)
        Q_means = (weight.T @ Points) / weight.sum(axis = 0)[:, np.newaxis]
    else:
        if(initial):
            P_mean = Points.T @ prob
            P_norm = np.linalg.norm(Points - P_mean, axis = 1)
        else:
            P_norm = np.linalg.norm(Points, axis = 1)
        P_norm = P_norm / (P_norm*prob).sum()
        P_norm_q = np.quantile(P_norm, quantile)
        weight = np.abs(P_norm[:, np.newaxis] - P_norm_q[np.newaxis, :])**2
        weight = np.exp(-weight*sigma)
        weight = weight * prob[:, np.newaxis]
        weight = weight / weight.sum(axis = 0)
        Q_means = weight.T @ Points
    return Q_means

def QDOT(X, 
         Y, 
         a = None,
         b = None,
         n_quantile = 50, 
         tor = 0.1,
         sigma = 10, 
         initial = True, 
         scale = True, 
         intergal = False,
         EMD = True, 
         Sink_reg = 0.01):
    """
    Numerical Implement for 2-QDOT/IQDOT distance.
    """
    n, p = np.shape(X)
    m, q = np.shape(Y)
    if(initial):
        Xn = (X - X.mean(axis=0)) / (X - X.mean(axis=0)).std()
        Yn = (Y - Y.mean(axis=0)) / (Y - Y.mean(axis=0)).std()
    else: 
        Xn = X
        Yn = Y
    
    X_Q = get_quantile_means(Xn, a, n_quantile, sigma, tor, initial)
    Y_Q = get_quantile_means(Yn, b, n_quantile, sigma, tor, initial)
    
    X_QD = sp.spatial.distance.cdist(Xn, X_Q)
    Y_QD = sp.spatial.distance.cdist(Yn, Y_Q)

    if(scale):
        X_QD = X_QD / X_QD.max()
        Y_QD = Y_QD / Y_QD.max()

    
    if(a is None):
        a = np.ones(n) / n
    if(b is None):
        b = np.ones(m) / m
    if(intergal):
        if((n == m) and (a == b).all()):
            loss = np.mean((np.sort(X_QD, axis=0) - np.sort(Y_QD, axis=0))**2, axis = 0).mean()
        else:
            loss = np.mean([ot.wasserstein_1d(X_QD[:, i], Y_QD[:, i], a, b) for i in range(n_quantile)])
        return loss
    else:
        C = sp.spatial.distance.cdist(X_QD, Y_QD) **2
        C1 = C / C.max()
        if(EMD):
            P = ot.emd(a, b, C1)
        else:
            P = ot.sinkhorn(a, b, C1, reg=Sink_reg)
        loss = (C*P).sum()
        return loss, P