from matplotlib import pyplot as plt
import numpy as np 
import statsmodels.api as sm 
from scipy.special import expit
import scipy.integrate as integrate
from joblib import Parallel, delayed
from scipy.integrate import nquad, quad
from statsmodels.nonparametric.kernel_regression import KernelReg
from semipara_helpers import * 
import time
import argparse
import os
import pandas as pd
import pickle

# python epslambda_aipw.py --eps  0.001 0.0001 0.00001 0.000001 0.0000001 --lmbda  0.001 0.0001 0.00001 0.000001 0.0000001
# test
# python epslambda_aipw.py --eps  0.01 0.001 --lmbda  0.01 0.001 

# python epslambda_aipw.py --eps 0.1 0.01 0.001 0.0001 0.00001 --lmbda 0.1 0.01 0.001 0.0001 0.00001
out_dir = 'eps_lambda_aipw'
EXP_NAME = 'AIPW'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

parser = argparse.ArgumentParser(description='evaluate finite differences')
parser.add_argument('--eps', nargs='+', default=0.01)
parser.add_argument('--lmbda', nargs='+', default=0.1)
flags = parser.parse_args()
print('Flags:')
for k,v in sorted(vars(flags).items()):
    print("\t{}: {}".format(k, v))
eps__ = [float(x) for x in flags.eps]; lmbda__ = [float(x) for x in flags.lmbda]
# read arguments
n = 1500
beta = 0.5




def oracle_mean(A,X): 
    return (-5*A*X+ (2*A-1)*3)*(X<0.25)+ (5*A*X+3)*(X>0.25)*(X<0.5) + -5*A*X*(X>0.5)*(X<0.75)+ 5*A*X*(X>0.75)*(X<1)


bw = [0.05,0.05]


def tilde_phi_eps_sample(operturb,n_mc,xbw, eps,lmbda): # x_,a_,y_ is target
    kl = np.vectorize(gsn_kl)
    [x_,a_,y_]=operturb
    x__ = np.random.uniform(low=x_-lmbda,high=x_+lmbda,size=n_mc)
    return np.mean([ int_gnd_tildekreps(xbw,X,Y,A,x,operturb,eps)+ int_gnd_resid_term(xbw,X,Y,A,x,operturb,eps) for x in x__ ])

[Xlong,Along,Ylong] = draw_data(n,beta,oracle_mean)
final_res = dict()
final_res['X']=Xlong; final_res['A']= Along; final_res['Y'] = Ylong

# ns = [50,100]
ns = np.linspace(300,1200,10)#np.logspace(1.7,3,10)
for ind_n,n__ in enumerate(ns): 
    print('n',n__)
    n__=int(n__)
    X=Xlong[0:n__]; A=Along[0:n__]; Y=Ylong[0:n__]

    [pyxa1,pxa1,pyxa0,pxa0,kr1,kr0,px] = fit_nuisances(X,A,Y,bw)
    Y_=Y.reshape((-1,1));A_=A.reshape((-1,1));X_=X.reshape((-1,1))
    # AIPW 
    def e1(x): 
        return pxa1.pdf(x)*np.mean(A)/px.pdf(x)
    phis = (A ==1)*(Y-kr1.fit(X)[0])/e1(X) + kr1.fit(X)[0] 
    xbw = bw[0]
#
# tests 
# start = time.time()
# ind = 2
# x_ = X[ind]; a_=A[ind]; y_=Y_[ind]
# tdphi = tilde_phi_eps_sample([x_,a_,y_],1000,xbw,eps,lmbda)
# print(time.time()-start)
# print('mc unif int', tdphi)

# print('phi',phis[ind])

# print(res - phis[0:ntest])
# ntest = 10 
    fn= out_dir + '/'+'AIPW_'+'n_'+str(n__)+'_'+ str(int(time.time()*1e6)) + '.p'
    data = {'X':X,'A':A,'Y':Y}
    pickle.dump(data,open(fn,'wb'))
    for eps in eps__: 
        for lmbda in lmbda__: 
            print((eps,lmbda))
            res = Parallel(n_jobs=-1,verbose=30)(delayed(tilde_phi_eps_sample)([X[i],A[i],Y[i]],1000,xbw,eps,lmbda) for i in range(n__))
            EXP_NAME = 'AIPW_'+'n_'+str(n)+'_eps_'+str(eps)+'_lmbda_'+str(lmbda)+'_'
            outfile = out_dir + '/'+EXP_NAME + str(int(time.time()*1e6)) + '.p'
            result = dict({
                'phis_epsformula': np.array(res),
                'eps': eps, 'lmbda': lmbda, 'bw': bw,
                })
            pickle.dump(result, open(outfile,'wb'))
            final_res[(ind_n,eps,lmbda)] = result
            # continually overwrite final_res
            outfile = out_dir + '/'+'final_AIPW_overns_rerun' + '.p'
            pickle.dump(final_res, open(outfile,'wb'))

pickle.dump(final_res, open(outfile,'wb'))
