"""
Transform families for generating dependent data from copulas.
"""
import numpy as np
from scipy.stats import norm


def _broadcast_b(b, d):
    """Broadcast scalar b to d-dimensional array."""
    b = np.asarray(b)
    if b.ndim == 0:
        return np.full(d, float(b))
    assert b.shape == (d,), f"b must be scalar or shape ({d},)"
    return b


def transform_trig_uniform(u, v, b):
    """
    Trigonometric-uniform transform.
    
    x_j = sin(Phi^{-1}(u_j))
    y_j = cos(b_j * x_j + v_j)
    
    Parameters
    ----------
    u, v : np.ndarray, shape (n, d)
        Uniform(0,1) samples
    b : float or array-like
        Dependence parameter(s)
        
    Returns
    -------
    X, Y : np.ndarray, shape (n, d)
        Transformed samples
    """
    n, d = u.shape
    b = _broadcast_b(b, d)
    
    X = np.sin(norm.ppf(u))
    Y = np.cos(X * b[None, :] + v)
    
    return X, Y


def transform_expquad(u, v, b):
    """
    Exponential-quadratic transform.
    
    x_j = exp(-(Phi^{-1}(u_j))^2)
    y_j = exp(-b_j * (x_j - 1)^2 + v_j)
    
    Parameters
    ----------
    u, v : np.ndarray, shape (n, d)
        Uniform(0,1) samples
    b : float or array-like
        Dependence parameter(s)
        
    Returns
    -------
    X, Y : np.ndarray, shape (n, d)
        Transformed samples
    """
    n, d = u.shape
    b = _broadcast_b(b, d)
    
    Z = norm.ppf(u)
    X = np.exp(-(Z ** 2))
    Y = np.exp(-b[None, :] * (X - 1.0) ** 2 + v)
    
    return X, Y


def transform_linear(u, v, b):
    """
    Linear transform.
    
    x_j = u_j
    y_j = b_j * x_j + v_j
    
    Parameters
    ----------
    u, v : np.ndarray, shape (n, d)
        Uniform(0,1) samples
    b : float or array-like
        Dependence parameter(s)
        
    Returns
    -------
    X, Y : np.ndarray, shape (n, d)
        Transformed samples
    """
    n, d = u.shape
    b = _broadcast_b(b, d)
    
    X = u.copy()
    Y = b[None, :] * X + v
    
    return X, Y


def transform_logquad(u, v, b):
    """
    Log-quadratic transform with phase and amplitude modulation.
    
    Z_j = Phi^{-1}(u_j)
    X_j = log1p(Z_j^2) / (1 + log1p(Z_j^2))
    Y_j = cos(b_j * X_j + v_j) * exp(-b_j * (X_j - 0.7)^2)
    
    Parameters
    ----------
    u, v : np.ndarray, shape (n, d)
        Uniform(0,1) samples
    b : float or array-like
        Dependence parameter(s)
        
    Returns
    -------
    X, Y : np.ndarray, shape (n, d)
        Transformed samples
    """
    n, d = u.shape
    b = _broadcast_b(b, d)
    
    Z = norm.ppf(u)
    X_base = np.log1p(Z**2)
    X = X_base / (1.0 + X_base)
    Y = np.cos(b[None, :] * X + v) * np.exp(-b[None, :] * (X - 0.7) ** 2)
    
    return X, Y


# Transform registry
TRANSFORM_MAP = {
    "trigU": transform_trig_uniform,
    "expquad": transform_expquad,
    "linear": transform_linear,
    "logquad": transform_logquad,
}


def apply_transform(u, v, b, transform_key):
    """
    Apply specified transform to uniform samples.
    
    Parameters
    ----------
    u, v : np.ndarray, shape (n, d)
        Uniform(0,1) samples from copula
    b : float or array-like
        Dependence parameter(s)
    transform_key : str
        Transform name: 'trigU', 'expquad', 'linear', or 'logquad'
        
    Returns
    -------
    X, Y : np.ndarray, shape (n, d)
        Transformed samples
    """
    if transform_key not in TRANSFORM_MAP:
        raise ValueError(f"Unknown transform: {transform_key}. "
                        f"Available: {list(TRANSFORM_MAP.keys())}")
    
    return TRANSFORM_MAP[transform_key](u, v, b)
