from matplotlib import pyplot as plt
import numpy as np 
import statsmodels.api as sm 
from scipy.special import expit
import scipy.integrate as integrate
from statsmodels.nonparametric.kernel_regression import KernelReg


def gsn_kl(h, Xi, x):
    """
    Gaussian Kernel for continuous variables
    h : 1-D ndarray, shape (K,)
        The bandwidths used to estimate the value of the kernel function.
    Xi : 1-D ndarray, shape (K,)
        The value of the training set.
    x : 1-D ndarray, shape (K,)
        The value at which the kernel density is being estimated.-------
    kernel_value : ndarray, shape (nobs, K)
        The value of the kernel function at each training point for each var.
    """
    return (1. / np.sqrt(2 * np.pi)) * np.exp(-(Xi - x)**2 / (h**2 * 2.))

def box_kl(h, Xi, x):
    return 1. / 2 * (np.abs((Xi-x)*1./h) < 1)
#     return 1. / (2*h) * (np.abs(Xi-x) < 2*h)





def get_densities(Y_,A_,X_,bw='cv_ls'): 
    pyxa1 = sm.nonparametric.KDEMultivariate(data=[Y_[A_==1],X_[A_==1]], \
            var_type='cc', bw=bw)
    if bw!='cv_ls':
        pxa1 = sm.nonparametric.KDEMultivariate(data=[X_[A_==1]], \
            var_type='c', bw=[bw[0]])
    else:
        pxa1 = sm.nonparametric.KDEMultivariate(data=[X_[A_==1]], \
            var_type='c', bw=bw)
#     pya1x = KernelDensity(bandwidth=bw).fit(np.hstack([Y_[A==1],X_[A==1]]))
#     pa1x = KernelDensity(bandwidth=bw).fit(X_[A==1])
#     tpya1x = np.exp(pya1x.score_samples(np.hstack([Y_,A_,X_])))
#     tpa1x = np.exp(pa1x.score_samples(np.hstack([A_,X_])))
    return [pyxa1,pxa1]




def fit_nuisances(X,A,Y,bw, bwcv = False): 
    bnds = [[min(Y),max(Y)],[0,1]]
    if bwcv: 
        bw = 'cv_ls'
    else: 
        xbw = bw[1]
        px = sm.nonparametric.KDEMultivariate(data=X, \
                    var_type='c', bw=[bw[0]])
        kr1 = KernelReg(Y[A==1],X[A==1],'c',reg_type='lc',bw = [bw[0]])
        kr0 = KernelReg(Y[A==0],X[A==0],'c',reg_type='lc',bw = [bw[0]])
    # bw = 'cv_ls'
    [pyxa1,pxa1]=get_densities(Y,A,X,bw)
    [pyxa0,pxa0]=get_densities(Y,1-A,X,bw)
    return [pyxa1,pxa1,pyxa0,pxa0,kr1,kr0,px]


def e1_cons(pxa1,px,A): 
    def e1(x): 
        return pxa1.pdf(x)*np.mean(A)/px.pdf(x)
    return e1

def replicate_dgp(noracle,nns,beta,oracle_mean, bw, dgp):
    bw = np.asarray(bw)
    X,A,Y=dgp(noracle,beta,oracle_mean)
    Y_=Y.reshape((-1,1));A_=A.reshape((-1,1));X_=X.reshape((-1,1))
    nmeths=3
    estimates = np.zeros([len(nns),nmeths])
    for (i_n,n_) in enumerate(nns): 
        bw = (np.asarray([0.1,.1]))*n_**-0.2
#         print(bw)
        [pyxa1,pxa1,pyxa0,pxa0,kr1,kr0,px] = fit_nuisances(X_[0:n_,:],A_[0:n_,:],Y_[0:n_,:],bw)
        e1 = e1_cons(pxa1,px,A)
        dm=np.mean(kr1.fit(X)[0])
        ipw=np.mean(Y*(A==1)/e1(X))
        
        aipw = np.mean((A ==1)*(Y-kr1.fit(X)[0])/e1(X) + kr1.fit(X)[0] )
        estimates[i_n,:] = [dm,ipw,aipw]
    return estimates 

def int_gnd_tildekreps(xbw,X,Y,A,x,operturb,eps): 
    # reads eps,lmbda from scope! 
    [x_,a_,y_] = operturb
    KXx = kl(xbw, X[A==1],x)
    # smoothdel_x = box_kl(lmbda,x,x_)
    denom = ((1-eps)*np.sum(KXx)*np.mean(A)+eps*(a_==1) )
    num =  ((1-eps)*((KXx.dot(Y[A==1])))*np.mean(A) + eps * y_ *(a_==1))
    return num/denom

def int_gnd_resid_term(xbw,X,Y,A,x,operturb,eps): 
    # reads eps,lmbda and [x_,a_,y_] from scope! 
    [x_,a_,y_] = operturb
    KXx_A1 = kl(xbw, X[A==1],x)
    KXx = kl(xbw, X,x)
#     smoothdel_x = box_kl(lmbda,x,x_)
    krx = Y[A==1].dot(KXx_A1)/np.sum(KXx_A1)
    denom = ((1-eps)*np.mean(KXx_A1)*np.mean(A)+eps*(a_==1) )
    num =  (1-eps) * np.mean(KXx) * (a_==1) * (y_ - krx)
#     print(((1-eps)*np.mean(KXx_A1)*np.mean(A)+eps*(a_==1) )/np.mean(KXx) )
#     print(((1-eps)*np.mean(KXx_A1)*np.mean(A)+eps*(a_==1) )/np.mean(KXx) )
    return num/denom

def draw_data(n,beta,oracle_mean, bwcv = False): 
    X = np.random.uniform(size=n)
    A = (np.random.uniform(size=n) < expit(np.sin(20*X)+0.5)).astype(int) #np.sin(20*X)
    if min(sum(A), sum(1-A)) == 0: 
        A = (np.random.uniform(size=n) < expit(np.sin(20*X)+0.5)).astype(int)
    #     return 5*np.sin(10*(2*A-1)*X)*(X<0.5) + A*X*(X>0.5)*(X<0.75)+ 5*A*X*(X>0.75)*(X<1)
    # Y = (2*A-1)*X+(2*A-1)-2*np.sin(0.25*(2*A-1)*X) + np.random.normal(size=n)
    Y = oracle_mean(A,X) + np.random.normal(size=n)
    # Y = A*X + X  + np.random.normal(size=n)
    return [X,A,Y]

# unit test 
def int_py_x(py_x,ygrid,x): 
    # int y p(y|x) dy 
    return ygrid.dot(py_x.pdf(ygrid,x*np.ones(len(ygrid)))) / len(ygrid)
def int_py_x_(x,py_x,a,b): 
    integrand = lambda y: y*py_x.pdf([y],[x])
    [val,eps]= integrate.quad(integrand, a,b)
#     print(eps)
    return val

kl = np.vectorize(gsn_kl)



# Numerical error from smoothed delta
# def FNL_Peps(eps,lmbda,xtarget,ytarget, px,pxa1,pyxa1, bnds, full_output=False): 
#     s_delta_x = lambda xint: gsn_kl(lmbda,xint,xtarget)
#     s_delta_yx = lambda yint,xint: gsn_kl(lmbda,yint,ytarget)*gsn_kl(lmbda,xint,xtarget)
#     # density evaluation (at x_int, y_int)
#     p_epsx = lambda x: ((1-eps)*pxa1.pdf([x])+eps*s_delta_x(x))
#     p_eps_yx = lambda y,x:  y *((1-eps)*pyxa1.pdf([y,x]) + eps * s_delta_yx(y,x))
    
#     # nquad arguments: x0, xn; integration carried out in order s.t. 
#     # x0 is innermost integral, xn is outermost

#     integrand = lambda y,x: px.pdf([x]) / p_epsx(x) * p_eps_yx(y,x)
#     bnds = [[min(Y),max(Y)],[0,1]]

#     output = nquad(integrand,bnds,full_output=full_output)
#     return output


