
import numpy as np
from scipy.special import expit as sigmoid  # Sigmoid function

def gen_observational_trial(n=1000, d=10, tau=4.0, gamma=0.5, c=2.0, sigma=0.2, sg_size=0.3, rule_size=2, seed=42, mean_shift=False):
    """
    Generate synthetic data following a confounded treatment-outcome structure
    with a contrastive subgroup.
    
    Parameters:
        n (int): Number of samples
        d (int): Number of features
        tau (float): Treatment effect within the subgroup
        gamma (float): Treatment effect outside the subgroup
        c (float): Confounding effect within the subgroup via x3
        sigma (float): Standard deviation of the noise in Y
        seed (int): Random seed for reproducibility
        
    Returns:
        dict: Dictionary containing X, s_star, A, Y, mu
    """
    np.random.seed(seed)
    
    # 1. Generate features X ~ N(0, I)
    X = np.random.uniform(0,1, size=(n, d))

    rule_feats = np.random.choice(d, size=rule_size, replace=False)
    interval_size = sg_size**(1 / rule_size)

    rule_descriptor = {}
    s_star = np.ones(n, dtype=bool)
    for i,feat in enumerate(rule_feats):
        lower_bound = np.random.uniform(0, 1 - interval_size)
        upper_bound = lower_bound + interval_size
        lower_bound = np.quantile(X[:, feat], lower_bound)        
        upper_bound = np.quantile(X[:, feat], upper_bound)
        s_star = np.logical_and(s_star, ((X[:, feat] > lower_bound) & (X[:, feat] < upper_bound)))
        rule_descriptor[int(feat)] = (float(lower_bound),float(upper_bound))
    
    # 3. Generate confounded treatment assignment A ~ Bernoulli(sigmoid(beta_A^T X))
    beta_A = np.random.randn(d)
    A_probs = sigmoid(X @ beta_A)
    A = np.random.binomial(1, A_probs)

    # 4. Generate outcome Y
    beta_Y = np.random.uniform(-1,1,size=d)

    non_contrastive_feat = np.random.choice(d, size=1, replace=False)[0]
    s_star = s_star.astype(float)  # Convert boolean to float for multiplication
    
    base_effect = X @ beta_Y
    
    subgroup_treatment = tau/2 * A * s_star
    subgroup_treatment -= tau/2 * (1-A) * s_star  # Ensure subgroup treatment is applied to both treated and untreated

    non_subgroup_treatment = gamma/2 * A * (1 - s_star)
    non_subgroup_treatment -= gamma/2 * (1 - A) * (1 - s_star)  # Ensure non-subgroup treatment is applied to both treated and untreated
    subgroup_independent_effect = c * s_star * X[:, non_contrastive_feat]
    mu = base_effect + subgroup_treatment + non_subgroup_treatment #+ subgroup_independent_effect.flatten()

    Y = np.random.normal(loc=mu, scale=sigma)
    #Y = mu
    return {
        "X": X,
        "s_star": s_star,
        "A": A,
        "Y": Y,
        "mu": mu,
        "rule": rule_descriptor
    }

def gen_interventional_trial(n=1000, d=10, tau=4.0, gamma=0.5, c=2.0, sigma=0.2, sg_size=0.2, rule_size=2, seed=42, mean_shift=False):
    """
    Generate synthetic data where treatment A causes covariates X.
    
    Parameters:
        n (int): Number of samples
        d (int): Number of features
        tau (float): Treatment effect within the subgroup
        gamma (float): Treatment effect outside the subgroup
        c (float): Confounding effect within the subgroup via x3
        sigma (float): Standard deviation of the noise in Y
        sg_size (float): Size of the subgroup
        rule_size (int): Number of features in the rule
        seed (int): Random seed for reproducibility
        
    Returns:
        dict: Dictionary containing X, s_star, A, Y, mu, rule
    """
    np.random.seed(seed)
    
    # 1. Generate treatment assignment A ~ Bernoulli(p)
    p_A = 0.5
    A = np.random.binomial(1, p_A, size=n)

    X = np.random.uniform(0, 1, (n, d))

    # 3. Define contrastive subgroup s*(x) via axis-aligned box rule
    rule_feats = np.random.choice(d, size=rule_size, replace=False)
    interval_size = sg_size**(1 / rule_size)
    lower_bounds = np.random.uniform(0, 1 - interval_size, size=rule_size)
    upper_bounds = lower_bounds + interval_size

    rule_descriptor = {}
    s_star = np.ones(n, dtype=bool)
    for i, feat in enumerate(rule_feats):
        s_star = np.logical_and(
            s_star,
            ((X[:, feat] > lower_bounds[i]) & (X[:, feat] < upper_bounds[i]))
        )
        rule_descriptor[int(feat)] = (float(lower_bounds[i]), float(upper_bounds[i]))

    s_star = s_star.astype(float)

    # 4. Generate outcome Y
    beta_Y = np.random.uniform(-1,1,size=d)
    base_effect = X @ beta_Y

    subgroup_treatment = tau/2 * A * s_star
    subgroup_treatment -= tau/2 * (1-A) * s_star  # Ensure subgroup treatment is applied to both treated and untreated
    non_subgroup_treatment = gamma/2 * A * (1 - s_star)
    non_subgroup_treatment -= gamma/2 * (1 - A) * (1 - s_star)  # Ensure non-subgroup treatment is applied to both treated and untreated
    non_contrastive_feat = np.random.choice(d, size=1, replace=False)[0]
    subgroup_independent_effect = c * s_star * X[:, non_contrastive_feat]   

    mu = base_effect + subgroup_treatment + non_subgroup_treatment #+ subgroup_independent_effect
    Y = np.random.normal(loc=mu, scale=sigma)   
    return {
        "X": X,
        "s_star": s_star,
        "A": A,
        "Y": Y,
        "mu": mu,
        "rule": rule_descriptor
    }

def gen_demographic_data(n=1000, d=10, tau=4.0, gamma=0.5, c=2.0, sigma=0.2,
                             sg_size=0.2, rule_size=2, seed=42, mean_shift=False):
    """
    Generate synthetic data where treatment A causes covariates X.
    """
    np.random.seed(seed)

    # 1. Generate treatment assignment A ~ Bernoulli(p)
    p_A = 0.5
    A = np.random.binomial(1, p_A, size=n)

    # 2. Generate features X ~ N(mu_A, I) depending on A
    mu0 = np.zeros(d)
    mu1 = np.random.uniform(-0.3,0.3, size=d)  # A shifts X
    shift = (A[:, None] * (mu1 - mu0))
    X = np.random.uniform(0,1,(n,d)) + shift  # Each row conditioned on A

    # 3. Define contrastive subgroup s*(x) via axis-aligned box rule
    rule_feats = np.random.choice(d, size=rule_size, replace=False)
    interval_size = sg_size**(1 / rule_size)

    rule_descriptor = {}
    s_star = np.ones(n, dtype=bool)
    for i,feat in enumerate(rule_feats):
        lower_bound = np.random.uniform(0, 1 - interval_size)
        upper_bound = lower_bound + interval_size
        lower_bound = np.quantile(X[:, feat], lower_bound)        
        upper_bound = np.quantile(X[:, feat], upper_bound)
        s_star = np.logical_and(s_star, ((X[:, feat] > lower_bound) & (X[:, feat] < upper_bound)))
        rule_descriptor[int(feat)] = (float(lower_bound),float(upper_bound))
    s_star = s_star.astype(float)

    # 4. Generate outcome Y
    beta_Y = np.random.uniform(-1,1,size=d)
    base_effect = X @ beta_Y

    subgroup_treatment = tau/2 * A * s_star
    subgroup_treatment -= tau/2 * (1-A) * s_star  # Ensure subgroup treatment is applied to both treated and untreated
    non_subgroup_treatment = gamma/2 * A * (1 - s_star)
    non_subgroup_treatment -= gamma/2 * (1 - A) * (1 - s_star)  # Ensure non-subgroup treatment is applied to both treated and untreated
    non_contrastive_feat = np.random.choice(d, size=1, replace=False)[0]
    subgroup_independent_effect = c * s_star * X[:, non_contrastive_feat]

    mu = base_effect + subgroup_treatment + non_subgroup_treatment #+ subgroup_independent_effect
    Y = np.random.normal(loc=mu, scale=sigma)

    return {
        "X": X,
        "s_star": s_star,
        "A": A,
        "Y": Y,
        "mu": mu,
        "rule": rule_descriptor
    }