from typing import Dict, Any, Optional
import numpy as np

def phi(k: int, t: np.ndarray) -> np.ndarray:
    """
    Cosine basis on [0,1]:
      phi_1(t)=1,
      phi_k(t)=sqrt(2)*cos((k-1)*pi*t), k>=2
    """
    t = np.asarray(t, dtype=float)
    if k == 1:
        return np.ones_like(t)
    return np.sqrt(2.0) * np.cos((k - 1) * np.pi * t)


def add_noise_snr_matrix(
    X: np.ndarray,
    target_snr: float,
    *,
    per_sample: bool = False,
    rng: Optional[np.random.Generator] = None,
) -> Dict[str, Any]:
    """
    Add i.i.d. Gaussian noise to each row so that (row-wise) Var(signal)/Var(noise) matches target_snr.

    If per_sample=False (default): use a *global* noise std computed from mean row variance.
    If per_sample=True: use row-specific noise std.
    """
    if not isinstance(rng, np.random.Generator):
        rng = np.random.default_rng(0)

    X = np.asarray(X, dtype=float)
    Xc = X - X.mean(axis=1, keepdims=True)
    sig_var = np.var(Xc, axis=1, ddof=0)  # (n,)

    target_snr = float(target_snr)
    if target_snr <= 0:
        raise ValueError("target_snr must be > 0.")

    if per_sample:
        noise_std = np.sqrt(np.maximum(sig_var / target_snr, 0.0))
    else:
        global_std = float(np.sqrt(np.maximum(sig_var.mean() / target_snr, 0.0)))
        noise_std = np.full(X.shape[0], global_std, dtype=float)

    noise = rng.normal(0.0, noise_std[:, None], size=X.shape)
    X_noisy = X + noise

    s_var = np.var(Xc, axis=1, ddof=0)
    n_var = np.var(noise, axis=1, ddof=0)
    snr_vec = s_var / np.maximum(n_var, 1e-12)

    return dict(
        X_noisy=X_noisy,
        noise_std=noise_std,
        snr_vec=snr_vec,
        snr_mean=float(np.mean(snr_vec)),
    )


def gen_X_cosine(
    n_samples: int,
    fine_t: np.ndarray,
    observed_t: np.ndarray,
    num_basis: int,
    z: np.ndarray,
    *,
    rng: Optional[np.random.Generator] = None,
    add_noise_obs: bool = True,
    target_snr_obs: float = 10.0,
    per_sample_noise: bool = False,
) -> Dict[str, Any]:
    """
    Generate functional X(t)=sum_k C_k phi_k(t) on fine grid and observed grid.
    C_k = R_k * z_k, where R_k ~ Unif(-sqrt(3), sqrt(3)) so Var(R_k)=1.
    """
    if not isinstance(rng, np.random.Generator):
        rng = np.random.default_rng(0)

    fine_t = np.asarray(fine_t, dtype=float)
    observed_t = np.asarray(observed_t, dtype=float)

    if not (isinstance(num_basis, int) and num_basis >= 1):
        raise ValueError("num_basis must be int >= 1")
    z = np.asarray(z, dtype=float)
    if z.shape != (num_basis,):
        raise ValueError(f"z shape must be ({num_basis},), got {z.shape}")

    Phi_fine = np.stack([phi(k, fine_t) for k in range(1, num_basis + 1)], axis=1)  # (Tf, K)
    Phi_obs = np.stack([phi(k, observed_t) for k in range(1, num_basis + 1)], axis=1)  # (To, K)

    R = rng.uniform(-np.sqrt(3.0), np.sqrt(3.0), size=(n_samples, num_basis))  # Var=1
    C = R * z[None, :]  # (n, K)

    X_fine = C @ Phi_fine.T  # (n, Tf)
    X_obs = C @ Phi_obs.T    # (n, To)

    out: Dict[str, Any] = dict(X_fine=X_fine, X_obs=X_obs)

    if add_noise_obs:
        res = add_noise_snr_matrix(X_obs, target_snr_obs, per_sample=per_sample_noise, rng=rng)
        out["X_obs_noisy"] = res["X_noisy"]
        out["X_obs_noise_std"] = res["noise_std"]
        out["snr_obs_actual_vec"] = res["snr_vec"]
        out["snr_obs_actual_mean"] = res["snr_mean"]

    return out
