'''Functions for in-sample causal effect estimation'''

import numpy as np

from scipy.optimize import linear_sum_assignment
from scipy.special import expit, logit
from scipy.stats import bernoulli

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression

import data
import utils
import estimation



def in_sample_estimation(num_runs = 2,outcome_degree=2, 
                         mu_degree=2, theta_true=OUTCOME_THETA, 
                         cov_dist='gaussian',
                         cov_mean=0, cov_sigma=1, 
                         noise_sigma=1, pol_type='logistic',
                         num_samples = 200, 
                         pol_theta=1, pob_b = 0.5, 
                         test_data =False, 
                         test_W_samples=None, fit_outcome_degree=2):
    
    '''
    Input: a set of training points
     
    Output: CATE estimator, \hat \mu_0 and \hat \mu_1
    
    
    '''
    est_diff_all_direct = [] #direct method
    est_diff_all_WDM = []
    est_diff_all_DR = []
    true_diff_all = []
    
    for run in range(num_runs):

        est_mu_diff_list = []
        est_mu_diff_list_WDM = []
        est_mu_diff_list_DR = []
        true_mu_diff_list = []
    
       
        W_samples, Z_samples, Y_samples = data.generate_data(degree=outcome_degree, 
                                                     theta_true=theta_true, 
                                                     cov_dist=cov_dist,
                                                     cov_mean=cov_mean, cov_sigma=cov_sigma, 
                                                     noise_sigma=noise_sigma, 
                                                             pol_type=pol_type,
                                                     num_samples = num_samples, 
                                                     pol_theta=pol_theta, pol_b = pob_b, 
                                                     test_data =False)
      
            
        fit_outcome_model = estimation.fit_outcome(W_samples, Z_samples, Y_samples, 
                                        outcome_degree=fit_outcome_degree)

        prop_model = estimation.fit_prop_score(W_samples=W_samples, Z_samples=Z_samples)
        
        mu_0_WDM, mu_1_WDM = estimation.fit_weighted_direct_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                                          mu_degree=mu_degree, prop_model=prop_model)
        
        mu_0_DR, mu_1_DR = estimation.fit_robust_mu(Z_samples=Z_samples, W_samples=W_samples, Y_samples=Y_samples, 
                                                 mu_degree=mu_degree, prop_model=prop_model)
           
        
        for test_W_sample in test_W_samples:
            
            # compute true mu_0, mu_1
            poly = PolynomialFeatures(degree=outcome_degree)
            features_0 = poly.fit_transform([[test_W_sample, 0]])
            features_1 = poly.fit_transform([[test_W_sample, 1]])
            
            
            true_mu_0 = np.dot(features_0, theta_true.reshape(-1, 1))
            true_mu_1 = np.dot(features_1, theta_true.reshape(-1, 1))
            
            true_mu_diff_list.append(true_mu_1[0][0]-true_mu_0[0][0])

            # compute estimates
            
            est_mu_0_direct = estimation.direct_mu(Z=0, W=test_W_sample, theta_reg_model=fit_outcome_model, degree=fit_outcome_degree)
            est_mu_1_direct = estimation.direct_mu(Z=1, W=test_W_sample, theta_reg_model=fit_outcome_model, degree=fit_outcome_degree)

            est_mu_diff_list.append(est_mu_1_direct-est_mu_0_direct)
            
            
            est_mu_1_WDM = estimation.weighted_direct_mu(W=test_W_sample, Z=1, mu_1_model=mu_1_WDM, 
                                                    mu_0_model=mu_0_WDM, mu_degree=mu_degree)
            est_mu_0_WDM =estimation.weighted_direct_mu(W=test_W_sample, Z=0, mu_1_model=mu_1_WDM, 
                                                    mu_0_model=mu_0_WDM, mu_degree=mu_degree)
            est_mu_diff_list_WDM.append(est_mu_1_WDM-est_mu_0_WDM)
            
            
            est_mu_1_DR =estimation.robust_mu(W=test_W_sample, Z=1, mu_1_model=mu_1_DR, mu_0_model=mu_0_DR, 
                                           mu_degree=mu_degree, prop_model=prop_model)
            est_mu_0_DR = estimation.robust_mu(W=test_W_sample, Z=0, mu_1_model=mu_1_DR, mu_0_model=mu_0_DR, 
                                           mu_degree=mu_degree, prop_model=prop_model)
            est_mu_diff_list_DR.append(est_mu_1_DR-est_mu_0_DR)

    
        est_diff_all_direct.append(est_mu_diff_list)
        
        est_diff_all_WDM.append(est_mu_diff_list_WDM)
        
    
        est_diff_all_DR.append(est_mu_diff_list_DR)
 
        true_diff_all.append(true_mu_diff_list)
        
    
    
    results_dict = {'direct': np.array(est_diff_all_direct),
                   'WDM': np.array(est_diff_all_WDM),
                   'GRDR': np.array(est_diff_all_DR),
                    'oracle': np.array(true_diff_all)}
                    
    return results_dict

