import numpy as np
from benchmark_data import *
from causalml.inference.meta import BaseDRLearner, BaseRLearner, BaseXLearner
from causalml.inference.tf import DragonNet
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.linear_model import LinearRegression
from data import *
from torch.utils.data import DataLoader, random_split
from model.network import StoNet_Causal
import pickle
import json
import tensorflow as tf

########################################################################################################################
seed = 1
np.random.seed(seed)
tf.random.set_seed(seed)
tf.keras.utils.set_random_seed(seed)

# network configure
common_args = dict(activation='tanh', solver='sgd', alpha=0, batch_size=100, learning_rate='invscaling', max_iter=1500,
                   early_stopping=True)
treatment_model = MLPClassifier(hidden_layer_sizes=(64,), **common_args)
outcome_model = MLPRegressor(hidden_layer_sizes=(64, 32, 16), **common_args)

# containers of absolute errors
eps_dragon_out, eps_dr_out, eps_r_out, eps_x_out, eps_stonet_out= \
    np.zeros(10), np.zeros(10),  np.zeros(10), np.zeros(10), np.zeros(10)
eps_dragon_in, eps_dr_in, eps_r_in, eps_x_in, eps_stonet_in = \
    np.zeros(10), np.zeros(10),  np.zeros(10), np.zeros(10), np.zeros(10)

# load test set index
with open('./sim_test_idx.pkl', 'rb') as f:
    test_idx = pickle.load(f)

for dgp in range(1, 11, 1):
    print('iter', dgp)

    # extract test set index
    idx = test_idx[str(dgp)]

    # load data
    y, treat, x, y_count = SimData_bench(100, dgp, 11000)

    # divide dataset
    y_test, treat_test, x_test, y_count_test = y[idx], treat[idx], x[idx], y_count[idx]
    ate_true_out = true_cate(y_test, treat_test, y_count_test).mean()
    print('ate_true_out', ate_true_out)

    y_train, treat_train, x_train, y_count_train = np.delete(y, idx, 0), np.delete(treat, idx, 0), np.delete(x, idx, 0), \
                                                  np.delete(y_count, idx, 0)
    ate_true_in = true_cate(y_train, treat_train, y_count_train).mean()
    print('ate_true_in', ate_true_in)

    # fit the treatment model first
    treat_fit = treatment_model.fit(x_train, treat_train)
    propensity_train = treat_fit.predict_proba(x_train)[:, 1]
    propensity_test = treat_fit.predict_proba(x_test)[:, 1]

    # DR-Learner
    learner_dr = BaseDRLearner(control_outcome_learner=outcome_model, treatment_outcome_learner=outcome_model,
                               treatment_effect_learner=LinearRegression())
    learner_dr.fit(x_train, treat_train, y_train, p=propensity_train, seed=seed)

    cate_dr = learner_dr.predict(x_train, treat_train, y_train, propensity_train)
    eps_dr_in[dgp-1] = abs(cate_dr.mean() - ate_true_in)
    print('dr in-sample', eps_dr_in[dgp-1])

    cate_dr = learner_dr.predict(x_test, treat_test, y_test, propensity_test)
    eps_dr_out[dgp-1] = abs(cate_dr.mean() - ate_true_out)
    print('dr out-of-sample', eps_dr_out[dgp-1])

    # R-Learner
    learner_r = BaseRLearner(outcome_learner=outcome_model, effect_learner=LinearRegression(), n_fold=3)
    learner_r.fit(x_train, treat_train, y_train, propensity_train)

    cate_r = learner_r.predict(x_train, propensity_train)
    eps_r_in[dgp-1] = abs(cate_r.mean() - ate_true_in)
    print('r in-sample', eps_r_in[dgp-1])

    cate_r = learner_r.predict(x_test, propensity_test)
    eps_r_out[dgp-1] = abs(cate_r.mean() - ate_true_out)
    print('r out-of-sample', eps_r_out[dgp-1])

    # X-Learner
    learner_x = BaseXLearner(control_outcome_learner=outcome_model, treatment_outcome_learner=outcome_model,
                             control_effect_learner=LinearRegression(), treatment_effect_learner=LinearRegression())
    learner_x.fit(x_train, treat_train, y_train, propensity_train)

    cate_x = learner_x.predict(x_train, treat_train, y_train, propensity_train)
    eps_x_in[dgp-1] = abs(cate_x.mean() - ate_true_in)
    print('x in-sample', eps_x_in[dgp-1])

    cate_x = learner_x.predict(x_test, treat_test, y_test, propensity_test)
    eps_x_out[dgp-1] = abs(cate_x.mean() - ate_true_out)
    print('x out-of-sample', eps_x_out[dgp-1])

    # DragonNet
    dragon_leaner = DragonNet(neurons_per_layer=200, targeted_reg=True)
    dragon_leaner.fit(x_train, treat_train, y_train)

    cate_dragon = dragon_leaner.predict_tau(x_train)
    eps_dragon_in[dgp-1] = abs(cate_dragon.mean() - ate_true_in)
    print('dragon in-sample', eps_dragon_in[dgp-1])

    cate_dragon = dragon_leaner.predict_tau(x_test)
    eps_dragon_out[dgp-1] = abs(cate_dragon.mean() - ate_true_out)
    print('dragon out-of-sample', eps_dragon_out[dgp-1])

    # stonet
    net_args = dict(num_hidden=3, hidden_dim=[6, 4, 3], input_dim=100, output_dim=1, treat_layer=1, treat_node=1)
    net = StoNet_Causal(**net_args)

    file_dir = "simulation"
    bic = np.loadtxt(os.path.join(file_dir, 'Overall_BIC.txt'))
    prune_seed = np.argmin(bic)

    model_name = 'model' + str(prune_seed) + '.pt'
    model_dir = os.path.join(file_dir, str(prune_seed), model_name)
    net.load_state_dict(torch.load(model_dir,  map_location=torch.device('cpu')))

    data = SimData_Causal(100, dgp, 11000)
    train_set, test_set = random_split(data, [10000, 1000],
                                   generator=torch.Generator().manual_seed(1))
    train_data = DataLoader(train_set, batch_size=train_set.__len__())
    test_data = DataLoader(test_set, batch_size=test_set.__len__())

    # calcuate ate
    with torch.no_grad():
        for y, treat, x, _ in train_data:
            pred, prop_score = net.forward(x, treat)
            counter_fact, _ = net.forward(x, 1 - treat)
            outcome_contrast = torch.flatten(pred-counter_fact) * (2*treat - 1)
            prop_contrast = treat/prop_score - (1-treat)/(1-prop_score)
            pred_resid = torch.flatten(y - pred)
            ate_stonet = (outcome_contrast + prop_contrast * pred_resid).mean()
        eps_stonet_in[dgp - 1] = abs(ate_stonet - ate_true_in)
        print('stonet in-sample', eps_stonet_in[dgp - 1])

        for y, treat, x, _ in test_data:
            pred, prop_score = net.forward(x, treat)
            counter_fact, _ = net.forward(x, 1 - treat)
            outcome_contrast = torch.flatten(pred-counter_fact) * (2*treat - 1)
            prop_contrast = treat/prop_score - (1-treat)/(1-prop_score)
            pred_resid = torch.flatten(y - pred)
            ate_stonet = (outcome_contrast + prop_contrast * pred_resid).mean()
        eps_stonet_out[dgp - 1] = abs(ate_stonet - ate_true_out)
        print('stonet out-of-sample', eps_stonet_out[dgp - 1])

########################################################################################################################
# save the results
dr = dict(error_in=eps_dr_in.mean(), sd_in=np.sqrt(eps_dr_in.var()/10),
          error_out=eps_dr_out.mean(), sd_out=np.sqrt(eps_dr_out.var()/10))
r = dict(error_in=eps_r_in.mean(), sd_in=np.sqrt(eps_r_in.var()/10),
         error_out=eps_r_out.mean(), sd_out=np.sqrt(eps_r_out.var()/10))
x = dict(error_in=eps_x_in.mean(), sd_in=np.sqrt(eps_x_in.var()/10),
         error_out=eps_x_out.mean(), sd_out=np.sqrt(eps_x_out.var()/10))
dragon = dict(error_in=eps_dragon_in.mean(), sd_in=np.sqrt(eps_dragon_in.var()/10),
              error_out=eps_dragon_out.mean(), sd_out=np.sqrt(eps_dragon_out.var()/10))
stonet = dict(error_in=eps_stonet_in.mean(), sd_in=np.sqrt(eps_stonet_in.var()/10),
              error_out=eps_stonet_out.mean(), sd_out=np.sqrt(eps_stonet_out.var()/10))

ate_result = dict(dr=dr, r=r, x=x, dragon=dragon, stonet=stonet)

ate_result

# save the results to json file
ate_file = open('result/bench_sim.json', "w")
json.dump(ate_result, ate_file, indent="")
ate_file.close()
