import numpy as np
from typing import Optional, Dict, Any, Tuple, Union
# ----------------------------
# 1) Helpers: bandwidth + MC box
# ----------------------------

def _isotropic_scott_bandwidth(pooled: np.ndarray) -> float:
    """
    Isotropic Scott's rule bandwidth for Gaussian KDE:
        h = s * n^{-1/(d+4)},  s = sqrt(trace(cov)/d)
    """
    pooled = np.asarray(pooled, dtype=float)
    n, d = pooled.shape
    if n < 2:
        raise ValueError("Need at least 2 pooled points to estimate bandwidth.")
    centered = pooled - pooled.mean(axis=0, keepdims=True)
    cov_trace = np.sum(centered * centered) / max(n - 1, 1)
    s = np.sqrt(cov_trace / d)
    h = s * (n ** (-1.0 / (d + 4.0)))
    if not np.isfinite(h) or h <= 0:
        raise ValueError("Bandwidth selection failed; check your data scaling.")
    return float(h)


def _sample_uniform_box(
    pooled: np.ndarray,
    num_mc: int,
    *,
    margin: float,
    h: float,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]:
    """
    Sample Z ~ Uniform(Ω), where Ω is an expanded axis-aligned box around pooled points.
    Expansion uses `margin * max(range, h)` per dimension.
    Returns (Z, volume, lo, hi).
    """
    pooled = np.asarray(pooled, dtype=float)
    lo = pooled.min(axis=0)
    hi = pooled.max(axis=0)
    span = hi - lo
    expand = margin * np.maximum(span, h)
    lo2 = lo - expand
    hi2 = hi + expand
    side = hi2 - lo2
    volume = float(np.prod(side))
    Z = lo2 + rng.random((int(num_mc), pooled.shape[1])) * side
    return Z, volume, lo2, hi2


# ----------------------------
# 2) KDE eval at MC points
# ----------------------------

def _kde_eval_gaussian_isotropic_chunked(
    X: np.ndarray,
    Z: np.ndarray,
    *,
    h: float,
    mc_batch: int = 1024,
) -> np.ndarray:
    """
    Evaluate isotropic Gaussian KDE from samples X at query points Z.

      fhat(z) = (1/n) * sum_i  (2π)^(-d/2) h^{-d} exp(-||z-x_i||^2/(2h^2))
    """
    X = np.asarray(X, dtype=float)
    Z = np.asarray(Z, dtype=float)
    n = X.shape[0]
    M, d = Z.shape
    if n == 0:
        return np.zeros(M, dtype=float)

    inv_2h2 = 1.0 / (2.0 * h * h)
    norm_const = (2.0 * np.pi) ** (-0.5 * d) * (h ** (-d))

    X2 = np.sum(X * X, axis=1)[None, :]  # (1,n)
    XT = X.T

    out = np.empty(M, dtype=float)
    for j0 in range(0, M, mc_batch):
        j1 = min(j0 + mc_batch, M)
        Zb = Z[j0:j1]                              # (b,d)
        Z2 = np.sum(Zb * Zb, axis=1)[:, None]      # (b,1)
        d2 = Z2 + X2 - 2.0 * (Zb @ XT)             # (b,n)
        out[j0:j1] = norm_const * np.exp(-inv_2h2 * d2).mean(axis=1)
    return out


# ----------------------------
# 3) Threshold selection on training data
# ----------------------------

def select_threshold_kde_cusum_l2(
    data,
    N_train: int,
    factor: float,
    num_mc: int,
    *,
    sigma: Optional[float] = None,
    margin: float = 3.0,
    min_seg_len: int = 20,
    s_step: int = 5,
    s_window: Optional[int] = None,
    mc_batch: int = 1024,
    max_points_per_time: Optional[int] = None,
    seed: Optional[int] = 0,
    threshold_quantile: float = 1.0,  # 1.0 means "max" (your original). Try 0.995 for robustness.
) -> Dict[str, Any]:
    """
    Build KDE on each time i, evaluate it on fixed MC points Z, then compute training scan stats T(t)
    for t <= N_train, and set a threshold.

    Returns a dict containing:
      - threshold_base, threshold
      - sigma, Z, volume
      - phi (N_total_train x num_mc), prefix (N_total_train+1 x num_mc)
      - settings

    NOTE: This function does NOT detect; it only calibrates threshold using data[0:N_train].
    """
    rng = np.random.default_rng(seed)
    N_train = int(N_train)
    if N_train < 2:
        raise ValueError("N_train must be >= 2.")

    # pool training points for sigma and MC box
    pooled_parts = []
    for i in range(N_train):
        Xi = data[i]
        if getattr(Xi, "size", 0):
            pooled_parts.append(np.asarray(Xi, dtype=float))
    if not pooled_parts:
        raise ValueError("Training data contains no points (all empty).")
    pooled = np.concatenate(pooled_parts, axis=0)
    dim = pooled.shape[1]

    if sigma is None:
        sigma = _isotropic_scott_bandwidth(pooled)

    Z, volume, box_lo, box_hi = _sample_uniform_box(
        pooled, int(num_mc), margin=float(margin), h=float(sigma), rng=rng
    )

    def _maybe_subsample(Xi: np.ndarray) -> np.ndarray:
        Xi = np.asarray(Xi, dtype=float)
        if max_points_per_time is None or Xi.shape[0] <= max_points_per_time:
            return Xi
        idx = rng.choice(Xi.shape[0], size=int(max_points_per_time), replace=False)
        return Xi[idx]

    # Precompute phi_i = fhat_i(Z) for i=0..N_train-1
    M = int(num_mc)
    phi = np.zeros((N_train, M), dtype=float)
    for i in range(N_train):
        Xi = data[i]
        if getattr(Xi, "size", 0):
            Xi = _maybe_subsample(Xi)
            phi[i] = _kde_eval_gaussian_isotropic_chunked(Xi, Z, h=float(sigma), mc_batch=mc_batch)
        else:
            phi[i] = 0.0

    # prefix sums for fast averages:
    # prefix[t] = sum_{i=0..t-1} phi_i, so prefix has shape (N_train+1, M)
    prefix = np.zeros((N_train + 1, M), dtype=float)
    prefix[1:] = np.cumsum(phi, axis=0)

    min_seg_len = int(min_seg_len)
    s_step = max(1, int(s_step))

    T_vals = []
    # t is count of time points used so far: t in [2*min_seg_len, N_train]
    for t in range(2 * min_seg_len, N_train + 1):
        pref_t = prefix[t]  # sum_{i=0..t-1}

        s_min = min_seg_len
        s_max = t - min_seg_len
        if s_window is not None:
            s_min = max(s_min, t - int(s_window))

        s_vals = np.arange(s_min, s_max + 1, s_step, dtype=int)
        if s_vals.size == 0:
            continue

        sum_1s = prefix[s_vals]                 # sum_{i=0..s-1}
        sum_st = pref_t[None, :] - sum_1s       # sum_{i=s..t-1}

        a = sum_1s / s_vals[:, None]
        b = sum_st / (t - s_vals)[:, None]
        g = a - b

        l2_sq = volume * np.mean(g * g, axis=1)
        Tt = float(np.sqrt(np.max(l2_sq)))
        T_vals.append(Tt)

    if len(T_vals) == 0:
        threshold_base = 0.0
    else:
        T_vals = np.asarray(T_vals, dtype=float)
        if threshold_quantile >= 1.0:
            threshold_base = float(np.max(T_vals))
        else:
            threshold_base = float(np.quantile(T_vals, float(threshold_quantile)))

    threshold = float(factor) * float(threshold_base)

    return {
        "threshold_base": float(threshold_base),
        "threshold": float(threshold),
        "sigma": float(sigma),
        "Z": Z,
        "volume": float(volume),
        "box_lo": box_lo,
        "box_hi": box_hi,
        "phi_train": phi,
        "prefix_train": prefix,
        "settings": {
            "N_train": N_train,
            "factor": float(factor),
            "num_mc": int(num_mc),
            "margin": float(margin),
            "min_seg_len": int(min_seg_len),
            "s_step": int(s_step),
            "s_window": None if s_window is None else int(s_window),
            "mc_batch": int(mc_batch),
            "max_points_per_time": None if max_points_per_time is None else int(max_points_per_time),
            "seed": seed,
            "dim": int(dim),
            "threshold_quantile": float(threshold_quantile),
        },
    }


# ----------------------------
# 4) Online detection (OUTPUT ONLY A TIME POINT)
# ----------------------------

def detect_change_kde_cusum_l2(
    data,
    N_train: int,
    factor: Union[float, np.ndarray],
    num_mc: int,
    *,
    sigma: Optional[float] = None,
    margin: float = 3.0,
    min_seg_len: int = 20,
    s_step: int = 5,
    s_window: Optional[int] = None,
    mc_batch: int = 1024,
    max_points_per_time: Optional[int] = None,
    seed: Optional[int] = 0,
    t_step: int = 1,
    threshold_quantile: float = 1.0,  # 1.0=max; try 0.995 for fewer false alarms
):
    """
    Online KDE-CUSUM detector.

    If `factor` is a scalar, returns a single integer alarm time.
    If `factor` is an array-like, returns an integer array `result` with the same shape,
    where `result[i]` is the alarm time corresponding to `factor[i]`.

    Convention: returns N_total (len(data)) when no change is detected within the horizon.

    NOTE: This returns the alarm time, not the split location s_hat.
    """
    N_total = len(data)
    N_train = int(N_train)
    if N_train < 2 or N_train > N_total:
        raise ValueError("Require 2 <= N_train <= len(data).")

    # --- Normalize factor input ---
    fac_arr = np.asarray(factor, dtype=float)
    is_scalar = (fac_arr.ndim == 0)
    if is_scalar:
        fac_scalar = float(fac_arr)
        if fac_scalar < 0.0:
            raise ValueError("factor must be >= 0.")
    else:
        fac_shape = fac_arr.shape
        fac_flat = fac_arr.reshape(-1)
        if np.any(fac_flat < 0.0):
            raise ValueError("All factor values must be >= 0.")

    # --- Calibrate once to get threshold_base (independent of factor scaling) ---
    calib = select_threshold_kde_cusum_l2(
        data=data,
        N_train=N_train,
        factor=1.0,  # compute threshold_base once; scale later by factor(s)
        num_mc=num_mc,
        sigma=sigma,
        margin=margin,
        min_seg_len=min_seg_len,
        s_step=s_step,
        s_window=s_window,
        mc_batch=mc_batch,
        max_points_per_time=max_points_per_time,
        seed=seed,
        threshold_quantile=threshold_quantile,
    )

    threshold_base = float(calib["threshold_base"])
    sigma_used = float(calib["sigma"])
    Z = calib["Z"]
    volume = float(calib["volume"])

    if is_scalar:
        threshold = fac_scalar * threshold_base
    else:
        thresholds = fac_flat * threshold_base
        # Sort thresholds so we can fill alarms in one pass.
        order = np.argsort(thresholds)
        thr_sorted = thresholds[order]
        res_sorted = np.full(thr_sorted.shape[0], N_total, dtype=int)
        next_to_set = 0  # first index in thr_sorted not yet assigned

    rng = np.random.default_rng(seed)

    def _maybe_subsample(Xi: np.ndarray) -> np.ndarray:
        Xi = np.asarray(Xi, dtype=float)
        if max_points_per_time is None or Xi.shape[0] <= max_points_per_time:
            return Xi
        idx = rng.choice(Xi.shape[0], size=int(max_points_per_time), replace=False)
        return Xi[idx]

    # --- Precompute KDE values at fixed MC points for all times ---
    M = int(num_mc)
    phi_all = np.zeros((N_total, M), dtype=float)

    # Reuse training phi from calibration (saves time)
    phi_train = calib.get("phi_train", None)
    if phi_train is not None and np.asarray(phi_train).shape == (N_train, M):
        phi_all[:N_train] = np.asarray(phi_train, dtype=float)
        start_i = N_train
    else:
        start_i = 0

    for i in range(start_i, N_total):
        Xi = data[i]
        if getattr(Xi, "size", 0):
            Xi = _maybe_subsample(Xi)
            phi_all[i] = _kde_eval_gaussian_isotropic_chunked(
                Xi, Z, h=sigma_used, mc_batch=mc_batch
            )
        else:
            phi_all[i] = 0.0

    # Prefix sums: prefix[t] = sum_{i=0..t-1} phi_all[i]
    prefix = np.zeros((N_total + 1, M), dtype=float)
    prefix[1:] = np.cumsum(phi_all, axis=0)

    t_step = max(1, int(t_step))
    min_seg_len = int(min_seg_len)
    s_step = max(1, int(s_step))

    # --- Scan online: t is count of time points used so far; alarm time is (t-1) ---
    for t in range(N_train + 1, N_total + 1, t_step):
        if t < 2 * min_seg_len:
            continue

        pref_t = prefix[t]
        s_min = min_seg_len
        s_max = t - min_seg_len
        if s_window is not None:
            s_min = max(s_min, t - int(s_window))

        s_vals = np.arange(s_min, s_max + 1, s_step, dtype=int)
        if s_vals.size == 0:
            continue

        sum_1s = prefix[s_vals]
        sum_st = pref_t[None, :] - sum_1s

        a = sum_1s / s_vals[:, None]
        b = sum_st / (t - s_vals)[:, None]
        g = a - b

        l2_sq = volume * np.mean(g * g, axis=1)
        Tt = float(np.sqrt(np.max(l2_sq)))

        if is_scalar:
            if Tt > threshold:
                return t - 1
        else:
            if next_to_set >= thr_sorted.size:
                break
            # Original rule is strict: alarm if Tt > threshold  <=> threshold < Tt.
            idx = int(np.searchsorted(thr_sorted, Tt, side="left"))
            if idx > next_to_set:
                res_sorted[next_to_set:idx] = t - 1
                next_to_set = idx
                if next_to_set >= thr_sorted.size:
                    break

    if is_scalar:
        return N_total

    # Unsort back to the original factor order and reshape.
    res_flat = np.empty_like(res_sorted)
    res_flat[order] = res_sorted
    return res_flat.reshape(fac_shape)
