import torch
import numpy as np
import os

SQRT_CONST = 1e-10

# def safe_sqrt(x, lbound=SQRT_CONST):
#     ''' Numerically safe version of TensorFlow sqrt '''
#     return torch.sqrt(tf.clip_by_value(x, lbound, np.inf))

def mmd2_lin_dfr(X, t, p):
    ''' Linear MMD '''
    # t: treatment variable [B, T, 4]
    # X: data               [B, T, 1]
    # p: probability        [T, 4]

    mmd = 0
    for cur_ts in range(0, t.shape[1]):
        list_X_mean = []
        for cur_treatment in range(0, t.shape[-1]):
            idx_cur_t = (t[:, cur_ts, :] == cur_treatment).squeeze()[:, 0]
            list_X_mean.append(torch.mean(X[idx_cur_t, cur_ts, :], dim=0))
        for i in range(1, len(list_X_mean)):
            mmd += torch.sum(torch.square(2.0*p[cur_ts, i]*list_X_mean[i] - 2.0*p[cur_ts, i-1]*list_X_mean[i-1]))
    
    return mmd

    # it = (t>0).nonzero().squeeze()[:, 0]   # index of treatment group
    # ic = (t<1).nonzero().squeeze()[:, 0]   # index of controlled group

    # Xc = X[ic, :]                          # samples of controlled group
    # Xt = X[it, :]                          # samples of treatment group

    # mean_control = torch.mean(Xc, dim=0)   # mean of controlled group
    # mean_treated = torch.mean(Xt, dim=0)   # mean of treatment group

    # mmd = torch.sum(torch.square(2.0*p*mean_treated - 2.0*(1.0-p)*mean_control))

    # return mmd

def lindisc(X, t, p):
    ''' Linear MMD '''
    # t: treatment variable
    # X: data 
    # p: portion of controlled groups

    it = (t>0).nonzero().squeeze()[:, 0]   # index of treatment group
    ic = (t<1).nonzero().squeeze()[:, 0]   # index of controlled group

    Xc = X[ic, :]                          # samples of controlled group
    Xt = X[it, :]                          # samples of treatment group

    mean_control = torch.mean(Xc, dim=0)   # mean of controlled group
    mean_treated = torch.mean(Xt, dim=0)   # mean of treatment group

    c = torch.square(2*p-1)*0.25
    f = torch.sign(p-0.5)

    mmd = torch.sum(torch.square(p*mean_treated - (1-p)*mean_control))
    mmd = f*(p-0.5) + torch.sqrt(c + mmd)

    return mmd

def mmd2_lin(X,t,p):
    ''' Linear MMD '''

    it = (t>0).nonzero().squeeze()[:,0]
    ic = (t<1).nonzero().squeeze()[:,0]

    Xc = X[ic, :]
    Xt = X[it, :]

    mean_control = torch.mean(Xc, dim=0)
    mean_treated = torch.mean(Xt, dim=0)

    mmd = torch.sum(torch.square(2.0*p*mean_treated - 2.0*(1.0-p)*mean_control))

    return mmd

def mmd2_rbf(X,t,p,sig):
    """ Computes the l2-RBF MMD for X given t """

    it = (t>0).nonzero().squeeze()[:,0]
    ic = (t<1).nonzero().squeeze()[:,0]

    Xc = X[ic, :]
    Xt = X[it, :]

    Kcc = torch.exp(-pdist2sq(Xc, Xc)/torch.square(sig))
    Kct = torch.exp(-pdist2sq(Xc,Xt)/torch.square(sig))
    Ktt = torch.exp(-pdist2sq(Xt,Xt)/torch.square(sig))

    m = (Xc.size()[0]).float()
    n = (Xt.size()[0]).float()

    mmd = torch.square(1.0-p)/(m*(m-1.0))*(torch.sum(Kcc)-m)
    mmd = mmd + torch.square(p)/(n*(n-1.0))*(torch.sum(Ktt)-n)
    mmd = mmd - 2.0*p*(1.0-p)/(m*n)*torch.sum(Kct)
    mmd = 4.0*mmd

    return mmd

def pdist2sq(X,Y):
    """ Computes the squared Euclidean distance between all pairs x in X, y in Y """
    C = -2*torch.matmul(X, Y.T)
    nx = torch.sum(torch.square(X), dim=1, keepdim=True)
    ny = torch.sum(torch.square(Y), dim=1, keepdim=True)
    D = (C + ny.T) + nx
    return D

def pdist2(X,Y):
    """ Returns the tensorflow pairwise distance matrix """
    return torch.sqrt(pdist2sq(X,Y))

def pop_dist(X,t):
    it = (t>0).nonzero().squeeze()[:,0]
    ic = (t<1).nonzero().squeeze()[:,0]
    Xc = X[ic, :]
    Xt = X[it, :]
    nc = (Xc.size()[0]).float()
    nt = (Xt.size()[0]).float()

    ''' Compute distance matrix'''
    M = pdist2(Xt, Xc)
    return M

# def wasserstein(X,t,p,lam=10,its=10,sq=False,backpropT=False):
#     """ Returns the Wasserstein distance between treatment groups """

#     it = (t>0).nonzero().squeeze()[:,0]
#     ic = (t<1).nonzero().squeeze()[:,0]
#     Xc = tf.gather(X,ic)
#     Xt = tf.gather(X,it)
#     nc = tf.to_float(tf.shape(Xc)[0])
#     nt = tf.to_float(tf.shape(Xt)[0])

#     ''' Compute distance matrix'''
#     if sq:
#         M = pdist2sq(Xt,Xc)
#     else:
#         M = safe_sqrt(pdist2sq(Xt,Xc))

#     ''' Estimate lambda and delta '''
#     M_mean = tf.reduce_mean(M)
#     M_drop = tf.nn.dropout(M,10/(nc*nt))
#     delta = tf.stop_gradient(tf.reduce_max(M))
#     eff_lam = tf.stop_gradient(lam/M_mean)

#     ''' Compute new distance matrix '''
#     Mt = M
#     num_row = tf.shape(M[:,0:1])[0]
#     num_col = tf.shape(M[0:1,:])[1]
#     row = delta*tf.ones([1, num_col])
#     col = tf.concat([delta*tf.ones([num_row, 1]), tf.zeros((1,1))], axis=0)
#     Mt = tf.concat([M, row], axis=0)
#     Mt = tf.concat([Mt, col], axis=1)

#     ''' Compute marginal vectors '''
#     a = tf.concat([p*tf.ones(tf.shape(tf.where(t>0)[:,0:1]))/nt, (1-p)*tf.ones((1,1))], axis=0)
#     b = tf.concat([(1-p)*tf.ones(tf.shape(tf.where(t<1)[:,0:1]))/nc, p*tf.ones((1,1))], axis=0)

#     ''' Compute kernel matrix'''
#     Mlam = eff_lam*Mt
#     K = tf.exp(-Mlam) + 1e-6 # added constant to avoid nan
#     U = K*Mt
#     ainvK = K/a

#     u = a
#     for i in range(0,its):
#         u = 1.0/(tf.matmul(ainvK,(b/tf.transpose(tf.matmul(tf.transpose(u),K)))))
#     v = b/(tf.transpose(tf.matmul(tf.transpose(u),K)))

#     T = u*(tf.transpose(v)*K)

#     if not backpropT:
#         T = tf.stop_gradient(T)

#     E = T*Mt
#     D = 2*tf.reduce_sum(E)

#     return D, Mlam

def simplex_project(x,k):
    """ Projects a vector x onto the k-simplex """
    d = x.shape[0]
    mu = np.sort(x,axis=0)[::-1]
    nu = (np.cumsum(mu)-k)/range(1,d+1)
    I = [i for i in range(0,d) if mu[i]>nu[i]]
    theta = nu[I[-1]]
    w = np.maximum(x-theta,0)
    return w
