from data_simulation import *
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import sys
sys.path.append("..")
import torch
import random
torch.manual_seed(0)
from engression import engression
from model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import warnings
from scipy.stats import ks_2samp,cramervonmises_2samp

# Suppress FutureWarning
warnings.simplefilter(action='ignore', category=FutureWarning)

from tqdm import tqdm
from scipy.stats import ttest_1samp
import copy
import pickle



def calculate_p_value(est_ate_list, true_ate_test):
    # Perform a one-sample t-test to test if the mean of est_ate_list equals true_ate_test
    t_stat, p_value = ttest_1samp(est_ate_list, true_ate_test)

    return p_value
    
def ate_testing(n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b,
                    n_bootstrap=500,seed = 42):
    """Perform bootstrapping to compute bootstrapped ATEs and p-value."""
    # bootstrap_ates={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    bootstrap_ates={'TARNet':[], 'CausalForest':[], 'S_BART':[], 'T_BART':[], 'S_engression':[], 'T_engression':[]}
    p_values={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    # p_values={'TARNet':[], 'CausalForest':[], 'S_BART':[],'S_engression':[]}
    # the order: TARNEt, CausalForestDML, TMLE, S_engression, T_engression, BART
    # true_ate_test = np.mean(sim_Y_test[sim_X_test==1]) - np.mean(sim_Y_test[sim_X_test==0])
    true_ate_test = 2

    # Run n_bootstrap iterations
    for b in range(n_bootstrap):
        # Bootstrap resample from the train data
        train_data, test_data_b = generate_train_test_data(
                    n_samples_train, n_samples_test_b, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params,seed=b)
        sim_Z_train,sim_X_train, sim_Y_train = train_data[['Z1','Z2']].values, train_data['X'].values, train_data['Y'].values
        sim_Z_test_b,sim_X_test_b, sim_Y_test_b = test_data_b[['Z1','Z2']].values, test_data_b['X'].values, test_data_b['Y'].values

    
        tarnet = TARNet_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        causal_forest = CausalForestDML_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_bart = S_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        t_bart_control, t_bart_treated = T_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_engressor = S_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train,num_epochs=200)
        t_engressor_control, t_engressor_treated = T_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train,num_epochs=200)
        
        bootstrap_ates['TARNet'].append(np.mean(TARNet_predict(Z_test = sim_Z_test_b, tarnet=tarnet)[0]))
        bootstrap_ates['CausalForest'].append(np.mean(CausalForestDML_predict(Z_test = sim_Z_test_b, causal_forest=causal_forest)[0]))
        bootstrap_ates['S_BART'].append(np.mean(S_BART_predict(Z_test = sim_Z_test_b, bart_model=s_bart)[0]))
        bootstrap_ates['T_BART'].append(np.mean(T_BART_predict(Z_test = sim_Z_test_b, bart_model_control=t_bart_control, bart_model_treatment=t_bart_treated)[0]))
        bootstrap_ates['S_engression'].append(np.mean(S_engression_predict(Z_test = sim_Z_test_b, engressor=s_engressor)[0]))
        bootstrap_ates['T_engression'].append(np.mean(T_engression_predict(Z_test = sim_Z_test_b, engressor_control=t_engressor_control, engressor_treated=t_engressor_treated)[0]))
    for key in bootstrap_ates.keys():
        p_values[key] = calculate_p_value(est_ate_list=bootstrap_ates[key], true_ate_test=true_ate_test)
    return p_values, bootstrap_ates


# setting 1
n_samples_train = 200
n_samples_test= 50
eta = 0.1
num_cont_covariates = 2
num_disc_covariates = 0
num_covariates = num_cont_covariates + num_disc_covariates
test_cont_margin_params = [4, 1]
test_cont_margin_family = 'gamma' 
test_disc_margin_params = [1, 0.5]
test_disc_margin_family = 'bernoulli'
train_cont_margin_params = [1, 1]
train_cont_margin_family = 'gamma'
train_disc_margin_params = [1, 0.75]
train_disc_margin_family = 'bernoulli'
treatment_family = 'bernoulli' 
prop_score_params = [0] * (num_covariates + 1)
causal_effect_family = 'gaussian'
causal_effect_params = [1, 2 , 1]
random.seed(1024)
corr_matrix =  np.array([[1.0, 0, 0.1],
                        [0, 1.0, 0.9],
                        [0.1, 0.9, 1.0]])

n_samples_train = 200
n_samples_test = 50
n_samples_test_b = 50

p_values_track2 = {'TARNet': [], 'CausalForest': [], 'S_BART': [], 'T_BART': [], 'S_engression': [], 'T_engression': []}
# p_values_track2 = {'TARNet': [], 'CausalForest': [], 'S_BART': [],'S_engression': []}
# train_df, test_df = process_data(path='', trial=train_idx[2]), process_data(path='', trial=test_idx[2])
# test_df['Z1']*=1.5
for i in tqdm(range(50), desc='Trial loop'):    
    # Track progress inside the ate_testing function
    p_values, bootstrap_ates = ate_testing(
        n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b=n_samples_test_b,
                    n_bootstrap=200,seed = 400)
    
    for key in p_values_track2.keys():
        p_values_track2[key].append(p_values[key])
print(p_values_track2)

# Assuming you already have p_values_track populated with lists of p-values
# Create a pandas DataFrame from the p_values_track dictionary
p_values_df2 = pd.DataFrame(p_values_track2)


# Plot the boxplots of p-values for each model
plt.figure(figsize=(10, 6))
sns.boxplot(data=p_values_df2)
plt.title('Boxplot of p-values for Different Models')
plt.xlabel('Model')
plt.ylabel('p-value')
plt.xticks(rotation=45)  # Rotate the x-axis labels if needed
plt.show()


# setting 2
test_cont_margin_params = [4, 1]
test_cont_margin_family = 'gamma' 
test_disc_margin_params = [1, 0.5]
test_disc_margin_family = 'bernoulli'
train_cont_margin_params = [1, 1]
train_cont_margin_family = 'gamma'
train_disc_margin_params = [1, 0.75]
train_disc_margin_family = 'bernoulli'
treatment_family = 'bernoulli' 
prop_score_params = [0] * (num_covariates + 1)
causal_effect_family = 'gaussian'
causal_effect_params = [1, 2 , 1]
random.seed(1024)
corr_matrix =  np.array([[1.0, 0, 0.1],
                        [0, 1.0, 0.9],
                        [0.1, 0.9, 1.0]])

n_samples_train = 200
n_samples_test = 50
n_samples_test_b = 50


p_values_track = {'TARNet': [], 'CausalForest': [], 'S_BART': [], 'T_BART': [], 'S_engression': [], 'T_engression': []}
# p_values_track = {'TARNet': [], 'CausalForest': [], 'S_BART': [],'S_engression': []}
# train_df, test_df = process_data(path='', trial=train_idx[2]), process_data(path='', trial=test_idx[2])
# test_df['Z1']*=1.5
for i in tqdm(range(50), desc='Trial loop'):    
    # Track progress inside the ate_testing function
    p_values, bootstrap_ates = ate_testing(
        n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b=n_samples_test_b,
                    n_bootstrap=200,seed = 400)
    
    for key in p_values_track.keys():
        p_values_track[key].append(p_values[key])
print(p_values_track)
with open('synthetic_setting2_mean.pkl', 'wb') as f:
    pickle.dump(p_values_track, f)
# Assuming you already have p_values_track populated with lists of p-values
# Create a pandas DataFrame from the p_values_track dictionary
p_values_df = pd.DataFrame(p_values_track)


# Plot the boxplots of p-values for each model
plt.figure(figsize=(10, 6))
sns.boxplot(data=p_values_df)
plt.title('Boxplot of p-values for Different Models')
plt.xlabel('Model')
plt.ylabel('p-value')
plt.xticks(rotation=45)  # Rotate the x-axis labels if needed
plt.show()


## CATE mean testing
# Combine the data into a DataFrame for easier plotting
data = {
    'p-value': p_values_track2['TARNet'] + p_values_track2['CausalForest'] + p_values_track2['S_BART'] + p_values_track2['T_BART'] +  p_values_track2['S_engression'] + p_values_track2['T_engression']+
               p_values_track['TARNet'] + p_values_track['CausalForest'] + p_values_track['S_BART'] + p_values_track['T_BART'] + p_values_track['S_engression'] + p_values_track['T_engression'],
    
    'Model': ['TARNet'] * len(p_values_track['TARNet']) + ['CausalForest'] * len(p_values_track['CausalForest']) + 
                 ['S_BART'] * len(p_values_track['S_BART']) +['T_BART'] * len(p_values_track['T_BART']) + ['S_engression'] * len(p_values_track['S_engression']) + ['T_engression'] * len(p_values_track['T_engression']) +
                 ['TARNet'] * len(p_values_track2['TARNet']) + ['CausalForest'] * len(p_values_track2['CausalForest']) + 
                 ['S_BART'] * len(p_values_track2['S_BART']) + ['T_BART'] * len(p_values_track2['T_BART']) + ['S_engression'] * len(p_values_track2['S_engression']) + ['T_engression'] * len(p_values_track2['T_engression']),
    
    'Setting': ['Setting 1'] * (len(p_values_track2['TARNet']) + len(p_values_track2['CausalForest']) + 
                len(p_values_track2['S_BART']) + len(p_values_track2['T_BART']) + len(p_values_track2['S_engression']) + len(p_values_track2['T_engression']))+
               ['Setting 2'] * (len(p_values_track['TARNet']) + len(p_values_track['CausalForest']) + 
                len(p_values_track['S_BART']) + len(p_values_track['T_BART']) + len(p_values_track['S_engression']) + len(p_values_track['T_engression']))
}

df = pd.DataFrame(data)

# Create the box plot using seaborn
plt.figure(figsize=(10, 6))
sns.boxplot(x='Model', y='p-value', hue='Setting', data=df, palette='coolwarm')

# Set labels and title
plt.title('p-values of CATE testing', fontsize=24)
plt.xlabel('Model', fontsize=22)
plt.ylabel('p-value', fontsize=20)
# Adjust the tick label size for both x and y axes
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.legend(title='Setting', fontsize=16, title_fontsize=18) 
plt.xticks(rotation=45)  # Rotate the x-axis labels if needed
# Display the plot
plt.tight_layout()
plt.show()


## Varying n
def ate_testing_varying_n(n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b,
                    n_bootstrap=500,seed = 42, n_train_values=[40, 100, 200, 400]):
    """
    Perform bootstrapping to compute bootstrapped ATEs and p-value with varying training set sizes.
    Only T-BART and T-Engression models are used, and results are split into separate groups.
    """
    # Initialize separate dictionaries to store p-values for T-BART and T-Engression
    p_values_tbart = {n_train: [] for n_train in n_train_values}
    p_values_tengression = {n_train: [] for n_train in n_train_values}

    for n_train in n_train_values:
        # Compute true ATE for the test set
        true_ate_test = 2

        # Fit T-BART and T-Engression models


        # Bootstrap to estimate ATE for each model
        bootstrap_ates_tbart = []
        bootstrap_ates_tengression = []

        for b in range(n_bootstrap):
            # Bootstrap resample from the test data
            # Generate train and test data with varying n_train size
            train_data, test_data_b = generate_train_test_data(
                        n_train, n_samples_test_b, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                        train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                        test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                        prop_score_params, causal_effect_family, causal_effect_params,seed=b)
                        
            sim_Z_train,sim_X_train, sim_Y_train = train_data[['Z1','Z2']].values, train_data['X'].values, train_data['Y'].values
            sim_Z_test_b,sim_X_test_b, sim_Y_test_b = test_data_b[['Z1','Z2']].values, test_data_b['X'].values, test_data_b['Y'].values
            
            t_bart_control, t_bart_treated = T_BART_fit(Z_train=sim_Z_train, Y_train=sim_Y_train, X_train=sim_X_train)
            
            t_engressor_control, t_engressor_treated = T_engression_fit(Z_train=sim_Z_train, Y_train=sim_Y_train, X_train=sim_X_train, num_epochs = 200)
            
            # Predict ATE using T-BART
            tbart_pred = np.mean(T_BART_predict(Z_test=sim_Z_test_b, bart_model_control=t_bart_control, bart_model_treatment=t_bart_treated)[0])
            bootstrap_ates_tbart.append(tbart_pred)
            
            # Predict ATE using T-Engression
            tengression_pred = np.mean(T_engression_predict(Z_test=sim_Z_test_b, engressor_control=t_engressor_control, engressor_treated=t_engressor_treated)[0])
            bootstrap_ates_tengression.append(tengression_pred)

        # Calculate p-values for T-BART and T-Engression
        p_values_tbart[n_train] = calculate_p_value(est_ate_list=bootstrap_ates_tbart, true_ate_test=true_ate_test)
        p_values_tengression[n_train] = calculate_p_value(est_ate_list=bootstrap_ates_tengression, true_ate_test=true_ate_test)

    return p_values_tbart, p_values_tengression

# Tracking the outer loop over the idx trials using tqdm
n_train_values=[40, 100, 200, 400]
p_values_tbart_track = {n_train: [] for n_train in n_train_values}
p_values_tengression_track = {n_train: [] for n_train in n_train_values}
n_samples_test = 50
n_samples_test_b = 50
# train_df, test_df = process_data(path='', trial=train_idx[2]), process_data(path='', trial=test_idx[2])

for i in tqdm(range(50), desc='Trial loop'):    
    # Track progress inside the ate_testing function
    
    p_values_tbart, p_values_tengression = ate_testing_varying_n(
        n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
                    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
                    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
                    prop_score_params, causal_effect_family, causal_effect_params, n_samples_test_b,
                    n_bootstrap=200,seed = 42, n_train_values=[40,100, 200, 400])
    
    for key in p_values_tengression_track.keys():
        p_values_tbart_track[key].append(p_values_tbart[key])
        p_values_tengression_track[key].append(p_values_tengression[key])


p_values_tbart_df = pd.DataFrame(p_values_tbart_track)
p_values_tengression_df = pd.DataFrame(p_values_tengression_track)

# Melt the dataframes into long format for easier plotting
tbart_long = pd.melt(p_values_tbart_df.reset_index(), id_vars='index', var_name='n_train', value_name='p_value')
tengression_long = pd.melt(p_values_tengression_df.reset_index(), id_vars='index', var_name='n_train', value_name='p_value')

# Add a column to distinguish between T-BART and T-Engression
tbart_long['Model'] = 'T-BART'
tengression_long['Model'] = 'T-Engression'

# Combine the two melted DataFrames
combined_df = pd.concat([tbart_long, tengression_long])

# Create the horizontal boxplot
plt.figure(figsize=(12, 8))
sns.boxplot(x='p_value', y='n_train', hue='Model', data=combined_df, palette='coolwarm', orient='h')

# Add titles and labels
plt.title('Boxplot of p-values for Different Training Set Sizes', fontsize=22)
plt.xlabel('p-value', fontsize=20)
plt.ylabel('Training Set Size', fontsize=20)

# Adjust the tick label size for both x and y axes
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=16, title_fontsize=18) 

# Display the plot
plt.tight_layout()
plt.show()


## Distributional Testing
# Example parameters for generating test data
n_samples_train = 200
n_samples_test = 50
eta = 0.1
num_cont_covariates = 2
num_disc_covariates = 0
num_covariates = num_cont_covariates + num_disc_covariates
test_cont_margin_params = [4, 1]
test_cont_margin_family = 'gamma' 
test_disc_margin_params = [1, 0.5]
test_disc_margin_family = 'bernoulli'
train_cont_margin_params = [1, 1]
train_cont_margin_family = 'gamma'
train_disc_margin_params = [1, 0.75]
train_disc_margin_family = 'bernoulli'
treatment_family = 'bernoulli' 
prop_score_params = [0] * (num_covariates + 1)
causal_effect_family = 'gaussian'
causal_effect_params = [1, 2 , 1]
random.seed(1024)
corr_matrix =  np.array([[1.0, 0, 0.1],
                        [0, 1.0, 0.9],
                        [0.1, 0.9, 1.0]])
# Generate train and test data
train_data, test_data = generate_train_test_data(
    n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
    prop_score_params, causal_effect_family, causal_effect_params,seed=42)

p0_track_1 = []
p1_track_1 = []
p_track_1 = []
num_sim = 200
for s in range(num_sim):
    # Generate train and test data
    train_data, test_data = generate_train_test_data(
    n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
    prop_score_params, causal_effect_family, causal_effect_params,seed=s)

    x_train, y_train = train_data[['X','Z1','Z2']], train_data['Y']
    x_train =torch.Tensor(x_train.to_numpy()).unsqueeze(2).to(device)
    y_train =torch.Tensor(y_train.to_numpy()).unsqueeze(1).to(device)
    ## Fit an engression model
    engressor = engression(x=x_train, y=y_train, lr=0.01, num_epochs=200, batch_size=100, device=device, verbose=False)

    x_test, y_test = test_data[['X','Z1','Z2']], test_data['Y']
    x_test =torch.Tensor(x_test.to_numpy()).unsqueeze(2).to(device)
    y_test =torch.Tensor(y_test.to_numpy()).unsqueeze(1).to(device)
    # prediction
    y_pred = engressor.sample(x_test,sample_size=100)
    marginal_pred_y = y_pred.view(-1).cpu().numpy()
    ks_stat, p_value = ks_2samp(test_data['Y'].values, marginal_pred_y)
    p_track_1.append(p_value)

    test_data_po1, y_pred_po1 = slice_data_by_x(test_data, y_pred,x_value=1)
    test_data_po0, y_pred_po0 = slice_data_by_x(test_data, y_pred,x_value=0)
    marginal_pred_y_po1 = y_pred_po1.view(-1).cpu().numpy()
    marginal_pred_y_po0 = y_pred_po0.view(-1).cpu().numpy()
    ks_stat, p_value0 = ks_2samp(test_data_po0['Y'].values, marginal_pred_y_po0)
    ks_stat, p_value1 = ks_2samp(test_data_po1['Y'].values, marginal_pred_y_po1)
    p0_track_1.append(p_value0)
    p1_track_1.append(p_value1)


n_samples_train = 200
n_samples_test = 50
eta = 0.1
num_cont_covariates = 2
num_disc_covariates = 0
num_covariates = num_cont_covariates + num_disc_covariates
test_cont_margin_params = [2, 1]
test_cont_margin_family = 'gamma' 
test_disc_margin_params = [1, 0.5]
test_disc_margin_family = 'bernoulli'
train_cont_margin_params = [1, 1]
train_cont_margin_family = 'gamma'
train_disc_margin_params = [1, 0.75]
train_disc_margin_family = 'bernoulli'
treatment_family = 'bernoulli' 
prop_score_params = [0] * (num_covariates + 1)
causal_effect_family = 'gaussian'
causal_effect_params = [1, 2 , 1]
random.seed(1024)
corr_matrix =  np.array([[1.0, 0, 0.1],
                        [0, 1.0, 0.9],
                        [0.1, 0.9, 1.0]])
p0_track_2 = []
p1_track_2 = []
p_track_2 = []
for s in range(num_sim):
    # Generate train and test data
    train_data, test_data = generate_train_test_data(
    n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
    train_cont_margin_family, train_disc_margin_params, train_disc_margin_family, test_cont_margin_params,
    test_cont_margin_family, test_disc_margin_params, test_disc_margin_family, treatment_family,
    prop_score_params, causal_effect_family, causal_effect_params,seed=s)

    x_train, y_train = train_data[['X','Z1','Z2']], train_data['Y']
    x_train =torch.Tensor(x_train.to_numpy()).unsqueeze(2).to(device)
    y_train =torch.Tensor(y_train.to_numpy()).unsqueeze(1).to(device)
    ## Fit an engression model
    engressor = engression(x=x_train, y=y_train, lr=0.01, num_epochs=200, batch_size=100, device=device, verbose=False)
    ## Summarize model information
    # engressor.summary()
    x_test, y_test = test_data[['X','Z1','Z2']], test_data['Y']
    x_test =torch.Tensor(x_test.to_numpy()).unsqueeze(2).to(device)
    y_test =torch.Tensor(y_test.to_numpy()).unsqueeze(1).to(device)
    # prediction
    y_pred = engressor.sample(x_test,sample_size=100)
    marginal_pred_y = y_pred.view(-1).cpu().numpy()
    ks_stat, p_value = ks_2samp(test_data['Y'].values, marginal_pred_y)
    p_track_2.append(p_value)

    test_data_po1, y_pred_po1 = slice_data_by_x(test_data, y_pred,x_value=1)
    test_data_po0, y_pred_po0 = slice_data_by_x(test_data, y_pred,x_value=0)
    marginal_pred_y_po1 = y_pred_po1.view(-1).cpu().numpy()
    marginal_pred_y_po0 = y_pred_po0.view(-1).cpu().numpy()
    ks_stat, p_value0 = ks_2samp(test_data_po0['Y'].values, marginal_pred_y_po0)
    ks_stat, p_value1 = ks_2samp(test_data_po1['Y'].values, marginal_pred_y_po1)
    p0_track_2.append(p_value0)
    p1_track_2.append(p_value1)

data = {
    'p-value':  p_track_2 + p0_track_2 + p1_track_2 + p_track_1 + p0_track_1 + p1_track_1,
    'Marginal Distribution':
                [r'$P_Y$'] * len(p_track_2) + [r'$P_{Y(0)}$'] * len(p0_track_2) + [r'$P_{Y(1)}$'] * len(p1_track_2)+[r'$P_Y$'] * len(p_track_1) + [r'$P_{Y(0)}$'] * len(p0_track_1) + [r'$P_{Y(1)}$'] * len(p1_track_1) ,
    'Setting': 
             ['Setting 1'] * (len(p_track_2) + len(p0_track_2) + len(p1_track_2)) + ['Setting 2'] * (len(p_track_1) + len(p0_track_1) + len(p1_track_1))
}

df = pd.DataFrame(data)

# Create the box plot using seaborn
plt.figure(figsize=(10, 6))
sns.boxplot(x='Marginal Distribution', y='p-value', hue='Setting', data=df, palette='coolwarm')

# Set labels and title
plt.title(r'p-values of Testing $P_Y$, $P_{Y(0)}$, $P_{Y(1)}$', fontsize=24)
plt.xlabel('Marignal Distribution', fontsize=22)
plt.ylabel('p-value', fontsize=20)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.legend(title='Setting', fontsize=16, title_fontsize=18) 

# Display the plot
plt.tight_layout()
plt.show()


