from benchmark_data import *
from causalml.inference.meta import BaseDRLearner, BaseRLearner, BaseXLearner
from causalml.inference.tf import DragonNet
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.linear_model import LinearRegression
import json
from data import *
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import RobustScaler
import pickle
from data import *
from model.network import StoNet_Causal
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import KFold
########################################################################################################################
seed = 1
np.random.seed(seed)
tf.random.set_seed(seed)
tf.keras.utils.set_random_seed(seed)

# network configure
common_args = dict(activation='tanh', solver='sgd', alpha=0, batch_size=200, learning_rate='invscaling', max_iter=2500,
                   early_stopping=True, random_state=seed)
treatment_model = MLPClassifier(hidden_layer_sizes=(64,), **common_args)
outcome_model = MLPRegressor(hidden_layer_sizes=(64, 32), **common_args)

# true ATEs
ate_trues = np.array([0.8, -0.8, -0.3429, 0, -1.432, 9.134, -3.159, -0.8486,-0.16058, 1])

# containers of absolute errors
eps_dragon_out, eps_dr_out, eps_r_out, eps_x_out, eps_stonet_out= \
    np.zeros(10), np.zeros(10),  np.zeros(10), np.zeros(10), np.zeros(10)
eps_dragon_in, eps_dr_in, eps_r_in, eps_x_in, eps_stonet_in = \
    np.zeros(10), np.zeros(10),  np.zeros(10), np.zeros(10), np.zeros(10)

for i in range(10):
    print(i)

    # load data
    y, treat, x = ACIC_bench(i+1)
    ate_true = ate_trues[i]

    # divide dataset
    cv = KFold(n_splits=3, shuffle=True, random_state=seed)
    split_indices = [index for _, index in cv.split(y)]
    train_idx = np.concatenate((split_indices[0], split_indices[1]))
    test_idx = split_indices[2]

    y_train, y_test = y[train_idx], y[test_idx]
    treat_train, treat_test = treat[train_idx], treat[test_idx]
    x_train, x_test = x[train_idx], x[test_idx]

    # fit the treatment model first
    treat_fit = treatment_model.fit(x_train, treat_train)
    propensity_train = treat_fit.predict_proba(x_train)[:, 1]
    propensity_test = treat_fit.predict_proba(x_test)[:, 1]

    # DR-Learner
    learner_dr = BaseDRLearner(control_outcome_learner=outcome_model, treatment_outcome_learner=outcome_model,
                               treatment_effect_learner=LinearRegression())
    learner_dr.fit(x_train, treat_train, y_train, p=propensity_train, seed=seed)

    cate_dr = learner_dr.predict(x_train, treat_train, y_train, propensity_train)
    eps_dr_in[i] = abs(cate_dr.mean() - ate_true)
    print('dr in-sample', eps_dr_in[i])

    cate_dr = learner_dr.predict(x_test, treat_test, y_test, propensity_test)
    eps_dr_out[i] = abs(cate_dr.mean() - ate_true)
    print('dr out-of-sample', eps_dr_out[i])

    # R-Learner
    learner_r = BaseRLearner(outcome_learner=outcome_model, effect_learner=LinearRegression(), n_fold=3)
    learner_r.fit(x_train, treat_train, y_train, propensity_train)

    cate_r = learner_r.predict(x_train, propensity_train)
    eps_r_in[i] = abs(cate_r.mean() - ate_true)
    print('r in-sample', eps_r_in[i])

    cate_r = learner_r.predict(x_test, propensity_test)
    eps_r_out[i] = abs(cate_r.mean() - ate_true)
    print('r out-of-sample', eps_r_out[i])

    # X-Learner
    learner_x = BaseXLearner(control_outcome_learner=outcome_model, treatment_outcome_learner=outcome_model,
                             control_effect_learner=LinearRegression(), treatment_effect_learner=LinearRegression())
    learner_x.fit(x_train, treat_train, y_train, propensity_train)

    cate_x = learner_x.predict(x_train, treat_train, y_train, propensity_train)
    eps_x_in[i] = abs(cate_x.mean() - ate_true)
    print('x in-sample', eps_x_in[i])

    cate_x = learner_x.predict(x_test, treat_test, y_test, propensity_test)
    eps_x_out[i] = abs(cate_x.mean() - ate_true)
    print('x out-sample', eps_x_out[i])

    # DragonNet : needs to scale y to prevent gradient explosion
    y_scaler = RobustScaler()
    y_scaler.fit(y_train.reshape(-1, 1))
    y_train_scaled = np.array(y_scaler.transform(y_train.reshape(-1, 1)))

    dragon = DragonNet(neurons_per_layer=200, targeted_reg=True, batch_size=128, learning_rate=1e-6,
                       adam_learning_rate=1e-4, verbose=False)
    dragon.fit(X=x_train, treatment=treat_train, y=y_train_scaled)

    preds = dragon.predict(X=x_train)
    q_t0 = y_scaler.inverse_transform(preds[:, 0].reshape(-1, 1).copy())
    q_t1 = y_scaler.inverse_transform(preds[:, 1].reshape(-1, 1).copy())
    cate_dragon = (q_t1 - q_t0).reshape(-1, 1)
    eps_dragon_in[i] = abs(cate_dragon.mean() - ate_true)
    print('dragon in-sample', eps_dragon_in[i])

    preds = dragon.predict(X=x_test)
    q_t0 = y_scaler.inverse_transform(preds[:, 0].reshape(-1, 1).copy())
    q_t1 = y_scaler.inverse_transform(preds[:, 1].reshape(-1, 1).copy())
    cate_dragon = (q_t1 - q_t0).reshape(-1, 1)
    eps_dragon_out[i] = abs(cate_dragon.mean() - ate_true)
    print('dragon out-of-sample', eps_dragon_out[i])

    # stonet
    net_args = dict(num_hidden=2, hidden_dim=[64, 32], input_dim=200, output_dim=1, treat_layer=1, treat_node=1)
    net = StoNet_Causal(**net_args)

    file_dir = "acic"
    bic = np.loadtxt(os.path.join(file_dir, 'Overall_BIC.txt'))
    prune_seed = np.argmin(bic)

    model_name = 'model' + str(prune_seed) + '.pt'
    model_dir = os.path.join(file_dir, str(prune_seed), model_name)
    net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))

    data = acic_data_homo(dgp[i])
    train_set, test_set, x_scalar, y_scalar = data_preprocess(data, 1, 1)
    train_data = DataLoader(train_set, batch_size=train_set.__len__())
    test_data = DataLoader(test_set, batch_size=test_set.__len__())
    with torch.no_grad():
        for y, treat, x in train_data:
            y = torch.FloatTensor(np.array(y_scalar.inverse_transform(y.cpu())))
            pred, prop_score = net.forward(x, treat)
            pred = torch.FloatTensor(np.array(y_scalar.inverse_transform(pred.cpu())))
            counter_fact, _ = net.forward(x, 1 - treat)
            counter_fact = torch.FloatTensor(np.array(y_scalar.inverse_transform(counter_fact.cpu())))
            outcome_contrast = torch.flatten(pred-counter_fact) * (2*treat - 1)
            prop_contrast = treat/prop_score - (1-treat)/(1-prop_score)
            pred_resid = torch.flatten(y - pred)
            ate_stonet = (outcome_contrast + prop_contrast * pred_resid).mean()
        eps_stonet_in[i] = abs(ate_stonet - ate_true)
        print('stonet in-sample', eps_stonet_in[i])

        for y, treat, x in test_data:
            y = torch.FloatTensor(np.array(y_scalar.inverse_transform(y.cpu())))
            pred, prop_score = net.forward(x, treat)
            pred = torch.FloatTensor(np.array(y_scalar.inverse_transform(pred.cpu())))
            counter_fact, _ = net.forward(x, 1 - treat)
            counter_fact = torch.FloatTensor(np.array(y_scalar.inverse_transform(counter_fact.cpu())))
            outcome_contrast = torch.flatten(pred-counter_fact) * (2*treat - 1)
            prop_contrast = treat/prop_score - (1-treat)/(1-prop_score)
            pred_resid = torch.flatten(y - pred)
            ate_stonet = (outcome_contrast + prop_contrast * pred_resid).mean()
        eps_stonet_out[i] = abs(ate_stonet - ate_true)
        print('stonet out-of-sample', eps_stonet_out[i])
########################################################################################################################
# save the results
dr = dict(error_in=eps_dr_in.mean(), sd_in=np.sqrt(eps_dr_in.var()/10),
          error_out=eps_dr_out.mean(), sd_out=np.sqrt(eps_dr_out.var()/10))
r = dict(error_in=eps_r_in.mean(), sd_in=np.sqrt(eps_r_in.var()/10),
         error_out=eps_r_out.mean(), sd_out=np.sqrt(eps_r_out.var()/10))
x = dict(error_in=eps_x_in.mean(), sd_in=np.sqrt(eps_x_in.var()/10),
         error_out=eps_x_out.mean(), sd_out=np.sqrt(eps_x_out.var()/10))
dragon = dict(error_in=eps_dragon_in.mean(), sd_in=np.sqrt(eps_dragon_in.var()/10),
              error_out=eps_dragon_out.mean(), sd_out=np.sqrt(eps_dragon_out.var()/10))
stonet = dict(error_in=eps_stonet_in.mean(), sd_in=np.sqrt(eps_stonet_in.var()/10),
              error_out=eps_stonet_out.mean(), sd_out=np.sqrt(eps_stonet_out.var()/10))

ate_result = dict(dr=dr, r=r, x=x, dragon=dragon, stonet=stonet)

ate_result

# save the results to json file
ate_file = open('./benchmark/bench_acic.json', "w")
json.dump(ate_result, ate_file, indent="")
ate_file.close()
