from benchmark_data import *
from causalml.inference.meta import BaseXLearner
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
import json
import numpy as np
from sklearn.model_selection import KFold
from imblearn.over_sampling import SMOTENC

seed = 1
np.random.seed(seed)

# network configure
common_args = dict(activation='tanh', solver='sgd', alpha=0, batch_size=64, learning_rate='adaptive',
                   learning_rate_init=1e-3, max_iter=5000, early_stopping=True, random_state=seed)
treatment_model = MLPClassifier(hidden_layer_sizes=(64,), **common_args)
outcome_model = MLPClassifier(hidden_layer_sizes=(64, 16, 2), **common_args)

# load data
y, treat, x = BRCA_bench()

ates_x = np.zeros(3)

cv = KFold(n_splits=3, shuffle=True, random_state=seed)
split_indices = [index for _, index in cv.split(y)]

for ifold in range(3):

    train_idx = np.concatenate((split_indices[ifold], split_indices[(ifold + 1) % 3]))
    val_idx = split_indices[(ifold + 2) % 3]

    y_train, y_val = y[train_idx], y[val_idx]
    treat_train, treat_val = treat[train_idx], treat[val_idx]
    x_train, x_val = x[train_idx], x[val_idx]

    # use SMOTE to create a balanced training dataset
    w = np.concatenate((treat_train.reshape((len(treat_train), 1)), x_train), axis=1)
    sm = SMOTENC(random_state=seed, categorical_features=np.arange(17))
    w_train, y_train = sm.fit_resample(w, y_train)
    treat_train = w_train[:, 0]
    x_train = w_train[:, 1:]

    treat_fit = treatment_model.fit(x_train, treat_train)
    prop_fit = treat_fit.predict_proba(x_train)[:, 1]
    prop_predict = treat_fit.predict_proba(x_val)[:, 1]

    # X-Learner
    learner_x = BaseXLearner(control_outcome_learner=outcome_model, treatment_outcome_learner=outcome_model,
                             control_effect_learner=LogisticRegression(), treatment_effect_learner=LogisticRegression())
    learner_x.fit(X=x_train, treatment=treat_train, y=y_train, p=prop_fit)
    cate_x = learner_x.predict(X=x_val, treatment=treat_val, y=y_val, p=prop_predict)
    ates_x[ifold] = cate_x.mean()

ate_x = ates_x.mean()
print('ate_x', ate_x)

std_x = np.sqrt(ates_x.var()/3)
print('std_x', std_x)

ate_result = dict(x_learner=dict(ate=ate_x.tolist(), sd=std_x.tolist()))

# save the results to json file
ate_file = open('./benchmark/bench_brca_ate' + '.json', "w")
json.dump(ate_result, ate_file, indent="")
ate_file.close()
