"""
Data generation for synthetic PD-cancellation demo.

Data-generating process:
  X1 ~ Normal(0, std_x1^2)
  X2 ~ Normal(-1, std_x2^2)
  X3 ~ Normal(-1, std_x3^2)
  y = X1^2*X2 + X1^2*X2*X3 + eps, where eps ~ Normal(0, noise_std)

Key property (PD cancellation):
  PD_1(a) = E[y | do(X1=a)] = 0  for all a (in expectation),
even though X1 matters strongly via interactions (ICE varies across points).
"""

from __future__ import annotations

import numpy as np


def make_dataset(
    n: int,
    seed: int,
    noise_std: float = 0.25,
    constant_offset: float = 0.0,
    std_x1: float = 1.0,
    std_x2: float = 2.0,
    std_x3: float = 3.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Generate synthetic dataset with PD cancellation property.

    Parameters
    ----------
    n : int
        Number of samples
    seed : int
        Random seed
    noise_std : float, default=0.25
        Standard deviation of noise term
    constant_offset : float, default=0.0
        Constant offset added to y (for compatibility with different scripts)
    std_x1 : float, default=1.0
        Standard deviation for X1
    std_x2 : float, default=1.0
        Standard deviation for X2
    std_x3 : float, default=1.0
        Standard deviation for X3

    Returns
    -------
    X : np.ndarray of shape (n, 3)
        Feature matrix
    y : np.ndarray of shape (n,)
        Target vector
    """
    rng = np.random.default_rng(seed)
    # Generate each feature with its own variance
    x1 = rng.normal(loc=0.0, scale=std_x1, size=n).astype(np.float64)
    x2 = rng.normal(loc=-1.0, scale=std_x2, size=n).astype(np.float64)
    x3 = rng.normal(loc=-1.0, scale=std_x3, size=n).astype(np.float64)
    X = np.column_stack([x1, x2, x3])
    y = x1**2 * x2 + x1**2 * x2 * x3 + constant_offset
    y = y + noise_std * rng.standard_normal(size=n)
    return X, y.astype(np.float64)
