import os
import numpy as np
import scipy.stats as st

def _sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def _generate_mixed_correlated_vars(n, R, rng, n_continuous=5, n_discrete=5, 
                                   gamma_shape=1.0, gamma_scale=1.0, bernoulli_p=0.5):
    
    eigvals, eigvecs = np.linalg.eigh(R)
    eigvals[eigvals < 1e-8] = 1e-8
    R_corrected = (eigvecs * eigvals) @ eigvecs.T
    
    Z_gauss = rng.multivariate_normal(mean=np.zeros(R.shape[0]), cov=R_corrected, size=n)
    
    U = st.norm.cdf(Z_gauss) 
    
    Zc = np.column_stack([
        st.gamma.ppf(U[:, j], a=gamma_shape, scale=gamma_scale)
        for j in range(n_continuous)
    ])
    
    Zd = np.column_stack([
        (U[:, n_continuous + j] < bernoulli_p).astype(int)
        for j in range(n_discrete)
    ])
    
    return Zc, Zd

def _generate_m2_once(
    n_samples,
    causal_params,
    seed=None,
    n_continuous=5,     
    n_discrete=5,       
    conf_strength=1.0,   
    temp_ps=0.7,         
    hetero_scale=1.0,        
):

    rng = np.random.default_rng(seed)
    base_tau, const = causal_params
    

    R = np.array([
        [1.0, 0.3, 0.4, 0.5, 0.1, 0.2, 0.7, 0.5, 0.4, 0.5],
        [0.3, 1.0, 0.3, 0.6, 0.3, 0.4, 0.4, 0.6, 0.3, 0.2],
        [0.4, 0.3, 1.0, 0.5, 0.2, 0.1, 0.1, 0.0, 0.4, 0.4],
        [0.5, 0.6, 0.5, 1.0, 0.2, 0.2, 0.5, 0.5, 0.3, 0.4],
        [0.1, 0.3, 0.2, 0.2, 1.0, 0.1, 0.5, 0.6, 0.2, 0.4],
        [0.2, 0.4, 0.1, 0.2, 0.1, 1.0, 0.0, 0.4, 0.2, 0.5],
        [0.7, 0.4, 0.1, 0.5, 0.5, 0.0, 1.0, 0.4, 0.4, 0.4],
        [0.5, 0.6, 0.0, 0.5, 0.6, 0.4, 0.4, 1.0, 0.4, 0.4],
        [0.4, 0.3, 0.4, 0.3, 0.2, 0.2, 0.4, 0.4, 1.0, 0.4],
        [0.5, 0.2, 0.4, 0.4, 0.2, 0.5, 0.4, 0.4, 0.4, 1.0]
    ], dtype=float)

    Zc, Zd = _generate_mixed_correlated_vars(
        n=n_samples, R=R, rng=rng, 
        n_continuous=n_continuous, n_discrete=n_discrete
    )
    
    zc1, zc2, zc3, zc4, zc5 = Zc.T
    zd1, zd2, zd3, zd4, zd5 = Zd.T



    fX = -0.3+0.1*zc1+0.2*zc2+0.5*zc3-0.2*zc4+zc5+0.3*zd1-0.4*zd2+0.7*zd3-0.1*zd4+0.9*zd5+ (zc2**2) + 0.5*np.sin(zc3) + 0.6*(zc4*zc5)
    logit = -1.0 + conf_strength * fX
    ps = _sigmoid(temp_ps * logit)             
    t = rng.binomial(1, ps, size=n_samples).astype(int)

    g0 = (0.30*zc1 + 0.20*(zc2**2) + 0.20*np.sin(zc3) - 0.10*(zc4*zc5))
    h  = (0.50*zc1 - 0.30*zc2 + 0.20*(zc2**2) + 0.30*np.sin(zc4) + 0.20*(zc1*zc5))
    mu0 = const + g0
    tau = base_tau + hetero_scale * h
    mu1 = mu0 + tau

    sigma = 0.5 + 0.3*np.clip((zc1 + zc2)/2.0, 0.0, None)  # σ(x) ≥ 0.5

    y_mean = mu0 + t * tau
    y = rng.normal(loc=y_mean, scale=sigma, size=n_samples)

    X_features = np.column_stack([Zc, Zd])

    return {
        "X_features": X_features,   
        "Zc": Zc,                 
        "Zd": Zd,                  
        "t": t,                    
        "y": y,                    
        "ps": ps,                   
        "mu0": mu0,                
        "mu1": mu1,                
        "tau": tau,                
    }

def make_m2_replicates(
    n_reps=30,
    n_samples=3000,
    test_ratio=0.2,
    causal_params=(2.0, 1.0),
    seed=2025,
    save_dir=None
):

    assert 0.0 < test_ratio < 1.0
    n_test = int(round(n_samples * test_ratio))
    n_train = n_samples - n_test

    master_rng = np.random.default_rng(seed)


    out = {
        "X_train": [], "t_train": [], "y_train": [], "tau_train": [],
        "mu0_train": [], "mu1_train": [], "ps_train": [],
        "X_test": [],  "t_test": [],  "y_test": [],  "tau_test": [],
        "mu0_test": [], "mu1_test": [], "ps_test": []
    }

    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    for r in range(n_reps):
        rep_seed = int(master_rng.integers(0, 2**31 - 1))
        data = _generate_m2_once(n_samples=n_samples, causal_params=causal_params, seed=rep_seed)

        idx = master_rng.permutation(n_samples)
        train_idx, test_idx = idx[:n_train], idx[n_train:]

        Xtr, Xte = data["X_features"][train_idx], data["X_features"][test_idx]
        ttr, tte   = data["t"][train_idx], data["t"][test_idx]
        ytr, yte   = data["y"][train_idx], data["y"][test_idx]
        ps_tr, ps_te = data["ps"][train_idx], data["ps"][test_idx]
        mu0_tr, mu0_te = data["mu0"][train_idx], data["mu0"][test_idx]
        mu1_tr, mu1_te = data["mu1"][train_idx], data["mu1"][test_idx]
        tau_tr, tau_te = data["tau"][train_idx], data["tau"][test_idx]

        out["X_train"].append(Xtr)
        out["t_train"].append(ttr)
        out["y_train"].append(ytr)
        out["ps_train"].append(ps_tr)
        out["mu0_train"].append(mu0_tr)
        out["mu1_train"].append(mu1_tr)
        out["tau_train"].append(tau_tr)

        out["X_test"].append(Xte)
        out["t_test"].append(tte)
        out["y_test"].append(yte)
        out["ps_test"].append(ps_te)
        out["mu0_test"].append(mu0_te)
        out["mu1_test"].append(mu1_te)
        out["tau_test"].append(tau_te)

        if save_dir is not None:
            np.savez(
                os.path.join(save_dir, f"m2_rep{r:02d}_train.npz"),
                x=Xtr, t=ttr, y=ytr, ps=ps_tr, mu0=mu0_tr, mu1=mu1_tr, tau=tau_tr
            )
            np.savez(
                os.path.join(save_dir, f"m2_rep{r:02d}_test.npz"),
                x=Xte, t=tte, y=yte, ps=ps_te, mu0=mu0_te, mu1=mu1_te, tau=tau_te
            )

    return out
