''' Main functions for estimation. '''

import numpy as np

from scipy.optimize import linear_sum_assignment
from scipy.special import expit, logit
from scipy.stats import bernoulli

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression

import utils


def perturbation_alg_bootstrap(eval_pol=None, h=1, 
                               W_samples=None, Z_samples=None, 
                               Y_samples=None, outcome_degree=2, test_W_samples=None,
                              mu_degree=2, S=10, method='WDM',right_node_num=200, bootstrap=True):
    
    '''
    Main perturbation algorithm (Alg 1 in paper) using bootstrap with replacement.
    
    eval_pol: \pi(W) --> prob_0, prob_1 that is to be evaluated
    S: number of draws
    size of draw: N/((1+h)**2)
    
    method: direct, WDM, GRDR
     
    '''
    
    num_samples = Z_samples.shape[0]
    m = test_W_samples.shape[0]
    draw_size = int(num_samples/((1+h)**2))
    
     
    # Compute v hat zero
    if method == 'direct':
        
        outcome_model_all = fit_outcome(W_samples, Z_samples, Y_samples, 
                                        outcome_degree=outcome_degree)
        
        v_hat_zero = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              outcome_model=outcome_model_all)
            
        
    elif method == 'WDM':
            
        prop_model_all = fit_prop_score(W_samples=W_samples, Z_samples=Z_samples)
        mu_0_all,mu_1_all = fit_weighted_direct_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                                      mu_degree=mu_degree, prop_model=prop_model_all)
        
        v_hat_zero = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_all, mu_1=mu_1_all, prop_model=prop_model_all)
            
    elif method == 'GRDR':    

        prop_model_all = fit_prop_score(W_samples=W_samples, Z_samples=Z_samples)
        mu_0_all, mu_1_all = fit_robust_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                             mu_degree=mu_degree, prop_model=prop_model_all)

        v_hat_zero = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_all, mu_1=mu_1_all, prop_model=prop_model_all)
    
    if bootstrap == False:
        return v_hat_zero
        
        
    # Bootstrap
    
    v_hat_list = []
    
    for num_draw in range(S):
        
        indices = np.random.choice(num_samples, draw_size, replace=False)
        
        W_samples_fold = W_samples[indices]
        Z_samples_fold  = Z_samples[indices]
        Y_samples_fold  = Y_samples[indices]
        
        if method == 'direct':
            
            outcome_model_fold = fit_outcome(W_samples_fold, Z_samples_fold, Y_samples_fold, 
                                        outcome_degree=outcome_degree)
            
            v_hat_fold = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              outcome_model=outcome_model_fold)
            
        
        elif method == 'WDM':
            
            prop_model_fold = fit_prop_score(W_samples=W_samples_fold, Z_samples=Z_samples_fold)
            mu_0_fold, mu_1_fold = fit_weighted_direct_mu(Z_samples=Z_samples_fold, W_samples=W_samples_fold, Y_samples=Y_samples_fold, 
                                                          mu_degree=mu_degree, prop_model=prop_model_fold)
            
            v_hat_fold = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_fold, mu_1=mu_1_fold, prop_model=prop_model_fold)
            
        elif method == 'GRDR':    
            
            prop_model_fold = fit_prop_score(W_samples=W_samples_fold, Z_samples=Z_samples_fold)
            mu_0_fold, mu_1_fold = fit_robust_mu(Z_samples=Z_samples_fold, W_samples=W_samples_fold, Y_samples=Y_samples_fold, 
                                                 mu_degree=mu_degree, prop_model=prop_model_fold)

            v_hat_fold = eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method=method,
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              mu_0=mu_0_fold, mu_1=mu_1_fold, prop_model=prop_model_fold)
            
        
        v_hat_list.append(v_hat_fold)
         
    
    rho = v_hat_zero - (1/float(h)) * (v_hat_zero - np.mean(v_hat_list))
    
    return rho
    

def eval_value(eval_pol=None, right_node_num=200, method='WDM',
               test_W_samples=None, outcome_degree=2, mu_degree=2,
              outcome_model=None, mu_0=None, mu_1=None, prop_model=None):
    
    '''
    evaluate v hat for a given test set, given policy, given estimator.
    
    '''
    
    pi_0_list = []
    pi_1_list = []
    mu_1_list = []
    mu_0_list = []
    
    for i, test_w_sample in enumerate(test_W_samples):
        
        pi_0_list.append(eval_pol.get_pi_t(W=test_w_sample, Z=0))
        pi_1_list.append(eval_pol.get_pi_t(W=test_w_sample, Z=1))

        if method == 'direct':

            mu_1_list.append(direct_mu(Z=1, W=test_w_sample, theta_reg_model=outcome_model, 
                          degree=outcome_degree))
            mu_0_list.append(direct_mu(Z=0, W=test_w_sample, theta_reg_model=outcome_model, 
                      degree=outcome_degree))

        elif method == 'WDM':

            mu_1_list.append(weighted_direct_mu(W=test_w_sample, Z=1, mu_1_model=mu_1, 
                                                mu_0_model=mu_0, mu_degree=mu_degree))
            mu_0_list.append(weighted_direct_mu(W=test_w_sample, Z=0, mu_1_model=mu_1, 
                                                mu_0_model=mu_0, mu_degree=mu_degree))

        elif method == 'GRDR':

            mu_1_list.append(robust_mu(W=test_w_sample, Z=1, mu_1_model=mu_1, mu_0_model=mu_0, 
                                       mu_degree=mu_degree, prop_model=prop_model))
            mu_0_list.append(robust_mu(W=test_w_sample, Z=0, mu_1_model=mu_1, mu_0_model=mu_0, 
                                       mu_degree=mu_degree, prop_model=prop_model))

    
    pi_0_list = np.multiply(np.array(pi_0_list), np.array(mu_0_list))
    
    pi_1_list = np.multiply(np.array(pi_1_list), np.array(mu_1_list))

    test_sample_costs = np.add(pi_0_list, pi_1_list)

    cost_matrix = np.transpose([test_sample_costs] * right_node_num)
    v_hat, _, _ = utils.compute_matching(cost_matrix)

    return v_hat
    
    

def robust_mu(W=None, Z=None, mu_1_model=None, mu_0_model=None, mu_degree=2, prop_model=None):
    
    '''
    evaluate direct method \hat \mu for given W and Z
    '''
    
    poly = PolynomialFeatures(degree=mu_degree)
    test_features = poly.fit_transform(np.array([[W]]))
    
    prop_scores = prop_model.predict_proba([[W]]) #1 by 2 array
    
    add_features = np.array([[float(1-Z)/(prop_scores[0][0]), float(Z)/(prop_scores[0][1])]])
    
    test_features = np.append(test_features, add_features, 1)
    
    if Z==0:
        mu = mu_0_model.predict(test_features)[0]
    else:
        mu = mu_1_model.predict(test_features)[0]
    
    return mu

def fit_robust_mu(Z_samples=None, W_samples=None, Y_samples=None, mu_degree=2, 
                        prop_model=None, printmodel=False):
    
    '''
    Doubly robust estimator.
    
    fit \hat \mu_t(W) for t=0, 1 given a propensity score model by a polynomial regression on W.
    
    
    Return: two models \hat \mu_t(W) for t=0, 1
    '''
        
    poly = PolynomialFeatures(degree=mu_degree)
    features = poly.fit_transform(W_samples.reshape(-1,1))
    
    prop_scores = prop_model.predict_proba(W_samples.reshape(-1,1))
        
    e_1_feature = np.divide(Z_samples, prop_scores[:,1])
    e_0_feature = np.divide(np.subtract(1,Z_samples), prop_scores[:,0])
    
    features = np.append(features, e_0_feature.reshape(-1,1),1)
    features = np.append(features, e_1_feature.reshape(-1,1),1)

    
    Z0_idx = np.where(Z_samples == 0)[0]
    Z1_idx = np.where(Z_samples == 1)[0]

    #generate weights
    
    weights_Z1 = []
    weights_Z0 = []
    
    for i, Z_sample in enumerate(Z_samples):
        if Z_sample == 1:
            weights_Z1.append(float(1)/float(prop_scores[i][1]))
        else:
            weights_Z0.append(float(1)/float(prop_scores[i][0]))
    
    mu_1 = LinearRegression()
    mu_1.fit(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    r_sq_1 = mu_1.score(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    if printmodel:
        print('\nGRDR R^2 score (Z=1):', r_sq_1)
        print('GRDR intercept (Z=1):', mu_1.intercept_)
        print('GRDR slope (Z=1):', mu_1.coef_)

    mu_0 = LinearRegression()
    mu_0.fit(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    r_sq_0 = mu_0.score(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    if printmodel: 
        print('\nGRDR R^2 score (Z=0):', r_sq_0)
        print('GRDR intercept (Z=0):', mu_0.intercept_)
        print('GRDR slope (Z=0):', mu_0.coef_)

    return mu_0, mu_1
    


def weighted_direct_mu(W=None, Z=None, mu_1_model=None, mu_0_model=None, mu_degree=2):
    
    '''
    evaluate direct method \hat \mu for given W and Z
    '''
    
    poly = PolynomialFeatures(degree=mu_degree)
    test_features = poly.fit_transform(np.array([[W]]))
    
    if Z==0:
        mu = mu_0_model.predict(test_features)[0]
    else:
        mu = mu_1_model.predict(test_features)[0]
    
    return mu

def fit_weighted_direct_mu(Z_samples=None, W_samples=None, Y_samples=None, mu_degree=2, 
                        prop_model=None, printmodel=False):
    
    '''
    fit \hat \mu_t(W) for t=0, 1 given a propensity score model by a polynomial regression on W.
    
    
    Return: two models \hat \mu_t(W) for t=0, 1
    '''
        
    poly = PolynomialFeatures(degree=mu_degree)
    features = poly.fit_transform(W_samples.reshape(-1,1))
    
    Z0_idx = np.where(Z_samples == 0)[0]
    Z1_idx = np.where(Z_samples == 1)[0]

    #generate weights
    prop_scores = prop_model.predict_proba(W_samples.reshape(-1,1))
    
    weights_Z1 = []
    weights_Z0 = []
    
    for i, Z_sample in enumerate(Z_samples):
        if Z_sample == 1:
            weights_Z1.append(float(1)/float(prop_scores[i][1]))
        else:
            weights_Z0.append(float(1)/float(prop_scores[i][0]))
    
    mu_1 = LinearRegression()
    mu_1.fit(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    r_sq_1 = mu_1.score(features[Z1_idx], Y_samples[Z1_idx], weights_Z1)
    
    if printmodel:
        
        print('\nWDM R^2 score (Z=1):', r_sq_1)
        print('WDM intercept (Z=1):', mu_1.intercept_)
        print('WDM slope (Z=1):', mu_1.coef_)

    mu_0 = LinearRegression()
    mu_0.fit(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    r_sq_0 = mu_0.score(features[Z0_idx], Y_samples[Z0_idx], weights_Z0)
    
    if printmodel:
        print('\nWDM R^2 score (Z=0):', r_sq_0)
        print('WDM intercept (Z=0):', mu_0.intercept_)
        print('WDM slope (Z=0):', mu_0.coef_)

    return mu_0, mu_1
    


def fit_prop_score(W_samples=None, Z_samples=None, printmodel=False):
    
    '''
    fit a propensity score model to the given data
    
    Change penalty to 'none' to turn off reg.
    '''
    
    clf = LogisticRegression(random_state=0, penalty='l2').fit(W_samples.reshape(-1,1), Z_samples.reshape(-1,1))
    
    if printmodel:
 
        print('score of fitted prop model: ', clf.score(W_samples.reshape(-1,1), Z_samples.reshape(-1,1)))
    
    return clf
    
    
def direct_mu(Z=None, W=None, theta_reg_model=None, degree=2):
    
    '''
    evaluate direct method \hat \mu for given W and Z
    '''
    
    poly = PolynomialFeatures(degree=degree)
    test_features = poly.fit_transform(np.array([[W, Z]]))
    
    mu = theta_reg_model.predict(test_features)[0]
    
    return mu

class oracle_outcome_model:
    
    def __init__(self, true_theta=None):
    
        self.true_theta = true_theta
    
    def predict(self, test_features):
        
#         preds = []
        
#         for test_point in test_features:
        
        pred = np.dot(self.true_theta, test_features.reshape(-1,1))
        
        return pred
    
class my_logistic_policy:
    
    def __init__(self, phi=None, b=None):
    
        self.phi = phi
        self.b = b
    
    def get_pi_t(self, W=None, Z=None):
        
        if Z==0:
            prob = 1 - expit(self.phi * W + self.b)
        else:
            prob = expit(self.phi * W + self.b)
        
        return prob



def fit_outcome(W_samples, Z_samples, Y_samples, outcome_degree=2, printmodel=False):
    
    '''
    Estimate the outcome model by Polynomial regression;
    Obtain \hat \mu by the direct method.
    

    By definition, $\mu_z(W) = E[Y | Z=z, W]$.

    So we can compute $\tilde \mu_z(W) = poly_{\hat \theta}(z, W)$.
    
    Return: \hat \mu_0, \hat \mu_1
    
   
    '''
    

    poly = PolynomialFeatures(degree=outcome_degree)
    train_features = poly.fit_transform(np.column_stack((W_samples,Z_samples)))
    theta_reg_model = LinearRegression()
    theta_reg_model.fit(train_features, Y_samples)
    
    r_sq = theta_reg_model.score(train_features, Y_samples)
    
    if printmodel:
        print('outcome model R^2 score:', r_sq)
        print('outcome model intercept:', theta_reg_model.intercept_)
        print('outcome model slope:', theta_reg_model.coef_)
    
    
#     poly_mu = PolynomialFeatures(degree=mu_degree)
#     test_features_mu = poly_mu.fit_transform(np.array([[W, Z]]))
    
#     mu = theta_reg_model.predict(test_features)[0]
    
    return theta_reg_model

