import numpy as np
from sklearn.preprocessing import normalize
#from scipy.stats import wasserstein_distance
import ot

def sampled_sphere(ndirs,d):
    mean = np.zeros(d)
    identity = np.identity(d)
    U = np.random.multivariate_normal(mean=mean, cov=identity, size=ndirs)
    return normalize(U)


def Tukey_Depth(X, ndirs, U=None):
	""" Compute the score of the classical tukey depth of X w.r.t. X

	Parameters
	----------
	X : Array-like
	        The training set.
	ndirs : int
	    The number of random directions to compute the score.
	    
	Returns
	-------
	Array of float
	    Depth score of each delement of X.
	"""

	# Simulate random directions on the unit sphere.    
	n, d = X.shape
	if U is None:
		U = sampled_sphere(ndirs,d)
	################################################      	
	z = np.arange(1,n+1)
	Depth = np.zeros((n,ndirs))       
	Z = np.matmul(X,U.T)
	A = np.matrix.argsort(Z, axis =0)        
	for k in range(ndirs):
	    Depth[A[:,k],k] = z   
	Depth =  Depth / (n * 1.)  
	Depth_score = np.minimum(Depth, 1 - Depth)

	return np.amin(Depth, axis = 1), Z

def Projection_Depth(X,ndirs, U=None):
	""" Compute the score of the projection depth of X w.r.t. X

	Parameters
	----------
	X : Array-like
	        The training set.
	ndirs : int
	    The number of random directions to compute the score.
	    
	Returns
	-------
	Array of float
	    Depth score of each delement of X.
	"""

	n, d = X.shape
	if U is None:
		U = sampled_sphere(ndirs,d)

	Z = np.matmul(X,U.T)
	Depth = np.zeros((n, ndirs))
	MAD = np.zeros(ndirs)
	med = np.median(Z, axis=0)
	MAD = np.median(np.absolute(Z - med.reshape(1,-1)), axis=0) 
	Depth = np.absolute(Z - med.reshape(1,-1) ) / MAD
	Outlyingness = np.amax(Depth, axis=1)
	return 1 / (1 + Outlyingness), Z



def SW(X,Y,ndirs, p=2, max_sliced=False):
	n, d = X.shape
	U = sampled_sphere(ndirs,d)
	Z = np.matmul(X,U.T)
	Z2 = np.matmul(Y,U.T)
	Sliced = np.zeros(ndirs)
	for k in range(ndirs):
		Sliced[k] = ot.emd2_1d(Z[:,k], Z2[:,k], p=2)
	if (max_sliced == True):
		return (np.max(Sliced)) ** (1 / p)
	else:
		return (np.mean(Sliced))  ** (1 / p)

def Sinkhorn(X,Y, reg=0.01, large_scale=False):
	Dist = ot.dist(X,Y)
	n, d = X.shape
	m, d = Y.shape
	a = np.zeros(n) + 1/n
	b = np.zeros(m) + 1/m

	if large_scale==True:
		pi = ot.stochastic.solve_semi_dual_entropic(a, b, Dist, reg, "SAG",
                                                1000)
		return(np.sum(pi*Dist))
	else:
		return ot.sinkhorn2(a,b,Dist, reg=reg)