import numpy as np
from scipy.special import expit 
from dataclasses import dataclass
from typing import Optional, Tuple


@dataclass(frozen=True)
class MultivarOutcomeRegressionParams:
    d_x: int = 1
    d_y: int = 1
    gamma0: float = 0.0
    gamma: Optional[np.ndarray] = None   # (d_x,)
    B: Optional[np.ndarray] = None       # (d_y, d_x)
    T: Optional[np.ndarray] = None       # (d_y, d_x)
    delta: Optional[np.ndarray] = None   # (d_y,)
    Sigma: Optional[np.ndarray] = None   # (d_y, d_y)


def _check_and_prepare_params(p: MultivarOutcomeRegressionParams) -> MultivarOutcomeRegressionParams:
    if p.d_x <= 0 or p.d_y <= 0:
        raise ValueError(f"d_x and d_y must be positive. Got d_x={p.d_x}, d_y={p.d_y}")

    # Defaults
    gamma = np.zeros((p.d_x,), dtype=float) if p.gamma is None else np.asarray(p.gamma, dtype=float)
    B     = np.zeros((p.d_y, p.d_x), dtype=float) if p.B is None else np.asarray(p.B, dtype=float)
    T     = np.zeros((p.d_y, p.d_x), dtype=float) if p.T is None else np.asarray(p.T, dtype=float)
    delta = np.zeros((p.d_y,), dtype=float) if p.delta is None else np.asarray(p.delta, dtype=float)
    Sigma = np.eye(p.d_y, dtype=float) if p.Sigma is None else np.asarray(p.Sigma, dtype=float)

    # Shape checks
    if gamma.shape != (p.d_x,):
        raise ValueError(f"gamma must have shape (d_x,)={(p.d_x,)}. Got {gamma.shape}")
    if B.shape != (p.d_y, p.d_x):
        raise ValueError(f"B must have shape (d_y, d_x)={(p.d_y, p.d_x)}. Got {B.shape}")
    if T.shape != (p.d_y, p.d_x):
        raise ValueError(f"T must have shape (d_y, d_x)={(p.d_y, p.d_x)}. Got {T.shape}")
    if delta.shape != (p.d_y,):
        raise ValueError(f"delta must have shape (d_y,)={(p.d_y,)}. Got {delta.shape}")
    if Sigma.shape != (p.d_y, p.d_y):
        raise ValueError(f"Sigma must have shape (d_y, d_y)={(p.d_y, p.d_y)}. Got {Sigma.shape}")

    # Basic covariance validity checks (SPD-ish)
    if not np.allclose(Sigma, Sigma.T, atol=1e-10, rtol=1e-10):
        raise ValueError("Sigma must be symmetric.")
    eig_min = np.linalg.eigvalsh(Sigma).min()
    if eig_min <= 0:
        raise ValueError(f"Sigma must be positive definite. Smallest eigenvalue={eig_min:g}")

    return MultivarOutcomeRegressionParams(
        d_x=p.d_x, d_y=p.d_y, gamma0=float(p.gamma0),
        gamma=gamma, B=B, T=T, delta=delta, Sigma=Sigma
    )


def simulate_multivariate_outcome_regression(
    n: int,
    params: MultivarOutcomeRegressionParams = MultivarOutcomeRegressionParams(),
    x_loc: float = 0.0,
    x_scale: float = 1.0,
    clip_propensity: Optional[Tuple[float, float]] = None,
    rng: Optional[np.random.Generator] = None,
    return_propensity: bool = False,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray] | Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    
    """
    Simulate data from:
      X ~ N(x_loc, x_scale^2 I)
      A | X ~ Bernoulli(expit(gamma0 + X gamma))
      Y | A, X ~ N(mu(X) + A * tau(X), Sigma)
        where mu(X)=X B^T, tau(X)=X T^T + delta

    Returns
    -------
    X : (n, d_x) float
    A : (n,) int {0,1}
    Y : (n, d_y) float
    e : (n,) float propensity scores (optional)
    """
    if n <= 0:
        raise ValueError(f"n must be positive. Got n={n}")

    p = _check_and_prepare_params(params)
    rng = np.random.default_rng() if rng is None else rng

    # Covariates
    X = rng.normal(loc=x_loc, scale=x_scale, size=(n, p.d_x)).astype(float)

    # Propensity and treatment
    logits = p.gamma0 + X @ p.gamma  # (n,)
    e = expit(logits)                # (n,)

    if clip_propensity is not None:
        lo, hi = clip_propensity
        if not (0.0 < lo < hi < 1.0):
            raise ValueError(f"clip_propensity must satisfy 0 < lo < hi < 1. Got {clip_propensity}")
        e = np.clip(e, lo, hi)

    A = rng.binomial(n=1, p=e, size=n).astype(np.int64)

    # Outcome mean structure
    mu = X @ p.B.T                         # (n, d_y)
    tau = X @ p.T.T + p.delta              # (n, d_y)
    mean_Y = mu +  A[:, None] * tau         # (n, d_y)

    # Multivariate noise
    eps = rng.multivariate_normal(mean=np.zeros(p.d_y), cov=p.Sigma, size=n)  # (n, d_y)
    Y = eps + mean_Y

    if return_propensity:
        return X, A, Y, e
    return X, A, Y

