import numpy as np

def path_wise_dataset_1(num_samples = 500, seed = 0):
    np.random.seed(seed)
    features, prediction = [], []
    for _ in range(num_samples): 
        confounder = np.random.normal(loc=0.3, scale=0.5)
        treatment = np.random.binomial(n=1, p= np.clip(0.8 - confounder, 0, 1))
        mediator1 = 0.5 * treatment + np.random.normal(loc=0.0, scale=0.5)
        mediator2 = 0.7 * treatment + np.random.normal(loc=0.0, scale=0.5)
        outcome = confounder + treatment - 0.8 * mediator1 ** 2 - 0.5 * treatment * mediator1
        features.append([treatment, confounder, mediator1, mediator2])
        prediction.append(outcome)

    return np.array(features), np.array(prediction)


treatment_col_index = 0
confounder_col_index = 1
mediator1_col_index = 2
mediator2_col_index = 3

class ModelWrapper():
    def predict(self, data):
        results = []
        for sample in data:
            treatment = sample[0]
            confounder = sample[1]
            mediator1 = sample[2]
            mediator2 = sample[3]
            results.append(confounder + treatment - 0.8 * mediator1 ** 2 - 0.5 * treatment * mediator1)
        return np.array(results)

def calculate_true_cate_but_mediator2(sample):
    """
    Calculate the true CATE without mediator2 for a given sample.
    
    Based on the true model:
    outcome = confounder + treatment - 0.8 * mediator1^2 + 0.7 * mediator2^2 - 0.5 * treatment * mediator1
    
    When mediator2 is removed, we need to calculate:
    E[Y(1) | confounder, mediator1] - E[Y(0) | confounder, mediator1]
    
    Where mediator2 is integrated out:
    E[mediator2^2 | treatment, confounder]
    """
    treatment = sample[0]
    confounder = sample[1]
    mediator1 = sample[2]
    
    E_mediator2_squared_t1 = 0.25 + (0.7 * 1) ** 2  # = 0.25 + 0.49 = 0.74
    E_mediator2_squared_t0 = 0.25 + (0.7 * 0) ** 2  # = 0.25 + 0 = 0.25
    
    cate_true = (confounder + 1 - 0.8 * mediator1**2 + 0.7 * E_mediator2_squared_t1 - 0.5 * 1 * mediator1) - \
                (confounder + 0 - 0.8 * mediator1**2 + 0.7 * E_mediator2_squared_t0 - 0.5 * 0 * mediator1)
    
    return cate_true
