from rwd_inference import *
from data_utils import *
import matplotlib.pyplot as plt
import seaborn as sns
import sys
sys.path.append("..")
import torch
torch.manual_seed(0)
from model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import warnings
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import copy
from scipy.stats import ks_2samp, ttest_1samp, cramervonmises_2samp
import pickle

# Suppress FutureWarning
warnings.simplefilter(action='ignore', category=FutureWarning)


# split the trials into two parts - one used as training domain and the other as test
train_idx, test_idx = train_test_split(np.arange(1000), train_size=500, test_size=500, random_state=42)

def calculate_p_value(est_ate_list, true_ate_test):
    # Perform a one-sample t-test to test if the mean of est_ate_list equals true_ate_test
    t_stat, p_value = ttest_1samp(est_ate_list, true_ate_test)

    return p_value

## RCT 
def ate_testing(train_df, test_df, n_samples_train, n_samples_test_b, n_samples_test = 10000, 
                n_bootstrap=500, marginal_cdf_seed = 2, training_data_seed = 1, test_data_seed = 2):
    """Perform bootstrapping to compute bootstrapped ATEs and p-value."""
    bootstrap_ates={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    p_values={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    # the order: TARNEt, CausalForestDML, TMLE, S_engression, T_engression, BART
    sim_Z_train, sim_Y_train, sim_X_train, sim_Z_test, sim_Y_test, sim_X_test = generate_generate_realistic_data(train_df=train_df, test_df=test_df, n_samples_test=n_samples_test, n_samples_train= n_samples_train, marginal_cdf_seed = marginal_cdf_seed, training_data_seed = training_data_seed, test_data_seed = test_data_seed)
    true_ate_test = np.mean(sim_Y_test[sim_X_test==1]) - np.mean(sim_Y_test[sim_X_test==0])
    

    # Run n_bootstrap iterations
    for b in range(n_bootstrap):
        # Bootstrap resample from the train data
        sim_Z_train, sim_Y_train, sim_X_train, sim_Z_test_b, sim_Y_test_b, sim_X_test_b = generate_generate_realistic_data(train_df=train_df, test_df=test_df, n_samples_test=n_samples_test_b, n_samples_train=n_samples_train, marginal_cdf_seed=marginal_cdf_seed, training_data_seed=b, test_data_seed = b) 
        tarnet = TARNet_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        causal_forest = CausalForestDML_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_bart = S_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        t_bart_control, t_bart_treated = T_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_engressor = S_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train, batch_size=200,)
        t_engressor_control, t_engressor_treated = T_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train,batch_size=200)
        
        bootstrap_ates['TARNet'].append(np.mean(TARNet_predict(Z_test = sim_Z_test_b, tarnet=tarnet)[0]))
        bootstrap_ates['CausalForest'].append(np.mean(CausalForestDML_predict(Z_test = sim_Z_test_b, causal_forest=causal_forest)[0]))
        bootstrap_ates['S_BART'].append(np.mean(S_BART_predict(Z_test = sim_Z_test_b, bart_model=s_bart)[0]))
        bootstrap_ates['T_BART'].append(np.mean(T_BART_predict(Z_test = sim_Z_test_b, bart_model_control=t_bart_control, bart_model_treatment=t_bart_treated)[0]))
        bootstrap_ates['S_engression'].append(np.mean(S_engression_predict(Z_test = sim_Z_test_b, engressor=s_engressor)[0]))
        bootstrap_ates['T_engression'].append(np.mean(T_engression_predict(Z_test = sim_Z_test_b, engressor_control=t_engressor_control, engressor_treated=t_engressor_treated)[0]))
    for key in bootstrap_ates.keys():
        p_values[key] = calculate_p_value(est_ate_list=bootstrap_ates[key], true_ate_test=true_ate_test)
    return p_values

# Tracking the outer loop over the idx trials using tqdm
p_values_track = {'TARNet': [], 'CausalForest': [], 'S_BART': [], 'T_BART': [], 'S_engression': [], 'T_engression': []}
# train_df, test_df = process_data(path='', trial=train_idx[2]), process_data(path='', trial=test_idx[2])

for i in tqdm(range(50), desc='Trial loop'):    
    # Track progress inside the ate_testing function
    train_df, test_df = process_data(path='', trial=train_idx[i]), process_data(path='', trial=train_idx[i])
    test_df['Z1']*=1.5
    p_values = ate_testing(
        train_df=train_df, test_df=test_df, n_samples_train=1000, n_samples_test_b=200, 
        n_samples_test=200, n_bootstrap=200, marginal_cdf_seed=11, training_data_seed=11, test_data_seed=11)
    
    for key in p_values_track.keys():
        p_values_track[key].append(p_values[key])

with open('ihdp_mean_p.pkl', 'wb') as f:
    pickle.dump(p_values_track, f)

## Non-RCT
def ate_testing_obs(train_df, test_df, n_samples_train, n_samples_test_b, n_samples_test = 10000, 
                n_bootstrap=500, marginal_cdf_seed = 2, training_data_seed = 1, test_data_seed = 2, prop_score_params=None):
    """Perform bootstrapping to compute bootstrapped ATEs and p-value."""
    bootstrap_ates={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    p_values={'TARNet':[], 'CausalForest':[], 'S_BART':[],'T_BART':[], 'S_engression':[],'T_engression':[]}
    # the order: TARNEt, CausalForestDML, TMLE, S_engression, T_engression, BART
    sim_Z_train, sim_Y_train, sim_X_train, sim_Z_test, sim_Y_test, sim_X_test = generate_generate_realistic_data(train_df=train_df, test_df=test_df, n_samples_test=n_samples_test, n_samples_train= n_samples_train, marginal_cdf_seed = marginal_cdf_seed, training_data_seed = training_data_seed, test_data_seed = test_data_seed, prop_score_params = prop_score_params)
    true_ate_test = np.mean(sim_Y_test[sim_X_test==1]) - np.mean(sim_Y_test[sim_X_test==0])

    # Run n_bootstrap iterations
    for b in range(n_bootstrap):
        # Bootstrap resample from the train data
        
        sim_Z_train, sim_Y_train, sim_X_train, sim_Z_test_b, sim_Y_test_b, sim_X_test_b = generate_generate_realistic_data(train_df=train_df, test_df=test_df, n_samples_test=n_samples_test_b, n_samples_train=n_samples_train, marginal_cdf_seed=marginal_cdf_seed, training_data_seed=training_data_seed, test_data_seed = b, prop_score_params = prop_score_params) 
        
        
        tarnet = TARNet_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        causal_forest = CausalForestDML_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_bart = S_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        t_bart_control, t_bart_treated = T_BART_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
        s_engressor = S_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train, batch_size=200)
        t_engressor_control, t_engressor_treated = T_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train,batch_size=200)
        bootstrap_ates['TARNet'].append(np.mean(TARNet_predict(Z_test = sim_Z_test_b, tarnet=tarnet)[0]))
        bootstrap_ates['CausalForest'].append(np.mean(CausalForestDML_predict(Z_test = sim_Z_test_b, causal_forest=causal_forest)[0]))
        bootstrap_ates['S_BART'].append(np.mean(S_BART_predict(Z_test = sim_Z_test_b, bart_model=s_bart)[0]))
        bootstrap_ates['T_BART'].append(np.mean(T_BART_predict(Z_test = sim_Z_test_b, bart_model_control=t_bart_control, bart_model_treatment=t_bart_treated)[0]))
        bootstrap_ates['S_engression'].append(np.mean(S_engression_predict(Z_test = sim_Z_test_b, engressor=s_engressor)[0]))
        bootstrap_ates['T_engression'].append(np.mean(T_engression_predict(Z_test = sim_Z_test_b, engressor_control=t_engressor_control, engressor_treated=t_engressor_treated)[0]))
    for key in bootstrap_ates.keys():
        p_values[key] = calculate_p_value(est_ate_list=bootstrap_ates[key], true_ate_test=true_ate_test)
    return p_values

# Tracking the outer loop over the idx trials using tqdm
p_values_track_obs = {'TARNet': [], 'CausalForest': [], 'S_BART': [], 'T_BART': [], 'S_engression': [], 'T_engression': []}
# train_df, test_df = process_data(path='', trial=train_idx[2]), process_data(path='', trial=test_idx[2])
prop_score_params = [0] * 26
prop_score_params[2]=1
prop_score_params[3]=1
prop_score_params[4]=1
for i in tqdm(range(50), desc='Trial loop'):    
    # Track progress inside the ate_testing function
    train_df, test_df = process_data(path='', trial=train_idx[i]), process_data(path='', trial=train_idx[i])
    test_df['Z1']*=1.5
    p_values = ate_testing_obs(
        train_df=train_df, test_df=test_df, n_samples_train=1000, n_samples_test_b=200, 
        n_samples_test=200, n_bootstrap=200, marginal_cdf_seed=11, training_data_seed=11, test_data_seed=11, prop_score_params = prop_score_params)
    
    for key in p_values_track_obs.keys():
        p_values_track_obs[key].append(p_values[key])

with open('ihdp_mean_p_obs.pkl', 'wb') as f:
    pickle.dump(p_values_track_obs, f)


## Distributional Regression test

def slice_data_by_x(sim_Y_test,  sim_X_test, x_value):
    if not torch.is_tensor(sim_X_test):
        sim_X_test = torch.tensor(sim_X_test)

    # Get indices where sim_X_test equals x_value (0 or 1)
    indices = (sim_X_test == x_value).nonzero(as_tuple=True)[0]
    
    # Slice sim_Y_test and hat_p_samples using these indices
    sim_Y_sliced = sim_Y_test[indices]
    
    return sim_Y_sliced


p0_track_s=[]
p1_track_s=[]
p_track_s=[]

p0_track_t=[]
p1_track_t=[]
p_track_t=[]

n_samples_train =500
n_samples_test = 100
n_samples_test_b=50

num_sim = 50

for s in range(num_sim):
    device = 'cpu'
    # Generate train and test data
    # Track progress inside the ate_testing function
    train_df, test_df = process_data(path='', trial=train_idx[s]), process_data(path='', trial=train_idx[s])
    test_df['Z1']*=1.5

    sim_Z_train, sim_Y_train, sim_X_train, sim_Z_test, sim_Y_test, sim_X_test = generate_generate_realistic_data(train_df=train_df, 
                                    test_df=test_df, n_samples_test=n_samples_test, n_samples_train= n_samples_train, 
                                    marginal_cdf_seed = 42, training_data_seed = 42, test_data_seed = 42)

    s_engressor = S_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)
    t_engressor_control, t_engressor_treated = T_engression_fit(Z_train=sim_Z_train, Y_train = sim_Y_train, X_train = sim_X_train)

    Z_test_tensor = torch.Tensor(sim_Z_test).to(device)
    X_test_tensor = torch.Tensor(sim_X_test).unsqueeze(1).to(device)
    W_test_tensor = torch.cat([Z_test_tensor, X_test_tensor], dim=1).unsqueeze(2)

    # Create two test datasets: one with X=1 and one with X=0
    W_test_treated = torch.cat([Z_test_tensor, torch.ones(Z_test_tensor.size(0), 1).to(device)], dim=1).unsqueeze(2)
    W_test_control = torch.cat([Z_test_tensor, torch.zeros(Z_test_tensor.size(0), 1).to(device)], dim=1).unsqueeze(2)

    
    
    # S_engression: Predictions for treated (X=1) and control (X=0)
    y_pred_treated_s = s_engressor.sample(W_test_treated, sample_size=50).view(-1).cpu().numpy()
    y_pred_control_s = s_engressor.sample(W_test_control, sample_size=50).view(-1).cpu().numpy()

    # T_engression: 
    y_pred_treated_t = t_engressor_treated.sample(Z_test_tensor, sample_size=50).view(-1).cpu().numpy()
    y_pred_control_t = t_engressor_treated.sample(Z_test_tensor, sample_size=50).view(-1).cpu().numpy()

    # Actual test data Y split by X
    y_true_po1 = slice_data_by_x(torch.Tensor(sim_Y_test).unsqueeze(1),sim_X_test, x_value=1)
    y_true_po0 = slice_data_by_x(torch.Tensor(sim_Y_test).unsqueeze(1), sim_X_test,x_value=0)

    # tgr
    y_pred_s = s_engressor.sample(W_test_tensor,sample_size=50).view(-1).cpu().numpy()
    y_pred_t = np.concatenate([y_pred_treated_t,y_pred_control_t],axis = 0)

    # Reshape the true y values for comparison
    y_true_po1 = y_true_po1.view(-1).cpu().numpy()
    y_true_po0 = y_true_po0.view(-1).cpu().numpy()
    y_true = sim_Y_test


    # Evaluate the distribution shift using statistical tests for X=1 (treated) and X=0 (control)
    _, p_value_po1_s = ks_2samp(y_true_po1, y_pred_treated_s)
    _, p_value_po0_s = ks_2samp(y_true_po0, y_pred_control_s)
    _, p_value_po_s = ks_2samp(sim_Y_test, y_pred_s)
   # cvm_result_po1 = cramervonmises_2samp(y_true_po1, y_pred_treated)
    _, p_value_po1_t = ks_2samp(y_true_po1, y_pred_treated_t)
    _, p_value_po0_t = ks_2samp(y_true_po0, y_pred_control_t)
    _, p_value_po_t = ks_2samp(sim_Y_test, y_pred_t)

    # Append p-values to the respective lists

    p0_track_s.append(p_value_po0_s)
    p1_track_s.append(p_value_po1_s)
    p_track_s.append(p_value_po_s)

    p0_track_t.append(p_value_po0_t)
    p1_track_t.append(p_value_po1_t)
    p_track_t.append(p_value_po_t)



# Combine the data into a DataFrame for easier plotting
data = {
    'p-value':  p_track_s + p0_track_s + p1_track_s + p_track_t + p0_track_t + p1_track_t,
    'Marginal Distribution':
                [r'$P_Y$'] * len(p_track_s) + [r'$P_{Y(0)}$'] * len(p0_track_s) + [r'$P_{Y(1)}$'] * len(p1_track_s)+[r'$P_Y$'] * len(p_track_t) + [r'$P_{Y(0)}$'] * len(p0_track_t) + [r'$P_{Y(1)}$'] * len(p1_track_t) ,
    'Model': 
             ['S_engression'] * (len(p_track_s) + len(p0_track_s) + len(p1_track_s)) + ['T_engression'] * (len(p_track_t) + len(p0_track_t) + len(p1_track_t))
}

df = pd.DataFrame(data)

# Create the box plot using seaborn
plt.figure(figsize=(10, 6))
sns.boxplot(x='Marginal Distribution', y='p-value', hue='Model', data=df, palette='coolwarm')

# Set labels and title
plt.title(r'p values of testing $P_Y$, $P_{Y(0)}$, $P_{Y(1)}$', fontsize=24)
plt.xlabel('Marignal Distribution', fontsize=22)
plt.ylabel('p-value', fontsize=22)

# Adjust the tick label size for both x and y axes
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=16, title_fontsize=18) 

# Display the plot
plt.tight_layout()
plt.show()



