import numpy as np

def data_generation(p=5, n=1, step=500, burn_in=100, scale=0.1, graph=None,
                    non_linear=None, activation="leaky_relu",
                    hidden_dim=16, return_confounders=False, seed=42):
    """
    Simulate multivariate time series data from a (linear or non-linear) causal graph.

    Parameters
    ----------
    p : int
        Number of observed variables.
    n : int
        Number of latent confounders.
    step : int
        Number of effective time steps.
    burn_in : int
        Warm-up steps (discarded from output).
    scale : float
        Simulate environmental variation.
    graph : ndarray (dim x dim)
        Weighted causal graph (dim = p+n).
    non_linear : None | "activation" | "mlp_after"
        Nonlinear mode: None = linear, "activation" = element-wise nonlinearity,
        "mlp_after" = pass WX through a small MLP.
    activation : str
        Activation function if non_linear="activation".
    hidden_dim : int
        Hidden dimension for the MLP if non_linear="mlp_after".
    return_confounders : bool
        If True, return confounders as well.
    seed : int
        Random seed.
    """
    np.random.seed(seed)
    dim = p + n
    total_steps = step + burn_in

    if graph is None:
        graph = np.eye(dim) * 0.5  # fallback

    # Gaussian residuals
    residuals = np.random.normal(scale=scale, size=(dim, total_steps))
    X = np.zeros((dim, total_steps))
    X[:, 0] = residuals[:, 0]

    # element-wise activation
    def apply_activation(x):
        if activation == "leaky_relu":
            return np.where(x > 0, x, 0.01 * x)
        elif activation == "relu":
            return np.maximum(0, x)
        elif activation == "tanh":
            return np.tanh(x)
        elif activation == "sigmoid":
            return 1 / (1 + np.exp(-x))
        else:
            return x

    # build MLP parameters (2-layer)
    if non_linear == "mlp_after":
        W1 = np.random.randn(hidden_dim, dim) * 0.1
        b1 = np.zeros((hidden_dim, 1))
        W2 = np.random.randn(dim, hidden_dim) * 0.1
        b2 = np.zeros((dim, 1))

        def mlp(x_vec):
            h = np.tanh(W1 @ x_vec + b1)
            out = W2 @ h + b2
            return out.ravel()

    # generate dynamics
    for t in range(1, total_steps):
        linear_part = graph @ X[:, t - 1]
        if non_linear is None:
            X[:, t] = linear_part
        elif non_linear == "activation":
            X[:, t] = apply_activation(linear_part)
        elif non_linear == "mlp_after":
            X[:, t] = mlp(linear_part.reshape(-1, 1))
        else:
            raise ValueError(f"Unknown non_linear mode: {non_linear}")
        X[:, t] += residuals[:, t]

    # drop burn-in
    if return_confounders:
        return X[:, burn_in:].T
    else:
        return X[n:, burn_in:].T


if __name__ == "__main__":
    # Linear with 1 confounder
    G = np.load("synthetic/graph/conf_latent1.npy")
    X_env1 = data_generation(p=5, n=1, graph=G, scale=0.1)
    X_env2 = data_generation(p=5, n=1, graph=G, scale=0.15)

    G_itv = np.load("synthetic/graph/itv_latent1_env1.npy")
    X_env3 = data_generation(p=5, n=1, graph=G_itv, scale=0.1)
    envs = np.stack([X_env1, X_env2, X_env3], axis=0)
    np.save("synthetic/linear_data/linear_envs3_conf1_itv1.npy", envs)

    # Non-linear with 1 confounder
    G = np.load("synthetic/graph/conf_latent1.npy")
    X_env1 = data_generation(p=5, n=1, graph=G, scale=0.1,
                             non_linear="activation", activation="leaky_relu")
    X_env2 = data_generation(p=5, n=1, graph=G, scale=0.15,
                             non_linear="activation", activation="leaky_relu")

    G_itv = np.load("synthetic/graph/itv_latent1_env1.npy")
    X_env3 = data_generation(p=5, n=1, graph=G_itv, scale=0.1,
                             non_linear="activation", activation="leaky_relu")
    envs = np.stack([X_env1, X_env2, X_env3], axis=0)
    np.save("synthetic/non_linear_data/non_linear_envs3_conf1_itv1_leaky.npy", envs)