''' Main functions for estimation. '''

import numpy as np
import sys

from scipy.optimize import linear_sum_assignment
from scipy.special import expit, logit
from scipy.stats import bernoulli

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestRegressor 

import utils
import reduced_rank_regressor


def perturbation_alg_bootstrap(eval_pol=None, h=1, 
                               W_samples=None, Z_samples=None, 
                               Y_samples=None, outcome_degree=2, test_W_samples=None,
                              mu_degree=2, S=10, method='WDM',right_node_num=200, bootstrap=True):
    
    '''
    Main perturbation algorithm (Alg 1 in paper) using bootstrap with replacement.
    
    eval_pol: pi(W) --> prob_0, prob_1 that is to be evaluated
    S: number of draws
    size of draw: N/((1+h)**2)
    
    method: direct, WDM, GRDR
     
    '''
    
    num_samples = Z_samples.shape[0]
    m = test_W_samples.shape[0]
    draw_size = int(num_samples/((1+h)**2))
    
     
    # Compute v hat zero
    if method == 'direct':
        
        outcome_model_all = fit_outcome(W_samples, Z_samples, Y_samples, 
                                        outcome_degree=outcome_degree)
        
        v_hat_zero = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              outcome_model=outcome_model_all)
            
        
    elif method == 'WDM':
            
        prop_model_all = fit_prop_score(W_samples=W_samples, Z_samples=Z_samples)
        mu_0_all,mu_1_all = fit_weighted_direct_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                                      mu_degree=mu_degree, prop_model=prop_model_all)
        
        v_hat_zero = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_all, mu_1=mu_1_all, prop_model=prop_model_all)
            
    elif method == 'GRDR':    

        prop_model_all = fit_prop_score(W_samples=W_samples, Z_samples=Z_samples)
        mu_0_all, mu_1_all = fit_robust_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                             mu_degree=mu_degree, prop_model=prop_model_all)

        v_hat_zero = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_all, mu_1=mu_1_all, prop_model=prop_model_all)
    
    if bootstrap == False:
        return v_hat_zero
        
        
    # Bootstrap
    
    v_hat_list = []
    
    for num_draw in range(S):
        
        indices = np.random.choice(num_samples, draw_size, replace=False)
        
        W_samples_fold = W_samples[indices]
        Z_samples_fold  = Z_samples[indices]
        Y_samples_fold  = Y_samples[indices]
        
        if method == 'direct':
            
            outcome_model_fold = fit_outcome(W_samples_fold, Z_samples_fold, Y_samples_fold, 
                                        outcome_degree=outcome_degree)
            
            v_hat_fold = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              outcome_model=outcome_model_fold)
            
        
        elif method == 'WDM':
            
            prop_model_fold = fit_prop_score(W_samples=W_samples_fold, Z_samples=Z_samples_fold)
            mu_0_fold, mu_1_fold = fit_weighted_direct_mu(Z_samples=Z_samples_fold, W_samples=W_samples_fold, Y_samples=Y_samples_fold, 
                                                          mu_degree=mu_degree, prop_model=prop_model_fold)
            
            v_hat_fold = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_fold, mu_1=mu_1_fold, prop_model=prop_model_fold)
            
        elif method == 'GRDR':    
            
            prop_model_fold = fit_prop_score(W_samples=W_samples_fold, Z_samples=Z_samples_fold)
            mu_0_fold, mu_1_fold = fit_robust_mu(Z_samples=Z_samples_fold, W_samples=W_samples_fold, Y_samples=Y_samples_fold, 
                                                 mu_degree=mu_degree, prop_model=prop_model_fold)

            v_hat_fold = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_fold, mu_1=mu_1_fold, prop_model=prop_model_fold)
            
        
        v_hat_list.append(v_hat_fold)
         
    
    rho = v_hat_zero - (1/float(h)) * (v_hat_zero - np.mean(v_hat_list))
    
    return rho
    

def eval_value(eval_pol=None, right_node_num=200, method='WDM',
               test_W_samples=None, outcome_degree=2, mu_degree=2,
              outcome_model=None, mu_0=None, mu_1=None, prop_model=None):
    
    '''
    evaluate v hat for a given test set, given policy, given estimator.
    
    '''
    
    pi_0_list = []
    pi_1_list = []
    mu_1_list = []
    mu_0_list = []
    
    for i, test_w_sample in enumerate(test_W_samples):
        
        pi_0_list.append(eval_pol.get_pi_t(W=test_w_sample, Z=0))
        pi_1_list.append(eval_pol.get_pi_t(W=test_w_sample, Z=1))

        if method == 'direct':

            mu_1_list.append(direct_mu(Z=1, W=test_w_sample, theta_reg_model=outcome_model, 
                          degree=outcome_degree))
            mu_0_list.append(direct_mu(Z=0, W=test_w_sample, theta_reg_model=outcome_model, 
                      degree=outcome_degree))

        elif method == 'WDM':

            mu_1_list.append(weighted_direct_mu(W=test_w_sample, Z=1, mu_1_model=mu_1, 
                                                mu_0_model=mu_0, mu_degree=mu_degree))
            mu_0_list.append(weighted_direct_mu(W=test_w_sample, Z=0, mu_1_model=mu_1, 
                                                mu_0_model=mu_0, mu_degree=mu_degree))

        elif method == 'GRDR':

            mu_1_list.append(robust_mu(W=test_w_sample, Z=1, mu_1_model=mu_1, mu_0_model=mu_0, 
                                       mu_degree=mu_degree, prop_model=prop_model))
            mu_0_list.append(robust_mu(W=test_w_sample, Z=0, mu_1_model=mu_1, mu_0_model=mu_0, 
                                       mu_degree=mu_degree, prop_model=prop_model))

    
    pi_0_list = np.multiply(np.array(pi_0_list), np.array(mu_0_list))
    
    pi_1_list = np.multiply(np.array(pi_1_list), np.array(mu_1_list))

    test_sample_costs = np.add(pi_0_list, pi_1_list)

    cost_matrix = np.transpose([test_sample_costs] * right_node_num)
    v_hat, _, _ = utils.compute_matching(cost_matrix)

    return v_hat
    
    

def robust_mu(W=None, Z=None, mu_1_model=None, mu_0_model=None, mu_degree=2, prop_model=None):
    
    '''
    evaluate direct method hat mu for given W and Z
    '''
    
    poly = PolynomialFeatures(degree=mu_degree)
    test_features = poly.fit_transform(np.array([[W]]))
    
    prop_scores = prop_model.predict_proba([[W]]) #1 by 2 array
    
    add_features = np.array([[float(1-Z)/(prop_scores[0][0]), float(Z)/(prop_scores[0][1])]])
    
    test_features = np.append(test_features, add_features, 1)
    
    if Z==0:
        mu = mu_0_model.predict(test_features)[0]
    else:
        mu = mu_1_model.predict(test_features)[0]
    
    return mu

def fit_robust_mu(Z_samples=None, W_samples=None, Y_samples=None, mu_degree=2, 
                        prop_model=None, printmodel=False):
    
    '''
    Doubly robust estimator.
    
    fit hat mu_t(W) for t=0, 1 given a propensity score model by a polynomial regression on W.
    
    
    Return: two models hat mu_t(W) for t=0, 1
    '''
        
    poly = PolynomialFeatures(degree=mu_degree)
    features = poly.fit_transform(W_samples.reshape(-1,1))
    
    prop_scores = prop_model.predict_proba(W_samples.reshape(-1,1))
        
    e_1_feature = np.divide(Z_samples, prop_scores[:,1])
    e_0_feature = np.divide(np.subtract(1,Z_samples), prop_scores[:,0])
    
    features = np.append(features, e_0_feature.reshape(-1,1),1)
    features = np.append(features, e_1_feature.reshape(-1,1),1)

    
    Z0_idx = np.where(Z_samples == 0)[0]
    Z1_idx = np.where(Z_samples == 1)[0]

    #generate weights
    
    weights_Z1 = []
    weights_Z0 = []
    
    for i, Z_sample in enumerate(Z_samples):
        if Z_sample == 1:
            weights_Z1.append(float(1)/float(prop_scores[i][1]))
        else:
            weights_Z0.append(float(1)/float(prop_scores[i][0]))
    
    mu_1 = LinearRegression()
    mu_1.fit(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    r_sq_1 = mu_1.score(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    if printmodel:
        print('\nGRDR R^2 score (Z=1):', r_sq_1)
        print('GRDR intercept (Z=1):', mu_1.intercept_)
        print('GRDR slope (Z=1):', mu_1.coef_)

    mu_0 = LinearRegression()
    mu_0.fit(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    r_sq_0 = mu_0.score(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    if printmodel: 
        print('\nGRDR R^2 score (Z=0):', r_sq_0)
        print('GRDR intercept (Z=0):', mu_0.intercept_)
        print('GRDR slope (Z=0):', mu_0.coef_)

    return mu_0, mu_1
    


def weighted_direct_mu(W=None, Z=None, mu_1_model=None, mu_0_model=None, mu_degree=2):
    
    '''
    evaluate direct method hat mu for given W and Z
    '''
    
    poly = PolynomialFeatures(degree=mu_degree)
    test_features = poly.fit_transform(np.array([[W]]))
    
    if Z==0:
        mu = mu_0_model.predict(test_features)[0]
    else:
        mu = mu_1_model.predict(test_features)[0]
    
    return mu

def fit_weighted_direct_mu(Z_samples=None, W_samples=None, Y_samples=None, mu_degree=2, 
                        prop_model=None, printmodel=False):
    
    '''
    fit hat mu_t(W) for t=0, 1 given a propensity score model by a polynomial regression on W.
    
    
    Return: two models hat mu_t(W) for t=0, 1
    '''
        
    poly = PolynomialFeatures(degree=mu_degree)
    features = poly.fit_transform(W_samples.reshape(-1,1))
    
    Z0_idx = np.where(Z_samples == 0)[0]
    Z1_idx = np.where(Z_samples == 1)[0]

    #generate weights
    prop_scores = prop_model.predict_proba(W_samples.reshape(-1,1))
    
    weights_Z1 = []
    weights_Z0 = []
    
    for i, Z_sample in enumerate(Z_samples):
        if Z_sample == 1:
            weights_Z1.append(float(1)/float(prop_scores[i][1]))
        else:
            weights_Z0.append(float(1)/float(prop_scores[i][0]))
    
    mu_1 = LinearRegression()
    mu_1.fit(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    r_sq_1 = mu_1.score(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    if printmodel:
        
        print('\nWDM R^2 score (Z=1):', r_sq_1)
        print('WDM intercept (Z=1):', mu_1.intercept_)
        print('WDM slope (Z=1):', mu_1.coef_)

    mu_0 = LinearRegression()
    mu_0.fit(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    r_sq_0 = mu_0.score(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    if printmodel:
        print('\nWDM R^2 score (Z=0):', r_sq_0)
        print('WDM intercept (Z=0):', mu_0.intercept_)
        print('WDM slope (Z=0):', mu_0.coef_)

    return mu_0, mu_1
    


def fit_prop_score(W_samples=None, Z_samples=None, printmodel=False):
    
    '''
    fit a propensity score model to the given data
    
    Change penalty to 'none' to turn off reg.
    '''
    
    clf = LogisticRegression(random_state=0, penalty='l2').fit(W_samples.reshape(-1,1), Z_samples.reshape(-1,1))
    
    if printmodel:
 
        print('score of fitted prop model: ', clf.score(W_samples.reshape(-1,1), Z_samples.reshape(-1,1)))
    
    return clf
    
    
def direct_mu(Z=None, W=None, theta_reg_model=None, degree=2):
    
    '''
    evaluate direct method hat mu for given W and Z
    '''
    
    poly = PolynomialFeatures(degree=degree)
    test_features = poly.fit_transform(np.array([[W, Z]]))
    
    mu = theta_reg_model.predict(test_features)[0]
    
    return mu

class oracle_outcome_model:
    
    def __init__(self, true_theta=None):
    
        self.true_theta = true_theta
    
    def predict(self, test_features):
        
#         preds = []
        
#         for test_point in test_features:
        
        pred = np.dot(self.true_theta, test_features.reshape(-1,1))
        
        return pred
    
class my_logistic_policy:
    
    def __init__(self, phi=None, b=None):
    
        self.phi = phi
        self.b = b
    
    def get_pi_t(self, W=None, Z=None):
        
        if Z==0:
            prob = 1 - expit(self.phi * W + self.b)
        else:
            prob = expit(self.phi * W + self.b)
        
        return prob



#def fit_outcome(W_samples, Z_samples, Y_samples, outcome_degree=2, printmodel=False):
    
#    '''
#    Estimate the outcome model by Polynomial regression;
#    Obtain hat mu by the direct method.
    

#    By definition, $mu_z(W) = E[Y | Z=z, W]$.

#    So we can compute $tilde \mu_z(W) = poly_{\hat \theta}(z, W)$.
    
#    Return: \hat \mu_0, \hat \mu_1
    
   
#    '''
    

#    poly = PolynomialFeatures(degree=outcome_degree)
#    train_features = poly.fit_transform(np.column_stack((W_samples,Z_samples)))
#    theta_reg_model = LinearRegression()
#    theta_reg_model.fit(train_features, Y_samples)
    
#    r_sq = theta_reg_model.score(train_features, Y_samples)
    
#    if printmodel:
#        print('outcome model R^2 score:', r_sq)
#        print('outcome model intercept:', theta_reg_model.intercept_)
#        print('outcome model slope:', theta_reg_model.coef_)
    
    
#     poly_mu = PolynomialFeatures(degree=mu_degree)
#     test_features_mu = poly_mu.fit_transform(np.array([[W, Z]]))
    
#     mu = theta_reg_model.predict(test_features)[0]
    
#    return theta_reg_model

def solve_ols(X, Y):
    return np.linalg.pinv(X.T@X)@X.T@Y

def scalarize(rho, z):
    # temporary hack for dimensions
    return np.ravel(rho.dot(z.T))


def fit_outcome(X_samples=None, T_samples=None, Y_samples=None, rank=2, rho=None, method = 'direct'): 

    # Estimate Y
    #TODO: subsample for T=1 and T=0 to get \hat B and \hat E 
    ############## LINEAR ########################################
    #reg0 = LinearRegression().fit(X_samples[T_samples<0.5], Y_samples[T_samples<0.5])
    #reg1 = LinearRegression().fit(X_samples[T_samples>0.5], Y_samples[T_samples>0.5])

    #coef_y0 = reg0.coef_
    #coef_y1 = reg1.coef_


    # Estimate Z-hat-0
    #Y0_hat = np.dot(X_samples,coef_y0.T)

    # Estimate Z-hat-1
    #Y1_hat = np.dot(X_samples,coef_y1.T)

    ########### NONLINEAR #################################
    rf = RandomForestRegressor(
        n_estimators=200,
        max_depth=None,
        random_state=42,
        n_jobs=-1)
    rf0 = rf.fit(X_samples[T_samples<0.5], Y_samples[T_samples<0.5])
    rf1 = rf.fit(X_samples[T_samples>0.5], Y_samples[T_samples>0.5])
      
    Y0_hat = rf0.predict(X_samples)
    Y1_hat = rf1.predict(X_samples)

    # Estimate A,B
    #TODO: subsample for T=1 and T=0 to get \hat B and \hat E 
    rrr0 = reduced_rank_regressor.ReducedRankRegressor(X_samples[T_samples<0.5], Y_samples[T_samples<0.5], rank=rank, reg=None)
    A = rrr0.A
    B = rrr0.B

    rrr1 = reduced_rank_regressor.ReducedRankRegressor(X_samples[T_samples>0.5], Y_samples[T_samples>0.5], rank=rank, reg=None)
    F = rrr1.A
    E = rrr1.B

    # Estimate Z-hat-0
    Z0_hat = np.dot(X_samples,B.T)

    # Estimate Z-hat-1
    Z1_hat = np.dot(X_samples,E.T)

    

    if method == 'direct': 

        # use \hat B and \hat E to on entire X dataset
        #mu0_direct = scalarize(rho,Z0_hat) 
        #mu1_direct = scalarize(rho,Z1_hat)

        mu0_direct = scalarize(rho, Y0_hat)
        mu1_direct = scalarize(rho, Y1_hat)

        return np.array(mu0_direct).reshape(-1), np.array(mu1_direct).reshape(-1)

    elif method == 'IPW-obs': 

        # IPW 
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        mu1_ipw = T_samples / propensity_scores * scalarize(rho, Y_samples)
        mu0_ipw = (1-T_samples)/(1-propensity_scores) * scalarize(rho, Y_samples)
        # TODO: element-wise multiplication and multiple every dimension of Y by the weights  
        # TODO: scalarize by Y 
        

        return np.array(mu0_ipw).reshape(-1), np.array(mu1_ipw).reshape(-1)

    elif method == 'IPW': 

        # IPW 
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        mu1_ipw = T_samples / propensity_scores * scalarize(rho, Y1_hat)
        mu0_ipw = (1-T_samples)/(1-propensity_scores) * scalarize(rho, Y0_hat)
        # TODO: element-wise multiplication and multiple every dimension of Y by the weights  
        # TODO: scalarize by Y 
        

        return np.array(mu0_ipw).reshape(-1), np.array(mu1_ipw).reshape(-1)


    elif method == 'DR-obs': 
        # DR
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        weights1 = T_samples / propensity_scores
        weights0 = (1-T_samples)/(1-propensity_scores)
        
        dr1 = scalarize(rho, Y_samples - np.dot(Z1_hat,F.T)) 
        dr0 = scalarize(rho, Y_samples - np.dot(Z0_hat,A.T)) 

        mu1_dr = weights1*dr1 + scalarize(rho, np.dot(Z1_hat,F.T))
        mu0_dr = weights0*dr0 + scalarize(rho, np.dot(Z0_hat,A.T))

        return np.array(mu0_dr).reshape(-1), np.array(mu1_dr).reshape(-1)

    elif method == 'DR':

        # DR
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        weights1 = T_samples / propensity_scores
        weights0 = (1-T_samples)/(1-propensity_scores)
        
        dr1 = scalarize(rho, Y1_hat - np.dot(Z1_hat,F.T)) 
        dr0 = scalarize(rho, Y0_hat - np.dot(Z0_hat,A.T)) 

        mu1_dr = weights1*dr1 + scalarize(rho, np.dot(Z1_hat,F.T))
        mu0_dr = weights0*dr0 + scalarize(rho, np.dot(Z0_hat,A.T))

        return np.array(mu0_dr).reshape(-1), np.array(mu1_dr).reshape(-1)

    elif method  == 'CV':

        # CV
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        mu1_ipw = T_samples / propensity_scores * scalarize(rho, Y1_hat)
        mu0_ipw = (1-T_samples)/(1-propensity_scores) * scalarize(rho, Y0_hat)
        
        C1 = np.multiply(Y1_hat, ((1-(T_samples/propensity_scores)).T)[:, np.newaxis], )
        C0 = np.multiply(Y0_hat, ((1-((1-T_samples)/(1-propensity_scores) ) ).T)[:, np.newaxis], )

        D1_hat = solve_ols(C1, scalarize(rho, Y1_hat)).T
        D0_hat = solve_ols(C0, scalarize(rho, Y0_hat)).T

        CV1 = np.dot(C1,D1_hat)
        CV0 = np.dot(C0,D0_hat)


        mu0_cv = mu0_ipw + CV0
        mu1_cv = mu0_ipw + CV1 

        return np.array(mu0_cv).reshape(-1), np.array(mu1_cv).reshape(-1)

def fit_latent_outcome(X_samples=None, T_samples=None, Y_samples=None, rank=2, rho=None, method = 'direct'): 

    # Estimate A,B
    #TODO: subsample for T=1 and T=0 to get \hat B and \hat E 
    rrr0 = reduced_rank_regressor.ReducedRankRegressor(X_samples[T_samples<0.5], Y_samples[T_samples<0.5], rank=rank, reg=None)
    A = rrr0.A
    B = rrr0.B

    rrr1 = reduced_rank_regressor.ReducedRankRegressor(X_samples[T_samples>0.5], Y_samples[T_samples>0.5], rank=rank, reg=None)
    F = rrr1.A
    E = rrr1.B

    # Estimate Z-hat-0
    Z0_hat = np.dot(X_samples,B.T)

    # Estimate Z-hat-1
    Z1_hat = np.dot(X_samples,E.T)

    rrr_full = reduced_rank_regressor.ReducedRankRegressor(X_samples, Y_samples, rank=rank, reg=None)
    M = rrr0.A
    N = rrr0.B

    Z_full = np.dot(X_samples,N.T)

    if method == 'direct': 

        # use \hat B and \hat E to on entire X dataset
        #mu0_direct = scalarize(rho,Z0_hat) 
        #mu1_direct = scalarize(rho,Z1_hat)

        mu0_direct = scalarize(rho, np.dot(Z0_hat,A.T))
        mu1_direct = scalarize(rho, np.dot(Z1_hat,F.T))

        return np.array(mu0_direct).reshape(-1), np.array(mu1_direct).reshape(-1)

    elif method == 'IPW-obs': 

        # IPW 
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        weights1 = T_samples / propensity_scores 
        weights0 = (1-T_samples)/(1-propensity_scores)
        #Z_hat0 = np.dot(X, A_hat0)
        #Z_hat1 = np.dot(X, A_hat1)
        mu0_ipw  = weights0*scalarize(rho, Z0_hat)
        mu1_ipw= weights1*scalarize(rho, Z1_hat)

        return np.array(mu0_ipw).reshape(-1), np.array(mu1_ipw).reshape(-1)

    elif method == 'IPW': 

        hat_A_hat_Z0_scalarized = scalarize(rho, np.dot(Z0_hat,A.T))
        hat_A_hat_Z1_scalarized = scalarize(rho, np.dot(Z1_hat,F.T))

        # IPW 
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        weights1 = T_samples / propensity_scores 
        weights0 = (1-T_samples)/(1-propensity_scores)
        #Z_hat0 = np.dot(X, A_hat0)
        #Z_hat1 = np.dot(X, A_hat1)
        mu0_ipw  = weights0*hat_A_hat_Z0_scalarized 
        mu1_ipw= weights1*hat_A_hat_Z1_scalarized
        

        return np.array(mu0_ipw).reshape(-1), np.array(mu1_ipw).reshape(-1)

    elif method == 'CV':

        hat_A_hat_Z0_scalarized = scalarize(rho, np.dot(Z0_hat,A.T))
        hat_A_hat_Z1_scalarized = scalarize(rho, np.dot(Z1_hat,F.T))

        # CV
        logreg = LogisticRegression()
        logreg.fit(X_samples, T_samples)
        propensity_scores = logreg.predict_proba(X_samples)[:, 1]
        weights1 = T_samples / propensity_scores
        weights0 = (1-T_samples)/(1-propensity_scores)
        #Z_hat0 = np.dot(X, A_hat0)
        #Z_hat1 = np.dot(X, A_hat1)
        Z1_hat_ipw = weights1*hat_A_hat_Z1_scalarized
        Z0_hat_ipw = weights0*hat_A_hat_Z0_scalarized
        
        C1 = np.multiply(Z1_hat, ((1-(T_samples/propensity_scores)).T)[:, np.newaxis], )
        C0 = np.multiply(Z0_hat, ((1-((1-T_samples)/(1-propensity_scores) ) ).T)[:, np.newaxis], )

        D1_hat = solve_ols(C1, hat_A_hat_Z1_scalarized).T
        D0_hat = solve_ols(C0, hat_A_hat_Z0_scalarized).T

        CV1 = np.dot(C1,D1_hat)
        CV0 = np.dot(C0,D0_hat)


        mu0_cv = Z0_hat_ipw + CV0
        mu1_cv = Z1_hat_ipw + CV1 


        return np.array(mu0_cv), np.array(mu1_cv)


