import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import importlib
from scipy.special import expit, logit

def generate_data(
    n,
    p,
    alpha0,
    alpha1,
    alpha2,
    zeta1,
    gamma0,
    gamma1,
    gamma2,
    beta0,
    beta1,
    beta2,
    beta3,
    rho0,
    rho1,
    sigma_epsilon_Y,
    sigma_epsilon_W,
    seed=0,
):
    np.random.seed(seed)
    # Covariance matrix for multivariate normal distribution
    Sigma = np.eye(p)

    # Generate X ~ N(0, Sigma)
    X = np.random.multivariate_normal(np.zeros(p), Sigma, n)

    # S ~ Bernoulli(expit(alpha1 * X1 + alpha2 * X2))
    X1, X2 = X[:, 0], X[:, 1]
    S_prob = expit(alpha0 + alpha1 * X1 + alpha2 * X2)
    S = np.random.binomial(1, S_prob, n)

    # T ~ Bernoulli((1 - S) * 0.5 + S * expit(zeta1 * X1))
    T_prob = (1 - S) * 0.5 + S * expit(zeta1 * X1)
    T = np.random.binomial(1, T_prob, n)

    # Calculate mu_W(x,t,s)
    phi = (gamma0 + gamma1 * X1 + gamma2 * X2)
    mu_W = phi * T + beta1 * X1 + beta2 * X2 + beta3 * X1 * X2 + beta0

    # Y = (rho1 * X1 + rho0) * mu_W + epsilon_Y
    alpha = (rho1 * X1 + rho0)
    Y = alpha * mu_W + np.random.normal(0, sigma_epsilon_Y, n)
    
    theta = alpha * phi

    # W = mu_W + epsilon_W
    W = mu_W + np.random.normal(0, sigma_epsilon_W, n)

    return X, S, T, Y, W, theta

import numpy as np
from scipy.special import expit

def generate_data_2(n, p, sigma_W=1.0, sigma_Y=1.0):
    """
    Generates synthetic data based on the specified Data Generating Procedure (DGP).
    
    Parameters:
    n : int
        Number of samples to generate.
    p : int
        Number of covariates (features).
    sigma_W : float, optional
        Standard deviation of noise in W.
    sigma_Y : float, optional
        Standard deviation of noise in Y.
        
    Returns:
    dict
        Dictionary containing generated data arrays X, S, T, W, and Y.
    """
    # Step 1: Generate covariates X from multivariate normal distribution
    X = np.random.multivariate_normal(mean=np.zeros(p), cov=np.identity(p), size=n)
    
    # Step 2: Generate study indicator S based on the inverse logit of a linear function of X1 and X2
    prob_S = expit(1)
    S = np.random.binomial(1, prob_S, size=n)
    
    # Step 3: Generate treatment assignment T
    prob_T = (1 - S) * 0.5 + S * expit(0.7 * X[:, 0])
    T = np.random.binomial(1, prob_T, size=n)
    
    # Step 4: Define the outcome W in the auxiliary study
    mu_W = (1.2 * X[:, 0]**2 - 0.5 * X[:, 1]) * T + 0.8 * X[:, 0] + 0.6 * X[:, 1] + 0.4 * X[:, 0] * X[:, 1] + 1
    W = mu_W + np.random.normal(0, sigma_W, size=n)
    
    # Step 5: Define alpha(X) for primary study outcome Y
    alpha_X = np.log(1.5 * X[:, 0]**2 + 0.5)
    
    # Step 6: Define the outcome Y in the primary study
    Y = alpha_X * mu_W + np.random.normal(0, sigma_Y, size=n)
    
    theta = (1.2 * X[:, 0]**2 - 0.5 * X[:, 1]) * alpha_X
    
    return X, S, T, Y, W, theta, alpha_X

def create_dataframe(X, S, T, Y, W):
    # Convert data to pandas DataFrame
    df = pd.DataFrame(X, columns=[f"X{i+1}" for i in range(X.shape[1])])
    df["S"] = S
    df["T"] = T
    df["Y"] = Y
    df["W"] = W
    df["V"] = S * W + (1 - S) * Y
    return df