from data_simulation import *
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import sys
sys.path.append("..")
import torch
import random
torch.manual_seed(0)
from engression import engression
from model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import warnings
from scipy.stats import ks_2samp,cramervonmises_2samp

# Suppress FutureWarning
warnings.simplefilter(action='ignore', category=FutureWarning)

from tqdm import tqdm
from scipy.stats import ttest_1samp
import copy
import pickle

## no domain shift
from tqdm import tqdm
from scipy.stats import percentileofscore
import copy

def calculate_p_value(est_ate_list, true_ate_test):
    # Perform a one-sample t-test to test if the mean of est_ate_list equals true_ate_test
    t_stat, p_value = ttest_1samp(est_ate_list, true_ate_test)

    return p_value
    
def ate_testing(n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b,
                    n_bootstrap=500,seed = 42):
    """Perform bootstrapping to compute bootstrapped ATEs and p-value."""
    # bootstrap_ates={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    bootstrap_ates={'TARNet':[], 'CausalForest':[], 'S_BART':[], 'T_BART':[], 'S_engression':[], 'T_engression':[]}
    p_values={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    # p_values={'TARNet':[], 'CausalForest':[], 'S_BART':[],'S_engression':[]}
    # the order: TARNEt, CausalForestDML, TMLE, S_engression, T_engression, BART
    # true_ate_test = np.mean(sim_Y_test[sim_X_test==1]) - np.mean(sim_Y_test[sim_X_test==0])
    true_ate_test = 2

    # Run n_bootstrap iterations
    for b in range(n_bootstrap):
        # Bootstrap resample from the train data
        train_data, test_data_b = generate_train_test_data(
                    n_samples_train, n_samples_test_b, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params,seed=b)
        sim_Z_train,sim_X_train, sim_Y_train = train_data[['Z1','Z2']].values, train_data['X'].values, train_data['Y'].values
        sim_Z_test_b,sim_X_test_b, sim_Y_test_b = test_data_b[['Z1','Z2']].values, test_data_b['X'].values, test_data_b['Y'].values

    
        tarnet = TARNet_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        causal_forest = CausalForestDML_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_bart = S_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        t_bart_control, t_bart_treated = T_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_engressor = S_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train,num_epochs=200)
        t_engressor_control, t_engressor_treated = T_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train,num_epochs=200)
        
        bootstrap_ates['TARNet'].append(np.mean(TARNet_predict(Z_test = sim_Z_test_b, tarnet=tarnet)[0]))
        bootstrap_ates['CausalForest'].append(np.mean(CausalForestDML_predict(Z_test = sim_Z_test_b, causal_forest=causal_forest)[0]))
        bootstrap_ates['S_BART'].append(np.mean(S_BART_predict(Z_test = sim_Z_test_b, bart_model=s_bart)[0]))
        bootstrap_ates['T_BART'].append(np.mean(T_BART_predict(Z_test = sim_Z_test_b, bart_model_control=t_bart_control, bart_model_treatment=t_bart_treated)[0]))
        bootstrap_ates['S_engression'].append(np.mean(S_engression_predict(Z_test = sim_Z_test_b, engressor=s_engressor)[0]))
        bootstrap_ates['T_engression'].append(np.mean(T_engression_predict(Z_test = sim_Z_test_b, engressor_control=t_engressor_control, engressor_treated=t_engressor_treated)[0]))
    for key in bootstrap_ates.keys():
        p_values[key] = calculate_p_value(est_ate_list=bootstrap_ates[key], true_ate_test=true_ate_test)
    return p_values, bootstrap_ates



n_samples_train = 200
n_samples_test= 50
eta = 0.1
num_cont_covariates = 2
num_disc_covariates = 0
num_covariates = num_cont_covariates + num_disc_covariates
test_cont_margin_params = [1, 1]
test_cont_margin_family = 'gaussian' 
test_disc_margin_params = [1, 0.5]
test_disc_margin_family = 'bernoulli'
train_cont_margin_params = [1, 1]
train_cont_margin_family = 'gaussian'
train_disc_margin_params = [1, 0.75]
train_disc_margin_family = 'bernoulli'
treatment_family = 'bernoulli' 
prop_score_params = [0] * (num_covariates + 1)
causal_effect_family = 'gaussian'
causal_effect_params = [1, 2 , 1]
random.seed(42)
corr_matrix =  np.array([[1.0, 0, 0.1],
                        [0, 1.0, 0.9],
                        [0.1, 0.9, 1.0]])
n_samples_train = 200
n_samples_test = 50
n_samples_test_b = 50

p_values_track = {'S_Linear':[], 'T_Linear':[], 'TARNet': [], 'CausalForest': [], 'S_BART': [], 'T_BART': [], 'S_engression': [], 'T_engression': []}
# p_values_track = {'TARNet': [], 'CausalForest': [], 'S_BART': [],'S_engression': []}
# train_df, test_df = process_data(path='', trial=train_idx[2]), process_data(path='', trial=test_idx[2])
# test_df['Z1']*=1.5
for i in tqdm(range(20), desc='Trial loop'):    
    # Track progress inside the ate_testing function
    p_values, bootstrap_ates = ate_testing(
        n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b=n_samples_test_b,
                    n_bootstrap=200,seed = i)
    
    for key in p_values_track.keys():
        p_values_track[key].append(p_values[key])
print(p_values_track)

# Assuming you already have p_values_track populated with lists of p-values
# Create a pandas DataFrame from the p_values_track dictionary
p_values_df = pd.DataFrame(p_values_track)

# Calculate the average p-values for each model
average_p_values = p_values_df.mean()

# Print the average p-values
print("Average p-values for each method:")
print(average_p_values)

# Plot the boxplots of p-values for each model
plt.figure(figsize=(10, 6))
sns.boxplot(data=p_values_df)
plt.title('Boxplot of p-values for Different Models')
plt.xlabel('Model')
plt.ylabel('p-value')
plt.xticks(rotation=45)  # Rotate the x-axis labels if needed
plt.show()

## with domain shift
n_samples_train = 200
n_samples_test= 50
eta = 0.1
num_cont_covariates = 2
num_disc_covariates = 0
num_covariates = num_cont_covariates + num_disc_covariates
test_cont_margin_params = [3, 2]
test_cont_margin_family = 'gaussian' 
test_disc_margin_params = [1, 0.5]
test_disc_margin_family = 'bernoulli'
train_cont_margin_params = [1, 1]
train_cont_margin_family = 'gaussian'
train_disc_margin_params = [1, 0.75]
train_disc_margin_family = 'bernoulli'
treatment_family = 'bernoulli' 
prop_score_params = [0] * (num_covariates + 1)
causal_effect_family = 'gaussian'
causal_effect_params = [1, 2 , 1]
random.seed(42)
corr_matrix =  np.array([[1.0, 0, 0.1],
                        [0, 1.0, 0.9],
                        [0.1, 0.9, 1.0]])
n_samples_train = 100
n_samples_test = 10
n_samples_test_b = 10

p_values_track = {'S_Linear':[], 'T_Linear':[], 'TARNet': [], 'CausalForest': [], 'S_BART': [], 'T_BART': [], 'S_engression': [], 'T_engression': []}
# p_values_track = {'TARNet': [], 'CausalForest': [], 'S_BART': [],'S_engression': []}
# train_df, test_df = process_data(path='', trial=train_idx[2]), process_data(path='', trial=test_idx[2])
# test_df['Z1']*=1.5
for i in tqdm(range(50), desc='Trial loop'):    
    # Track progress inside the ate_testing function
    p_values, bootstrap_ates = ate_testing(
        n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b=n_samples_test_b,
                    n_bootstrap=200,seed = i)
    
    for key in p_values_track.keys():
        p_values_track[key].append(p_values[key])
print(p_values_track)

# Assuming you already have p_values_track populated with lists of p-values
# Create a pandas DataFrame from the p_values_track dictionary
p_values_df = pd.DataFrame(p_values_track)

# Calculate the average p-values for each model
average_p_values = p_values_df.mean()

# Print the average p-values
print("Average p-values for each method:")
print(average_p_values)

# Plot the boxplots of p-values for each model
plt.figure(figsize=(10, 6))
sns.boxplot(data=p_values_df)
plt.title('Boxplot of p-values for Different Models')
plt.xlabel('Model')
plt.ylabel('p-value')
plt.xticks(rotation=45)  # Rotate the x-axis labels if needed
plt.show()
