import torch
import numpy as np
import os

SQRT_CONST = 1e-10

def safe_sqrt(x, lbound=SQRT_CONST):
    ''' Numerically safe version of TensorFlow sqrt '''
    return torch.sqrt(tf.clip_by_value(x, lbound, np.inf))

def lindisc(X, t, p):
    ''' Linear MMD '''
    # t: treatment variable
    # X: data
    # p: portion of treatment and controlled groups

    it = (t>0).nonzero().squeeze()[:, 0]   # index of treatment group
    ic = (t<1).nonzero().squeeze()[:, 0]   # index of controlled group

    Xc = X[ic, :]                          # samples of controlled group
    Xt = X[it, :]                          # samples of treatment group

    mean_control = torch.mean(Xc, dim=0)   # mean of controlled group
    mean_treated = torch.mean(Xt, dim=0)   # mean of treatment group

    c = torch.square(2*p-1)*0.25
    f = torch.sign(p-0.5)

    mmd = torch.sum(torch.square(p*mean_treated - (1-p)*mean_control))
    mmd = f*(p-0.5) + torch.sqrt(c + mmd)

    return mmd

def mmd2_lin(X,t,p):
    ''' Linear MMD '''

    it = (t>0).nonzero().squeeze()[:,0]
    ic = (t<1).nonzero().squeeze()[:,0]

    Xc = X[ic, :]
    Xt = X[it, :]

    mean_control = torch.mean(Xc, dim=0)
    mean_treated = torch.mean(Xt, dim=0)

    mmd = torch.sum(torch.square(2.0*p*mean_treated - 2.0*(1.0-p)*mean_control))

    return mmd

def mmd2_rbf(X,t,p,sig):
    """ Computes the l2-RBF MMD for X given t """

    it = (t>0).nonzero().squeeze()[:,0]
    ic = (t<1).nonzero().squeeze()[:,0]

    Xc = X[ic, :]
    Xt = X[it, :]

    Kcc = torch.exp(-pdist2sq(Xc, Xc)/torch.square(sig))
    Kct = torch.exp(-pdist2sq(Xc,Xt)/torch.square(sig))
    Ktt = torch.exp(-pdist2sq(Xt,Xt)/torch.square(sig))

    m = (Xc.size()[0]).float()
    n = (Xt.size()[0]).float()

    mmd = torch.square(1.0-p)/(m*(m-1.0))*(torch.sum(Kcc)-m)
    mmd = mmd + torch.square(p)/(n*(n-1.0))*(torch.sum(Ktt)-n)
    mmd = mmd - 2.0*p*(1.0-p)/(m*n)*torch.sum(Kct)
    mmd = 4.0*mmd

    return mmd

def pdist2sq(X,Y):
    """ Computes the squared Euclidean distance between all pairs x in X, y in Y """
    C = -2*torch.matmul(X, Y.T)
    nx = torch.sum(torch.square(X), dim=1, keepdim=True)
    ny = torch.sum(torch.square(Y), dim=1, keepdim=True)
    D = (C + ny.T) + nx
    return D

def pdist2(X,Y):
    """ Returns the tensorflow pairwise distance matrix """
    return torch.sqrt(pdist2sq(X,Y))

def pop_dist(X,t):
    it = (t>0).nonzero().squeeze()[:,0]
    ic = (t<1).nonzero().squeeze()[:,0]
    Xc = X[ic, :]
    Xt = X[it, :]
    nc = (Xc.size()[0]).float()
    nt = (Xt.size()[0]).float()

    ''' Compute distance matrix'''
    M = pdist2(Xt, Xc)
    return M

# def wasserstein(X,t,p,lam=10,its=10,sq=False,backpropT=False):
#     """ Returns the Wasserstein distance between treatment groups """

#     it = (t>0).nonzero().squeeze()[:,0]
#     ic = (t<1).nonzero().squeeze()[:,0]
#     Xc = tf.gather(X,ic)
#     Xt = tf.gather(X,it)
#     nc = tf.to_float(tf.shape(Xc)[0])
#     nt = tf.to_float(tf.shape(Xt)[0])

#     ''' Compute distance matrix'''
#     if sq:
#         M = pdist2sq(Xt,Xc)
#     else:
#         M = safe_sqrt(pdist2sq(Xt,Xc))

#     ''' Estimate lambda and delta '''
#     M_mean = tf.reduce_mean(M)
#     M_drop = tf.nn.dropout(M,10/(nc*nt))
#     delta = tf.stop_gradient(tf.reduce_max(M))
#     eff_lam = tf.stop_gradient(lam/M_mean)

#     ''' Compute new distance matrix '''
#     Mt = M
#     num_row = tf.shape(M[:,0:1])[0]
#     num_col = tf.shape(M[0:1,:])[1]
#     row = delta*tf.ones([1, num_col])
#     col = tf.concat([delta*tf.ones([num_row, 1]), tf.zeros((1,1))], axis=0)
#     Mt = tf.concat([M, row], axis=0)
#     Mt = tf.concat([Mt, col], axis=1)

#     ''' Compute marginal vectors '''
#     a = tf.concat([p*tf.ones(tf.shape(tf.where(t>0)[:,0:1]))/nt, (1-p)*tf.ones((1,1))], axis=0)
#     b = tf.concat([(1-p)*tf.ones(tf.shape(tf.where(t<1)[:,0:1]))/nc, p*tf.ones((1,1))], axis=0)

#     ''' Compute kernel matrix'''
#     Mlam = eff_lam*Mt
#     K = tf.exp(-Mlam) + 1e-6 # added constant to avoid nan
#     U = K*Mt
#     ainvK = K/a

#     u = a
#     for i in range(0,its):
#         u = 1.0/(tf.matmul(ainvK,(b/tf.transpose(tf.matmul(tf.transpose(u),K)))))
#     v = b/(tf.transpose(tf.matmul(tf.transpose(u),K)))

#     T = u*(tf.transpose(v)*K)

#     if not backpropT:
#         T = tf.stop_gradient(T)

#     E = T*Mt
#     D = 2*tf.reduce_sum(E)

#     return D, Mlam

def simplex_project(x,k):
    """ Projects a vector x onto the k-simplex """
    d = x.shape[0]
    mu = np.sort(x,axis=0)[::-1]
    nu = (np.cumsum(mu)-k)/range(1,d+1)
    I = [i for i in range(0,d) if mu[i]>nu[i]]
    theta = nu[I[-1]]
    w = np.maximum(x-theta,0)
    return w
