from typing import Dict, Any, Tuple
import numpy as np
from betat_module import beta_paper_case
from g_module     import generate_y_from_g
from xt_module    import gen_X_cosine

def generate_demo_data(
    cfg: Dict[str, Any],
    seed: int,
    fine_t: np.ndarray,
    observed_t: np.ndarray,
    z: np.ndarray,
    *,
    return_X: str = "obs",        # "obs" or "fine"
) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
    """
    Minimal in-memory generator for demo.
    Returns (X, y, meta) where:
      - X: (n, P) if return_X="obs"  uses X_obs_noisy
           (n, Tf) if return_X="fine" uses X_fine
      - y: (n,) y_noisy
      - meta: light metadata (intervals_true, snr_actual, etc.)
    """
    rng = np.random.default_rng(seed)

    # 1) X
    x_target_snr = cfg.get("x_target_snr", 10.0)
    Xdict = gen_X_cosine(
        n_samples=cfg["n_samples"],
        fine_t=fine_t,
        observed_t=observed_t,
        num_basis=cfg["num_basis"],
        z=z,
        rng=rng,
        add_noise_obs=True,
        target_snr_obs=x_target_snr
    )
    X_fine = Xdict["X_fine"]
    X_obs  = Xdict.get("X_obs_noisy", None)
    if return_X == "obs":
        if X_obs is None:
            raise ValueError("X_obs_noisy not found. Set add_noise_obs=True or use return_X='fine'.")
        X = X_obs
    elif return_X == "fine":
        X = X_fine
    else:
        raise ValueError("return_X must be 'obs' or 'fine'.")

    # 2) beta(t)
    betat, bet_info = beta_paper_case(fine_t, case_id=cfg["betat"]["params"]["case_id"])
    intervals_true = bet_info["intervals_true"]

    # 3) y
    gout = generate_y_from_g(
        X_fine=X_fine,
        betat=betat,
        fine_t=fine_t,
        g_mode=cfg["g"]["mode"],
        g_kwargs=cfg["g"].get("params", {}),
        standardize_eta=cfg["g"].get("standardize_eta", False),
        center_g_output=cfg["g"].get("center_g_output", False),
        clip_eta=cfg["g"].get("clip_eta", None),
        target_snr=cfg["snr"]["target"],
        rng=rng
    )
    y_clean = gout["y_clean"]
    y_noisy = gout["y_noisy"]

    sig_var = float(np.var(y_clean))
    noise_var = float(np.var(y_noisy - y_clean))
    snr_actual = float(sig_var / (noise_var + 1e-12))

    meta = dict(
        seed=int(seed),
        scenario=dict(
            case_id=int(cfg["betat"]["params"]["case_id"]),
            g_mode=str(cfg["g"]["mode"]),
            snr_target=float(cfg["snr"]["target"]),
            n=int(cfg["n_samples"]),
            k=int(cfg["num_basis"]),
        ),
        snr_actual=snr_actual,
        intervals_true=[(float(a), float(b)) for a, b in intervals_true],
    )
    return X, y_noisy.astype(float), meta
