"""
Generating high-fidelity privacy-conscious synthetic patient data for causal effect estimation with multiple treatments
Evaluate three causal inference models using synthetic dataset
"""

import dowhy
from dowhy import CausalModel
import dowhy.datasets
from tqdm import tqdm
from econml.dr import DRLearner, ForestDRLearner
import pandas as pd
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import GridSearchCV

def df_standardize(df):
    scales = df.max() - df.min()
    df_scaled = (df - df.min()) / (scales + 0.00001)
    return df_scaled, scales

def infer_effects_all_drugs(df_in, covar_col, all_t_col, y_col):

    df = df_in.copy()
    df_scaled, scales = df_standardize(df)
    df_scaled['drug_sum'] = df[all_t_col].sum(axis = 1)
    #covar_col = covar_col[1:100]

    CUTOFF = 0.002
    all_ate_effects = dict()
    print("Inferring causal effect for all drugs")
    for col in tqdm(all_t_col):
        # convert drug column to True and False
        df_scaled[col] = df_scaled[col].apply(lambda x: x > 0.5)

        # treated and controls
        df_controls = df_scaled.loc[df_scaled['drug_sum'] < 0.5, :]
        df_treated = df_scaled.loc[df_scaled['drug_sum'] > 0.5, :]

        # combined treated with this drug and controls from other drugs
        df_controls_not_this_drug = df_controls.loc[df_controls.index.difference(df_treated.loc[df_treated[col], :].index),:]
        df_controls_this_drug = df_controls.loc[df_treated.loc[df_treated[col], :].index,:]
        df_treated_this_drug = df_scaled.loc[df_scaled[col], :]

        #df_selected = pd.concat([df_treated_this_drug, df_controls_not_this_drug])
        df_selected = pd.concat([df_treated_this_drug, df_controls_not_this_drug])
        df_selected.sort_index(inplace = True)
        #df_selected = df_scaled.loc[df_scaled.loc[df_scaled[col], :].index,:]

        if df_treated_this_drug.shape[0] > len(df_scaled[col]) * CUTOFF:
            #df_scaled = df_scaled.iloc[1:10000,:]
            #ate_effect  = infer_effects_propensity (df_selected, covar_col, col, y_col)
            ate_effect  = infer_effects_dr (df_selected, covar_col, col, y_col)

            all_ate_effects[col] = ate_effect * scales[y_col]
            print('finished ' + col + ' ate: {:4.2f}'.format(all_ate_effects[col]))
        else:
            all_ate_effects[col] = 0

    return all_ate_effects

def infer_effects_propensity (df_in, covar_col, t_col, y_col):

    model = CausalModel(
        data = df_in,
        treatment = t_col,
        outcome = y_col,
        common_causes = covar_col, 
        method_params = {'max_iter': 100000})

    identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)

    effect_estimate = model.estimate_effect(identified_estimand,
                                method_name="backdoor.dowhy.propensity_score_stratification")
                                #method_name="backdoor.dowhy.propensity_score_matching", method_params = {'max_iter': 100000})
                                #method_name="backdoor.dowhy.propensity_score_weighting")
                                #method_name="backdoor.dowhy.generalized_linear_model")
    return effect_estimate.value

def infer_effects_dr_simple (df_in, covar_col, t_col, y_col):


    est = DRLearner()
    #est.fit(df_in[y_col], df_in[t_col], W=df_in[covar_col], inference = 'bootstrap')
    est.fit(df_in[y_col], df_in[t_col], W = df_in[covar_col])

    return est.ate()

def infer_effects_linear_dr (df_in, covar_col, t_col, y_col):

    model = CausalModel(
        data = df_in,
        treatment = t_col,
        outcome = y_col,
        common_causes = covar_col)

    identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)

    drlearner_estimate = model.estimate_effect(identified_estimand,
                                    method_name="backdoor.econml.dr.LinearDRLearner",
                                    confidence_intervals=False,
                                    method_params={"init_params":{
                                                        'model_propensity': LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto', max_iter = 1000)
                                                        },
                                                   "fit_params":{}
                                                   })
                                                
    return drlearner_estimate.value


def infer_effects_dr (df_in, covar_col, t_col, y_col):

    model = CausalModel(
        data = df_in,
        treatment = t_col,
        outcome = y_col,
        common_causes = covar_col)

    identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)

    model_reg = lambda: GridSearchCV(
                    estimator=RandomForestRegressor(),
                    param_grid={
                            'max_depth': [3, None],
                            'n_estimators': (10, 20)
                        }, cv=2, n_jobs=-1, scoring='neg_mean_squared_error'
                    )
    model_clf = lambda: GridSearchCV(
                    estimator=RandomForestClassifier(min_samples_leaf=10),
                    param_grid={
                            'max_depth': [3, None],
                            'n_estimators': (10, 20)
                        }, cv=2, n_jobs=-1, scoring='neg_mean_squared_error'
                    )

    drlearner_estimate = ForestDRLearner(model_regression = RandomForestRegressor(n_estimators = 10, max_depth = 5, n_jobs = -1),
                                   #model_propensity = LogisticRegression(solver='lbfgs', multi_class='auto', max_iter = 1000)
                                   model_propensity = RandomForestClassifier(n_estimators = 10, max_depth = 5, n_jobs = -1)
                                   )

    drlearner_estimate.fit(df_in[y_col],  df_in[t_col], X = df_in[covar_col], W = df_in[covar_col])
    return_value = drlearner_estimate.ate(X = df_in[covar_col])
    return return_value