from typing import Callable, Dict, Any, Optional, Tuple
import numpy as np


def _ensure_rng(rng: Optional[np.random.Generator], seed: Optional[int]) -> np.random.Generator:
    if isinstance(rng, np.random.Generator):
        return rng
    return np.random.default_rng(0 if seed is None else int(seed))


def build_gfunc(mode: str, *, freq: float = 1.0) -> Callable[[np.ndarray], np.ndarray]:
    """
      - linear: u
      - logic : 1/(1+exp(u))   (note: decreasing logistic)
      - sin   : sin(freq*u)
      - complex: tanh(u) + sin(4u)*exp(-0.01*u^2)
    """
    mode = str(mode).lower()

    if mode == "linear":
        return lambda u: u

    if mode == "logic":
        return lambda u: 1.0 / (1.0 + np.exp(u))

    if mode == "sin":
        f = float(freq)
        return lambda u: np.sin(f * u)

    if mode == "complex":
        return lambda u: np.tanh(u) + np.sin(4.0 * u) * np.exp(-0.01 * (u ** 2))

    raise ValueError("Unknown g mode. Use one of: ['linear','logic','sin','complex'].")


def add_noise_by_snr(
    signal: np.ndarray,
    target_snr: float,
    *,
    assume_centered: bool = False,
    rng: Optional[np.random.Generator] = None,
    seed: Optional[int] = None,
) -> Tuple[np.ndarray, float]:
    """
    Add Gaussian noise so that Var(signal)/Var(noise) = target_snr.
    Returns: (noisy_signal, noise_std)
    """
    rng = _ensure_rng(rng, seed)
    sig = np.asarray(signal, dtype=float)
    s = sig if assume_centered else (sig - sig.mean())
    sig_var = float(np.var(s))
    noise_std = float(np.sqrt(max(sig_var / float(target_snr), 0.0)))
    noisy = sig + rng.normal(0.0, noise_std, size=sig.shape)
    return noisy, noise_std


def generate_y_from_g(
    X_fine: np.ndarray,
    betat: np.ndarray,
    fine_t: np.ndarray,
    *,
    g_mode: str = "linear",
    g_kwargs: Optional[Dict[str, Any]] = None,
    standardize_eta: bool = False,
    clip_eta: Optional[Tuple[float, float]] = None,
    center_g_output: bool = False,
    unitvar_g_output: bool = False,
    target_snr: float = 10.0,
    rng: Optional[np.random.Generator] = None,
    seed: Optional[int] = None,
) -> Dict[str, Any]:
    """
    y_clean = g(eta),  eta = ∫ X(t) beta(t) dt  (trapz on fine grid)
    y_noisy = y_clean + eps where Var(y_clean)/Var(eps)=target_snr
    """
    rng = _ensure_rng(rng, seed)
    g_kwargs = {} if g_kwargs is None else dict(g_kwargs)

    # g(u)
    g = build_gfunc(g_mode, **g_kwargs)

    # eta
    betat = np.asarray(betat, dtype=float)
    eta = np.trapz(X_fine * betat[None, :], fine_t, axis=1)

    if standardize_eta:
        eta = (eta - eta.mean()) / (eta.std() + 1e-12)
    if clip_eta is not None:
        lo, hi = clip_eta
        eta = np.clip(eta, float(lo), float(hi))

    y_clean = g(eta).astype(float)

    if center_g_output:
        y_clean = y_clean - y_clean.mean()

    g_var_before = float(np.var(y_clean))
    if unitvar_g_output:
        y_clean = y_clean / (np.sqrt(g_var_before) + 1e-12)
    g_var_after = float(np.var(y_clean))

    y_noisy, noise_std = add_noise_by_snr(
        y_clean, target_snr, assume_centered=False, rng=rng
    )

    return dict(
        y_clean=y_clean,
        y_noisy=y_noisy,
        eta=eta,
        g_name=str(g_mode),
        noise_std=float(noise_std),
        snr_target=float(target_snr),
        g_var_before_scale=float(g_var_before),
        g_var_after_scale=float(g_var_after),
        standardize_eta=bool(standardize_eta),
        clip_eta=clip_eta,
        center_g_output=bool(center_g_output),
        unitvar_g_output=bool(unitvar_g_output),
    )
