import numpy as np
from typing import Optional, Literal, Union

ReturnTime = Literal["start", "mid", "end"]


class MMD_detection:
    """
    More stable exact MMD detector (RBF kernel), same core statistic.

    Stability upgrades (simple, not fancy):
      1) Deterministic subsampling per window index i (so threshold doesn't jitter).
      2) Default training windows are non-overlapping (train_step = B) to reduce correlation.
      3) Default quantile is less extreme (0.99 instead of 0.995).
      4) Clamp quantile so it never degenerates to the max when sample size is small.
      5) Optional "consecutive hits" to avoid 1-spike alarms (default 1 = off).

    Everything else is your fast exact chunked kernel-sum implementation.
    """

    def __init__(
        self,
        data,
        N_total: int,
        N_train: int,
        sigma: Optional[float],
        B: int,
        num_block: int,
        factor: Union[float, np.ndarray],
        *,
        max_points: Optional[int] = None,
        seed: Optional[int] = None,
        threshold_quantile: float = 0.99,      # <-- less extreme default
        sample_with_replacement: bool = True,
        stat_aggregation: str = "mean",        # "mean" or "sum"
        train_step: Optional[int] = None,      # default: B (non-overlapping for stability)
        detect_step: Optional[int] = 1,        # default 1
        chunk_size: int = 2048,
        deterministic_subsample: bool = True,  # <-- key stability switch
        min_train_windows: int = 20,           # <-- if too few stats, be conservative
        consecutive: int = 1,                  # <-- 1 = original behavior
    ):
        self.data = list(data)
        self.N = int(N_total)
        self.N_train = int(N_train)
        self.B = int(B)
        self.num_block = int(num_block)
        self.max_points = None if max_points is None else int(max_points)

        self.rng = np.random.default_rng(seed)
        self._base_seed = 0 if seed is None else (int(seed) & 0xFFFFFFFF)

        # Allow a single float factor or an increasing 1D array of factors.
        factors = np.asarray(factor, dtype=float).ravel()
        if factors.size == 0:
            raise ValueError("factor must be a float or a non-empty 1D array.")
        if not np.all(np.isfinite(factors)):
            raise ValueError("factor(s) must be finite.")
        if factors.size > 1 and np.any(np.diff(factors) < 0):
            raise ValueError("factor array must be sorted from small to large (nondecreasing).")
        self.factors = factors

        if len(self.data) == 0:
            raise ValueError("data is empty.")
        if self.N <= 0:
            raise ValueError("N_total must be positive.")
        if self.N_train <= 0 or self.N_train > self.N:
            raise ValueError("N_train must be in {1,...,N_total}.")
        if self.B <= 0:
            raise ValueError("B must be positive.")
        if self.num_block <= 0:
            raise ValueError("num_block must be positive.")
        if self.num_block * self.B > self.N_train:
            raise ValueError("Need num_block*B <= N_train so baseline fits in training.")
        if chunk_size <= 0:
            raise ValueError("chunk_size must be positive.")
        self.chunk_size = int(chunk_size)

        d0 = np.asarray(self.data[0])
        if d0.ndim != 2:
            raise ValueError("data[t] must be a 2D array (n_t, dim).")
        self.dim = int(d0.shape[1])

        if not (0.0 < threshold_quantile < 1.0):
            raise ValueError("threshold_quantile must be in (0,1).")
        self.threshold_quantile = float(threshold_quantile)
        self.sample_with_replacement = bool(sample_with_replacement)

        stat_aggregation = str(stat_aggregation).lower()
        if stat_aggregation not in ("mean", "sum"):
            raise ValueError("stat_aggregation must be 'mean' or 'sum'.")
        self.stat_aggregation = stat_aggregation

        # Stability: default to non-overlapping training windows
        default_train_step = max(1, self.B)
        self.train_step = default_train_step if train_step is None else max(1, int(train_step))
        self.detect_step = 1 if detect_step is None else max(1, int(detect_step))

        self.deterministic_subsample = bool(deterministic_subsample)
        self.min_train_windows = max(1, int(min_train_windows))
        self.consecutive = max(1, int(consecutive))

        # ---------- helpers ----------
        def _pool_window(arr_list) -> np.ndarray:
            parts = [np.asarray(a, dtype=np.float64) for a in arr_list if getattr(a, "size", 0)]
            if not parts:
                return np.empty((0, self.dim), dtype=np.float64)
            return np.concatenate(parts, axis=0)

        def _keyed_rng(key: int):
            # cheap deterministic RNG per key
            s = (self._base_seed + 1000003 * int(key)) & 0xFFFFFFFF
            return np.random.default_rng(s)

        def _subsample_fixed(Z: np.ndarray, n_fixed: Optional[int], key: Optional[int] = None) -> np.ndarray:
            if n_fixed is None:
                return Z
            n = Z.shape[0]
            if n == 0:
                return Z

            # Deterministic per window/block if enabled (big stability win)
            rng = self.rng
            if self.deterministic_subsample and (key is not None):
                rng = _keyed_rng(key)

            if n >= n_fixed:
                idx = rng.choice(n, size=n_fixed, replace=False)
                return Z[idx]

            # Do NOT upsample by default; keep original n to avoid changing estimator scale.
            if self.sample_with_replacement:
                return Z
            return Z

        self._pool_window = _pool_window
        self._subsample_fixed = _subsample_fixed

        # Core fast kernel sum:
        def _rbf_sum_chunked_BT(A: np.ndarray, BT: np.ndarray, A2: np.ndarray, B2T: np.ndarray, gamma: float) -> float:
            CHUNK = self.chunk_size
            total = 0.0
            m = A.shape[0]
            exp = np.exp
            maximum = np.maximum
            for i0 in range(0, m, CHUNK):
                i1 = min(i0 + CHUNK, m)
                Ai = A[i0:i1]
                Ai2 = A2[i0:i1]  # (c,1)

                G = Ai @ BT
                G *= -2.0
                G += Ai2
                G += B2T
                maximum(G, 0.0, out=G)   # clamp negatives from roundoff
                G *= -gamma
                exp(G, out=G)
                total += G.sum(dtype=np.float64)
            return float(total)

        self._rbf_sum_chunked_BT = _rbf_sum_chunked_BT

        # ---------- choose sigma (median heuristic) ----------
        if sigma is None:
            pooled = _pool_window(self.data[: self.N_train])
            if pooled.shape[0] < 2:
                raise ValueError("Not enough pooled training points to estimate sigma.")
            M = min(2000, pooled.shape[0])
            idx = self.rng.choice(pooled.shape[0], size=M, replace=False)
            S = pooled[idx]
            P = min(20000, M * (M - 1) // 2)
            ia = self.rng.integers(0, M, size=P)
            ib = self.rng.integers(0, M, size=P)
            d2 = np.sum((S[ia] - S[ib]) ** 2, axis=1)
            d2 = d2[d2 > 0]
            med = float(np.median(d2)) if d2.size else 1.0
            sigma = float(np.sqrt(max(med, 1e-12) / 2.0))

        self.sigma = float(sigma)
        if self.sigma <= 0:
            raise ValueError("sigma must be > 0 (or None to estimate).")
        self.gamma = 1.0 / (2.0 * self.sigma * self.sigma)

        # ---------- baseline blocks ----------
        self.X = []
        for k in range(self.num_block):
            Zk = _pool_window(self.data[(k * self.B): ((k + 1) * self.B)])
            Zk = _subsample_fixed(Zk, self.max_points, key=-(k + 1))  # deterministic per block
            self.X.append(Zk)

        self.valid_k = [k for k in range(self.num_block) if self.X[k].shape[0] >= 2]
        if not self.valid_k:
            self.threshold = float("inf")
            return

        # precompute baseline self terms xx[k]
        self.X2 = {}
        self.m = {}
        self.xx = {}
        for k in self.valid_k:
            Xk = self.X[k]
            mk = Xk.shape[0]
            X2k = np.sum(Xk * Xk, axis=1)[:, None]   # (m,1)
            XT = Xk.T                                 # (d,m)
            X2T = X2k.T                               # (1,m)
            Kxx_sum = _rbf_sum_chunked_BT(Xk, XT, X2k, X2T, self.gamma)
            Kxx_off = Kxx_sum - mk
            self.X2[k] = X2k
            self.m[k] = mk
            self.xx[k] = Kxx_off / (mk * (mk - 1))

        # ---------- training stats -> threshold ----------
        train_stats = []
        start = self.num_block * self.B
        end = self.N_train - self.B + 1

        for i in range(start, end, self.train_step):
            Y = _pool_window(self.data[i: (i + self.B)])
            Y = _subsample_fixed(Y, self.max_points, key=i)  # deterministic per window
            n = Y.shape[0]
            if n < 2:
                continue

            Y2 = np.sum(Y * Y, axis=1)[:, None]
            YT = Y.T
            Y2T = Y2.T

            Kyy_sum = _rbf_sum_chunked_BT(Y, YT, Y2, Y2T, self.gamma)
            yy = (Kyy_sum - n) / (n * (n - 1))

            vals = []
            for k in self.valid_k:
                Kxy_sum = _rbf_sum_chunked_BT(self.X[k], YT, self.X2[k], Y2T, self.gamma)
                mmd_k = self.xx[k] + yy - 2.0 * (Kxy_sum / (self.m[k] * n))
                vals.append(mmd_k)

            rec = float(np.sum(vals)) if self.stat_aggregation == "sum" else float(np.mean(vals))
            train_stats.append(max(rec, 0.0))

        if len(train_stats) < self.min_train_windows:
            # too few points => extreme quantiles are meaningless; be conservative
            self.threshold_base = float("inf")
            self.thresholds = self.factors * self.threshold_base
            self.threshold = float("inf")
            return

        train_stats = np.asarray(train_stats, dtype=float)

        # Clamp q so it cannot behave like "max" when sample size is small.
        # Example: if n=100 and q=0.995, clamp to 0.99.
        n_ts = train_stats.size
        q_eff = min(self.threshold_quantile, (n_ts - 1) / n_ts)

        base = float(np.quantile(train_stats, q_eff))
        self.threshold_base = base
        self.thresholds = self.factors * base
        self.threshold = float(self.thresholds[0])

    def detect_many(
        self,
        factors: np.ndarray,
        *,
        step: Optional[int] = None,
        return_time: ReturnTime = "end",
    ) -> np.ndarray:
        """
        Compute change-point estimates for a *sorted* array of factors.
        Returns N_total when no change is detected under that threshold.
        """
        factors = np.asarray(factors, dtype=float).ravel()
        if factors.size == 0:
            return np.empty((0,), dtype=int)
        if not np.all(np.isfinite(factors)):
            raise ValueError("factors must be finite.")
        if factors.size > 1 and np.any(np.diff(factors) < 0):
            raise ValueError("factors must be sorted from small to large (nondecreasing).")

        if not getattr(self, "valid_k", None) or not self.valid_k:
            return np.full(factors.shape, self.N, dtype=int)
        base = float(getattr(self, "threshold_base", float("inf")))
        if not np.isfinite(base):
            return np.full(factors.shape, self.N, dtype=int)

        thresholds = factors * base
        K = thresholds.size
        out = np.full((K,), self.N, dtype=int)
        hits = np.zeros((K,), dtype=np.int32)

        step = self.detect_step if step is None else max(1, int(step))

        end = self.N - self.B + 1
        for i in range(self.N_train, end, step):
            Y = self._pool_window(self.data[i: (i + self.B)])
            Y = self._subsample_fixed(Y, self.max_points, key=i)
            n = Y.shape[0]
            if n < 2:
                continue

            Y2 = np.sum(Y * Y, axis=1)[:, None]
            YT = Y.T
            Y2T = Y2.T

            Kyy_sum = self._rbf_sum_chunked_BT(Y, YT, Y2, Y2T, self.gamma)
            yy = (Kyy_sum - n) / (n * (n - 1))

            vals = []
            for k in self.valid_k:
                Kxy_sum = self._rbf_sum_chunked_BT(self.X[k], YT, self.X2[k], Y2T, self.gamma)
                mmd_k = self.xx[k] + yy - 2.0 * (Kxy_sum / (self.m[k] * n))
                vals.append(mmd_k)

            rec = float(np.sum(vals)) if self.stat_aggregation == "sum" else float(np.mean(vals))
            rec = max(rec, 0.0)

            mask = rec > thresholds
            hits[mask] += 1
            hits[~mask] = 0

            newly = (out == self.N) & (hits >= self.consecutive)
            if np.any(newly):
                if return_time == "start":
                    t = i
                elif return_time == "mid":
                    t = min(i + self.B // 2, self.N)
                else:
                    t = min(i + self.B, self.N)
                out[newly] = t
                if np.all(out != self.N):
                    break

        return out

    def detect(self, *, step: Optional[int] = None, return_time: ReturnTime = "end"):
        if not getattr(self, "valid_k", None) or not self.valid_k:
            return self.N
        if not np.isfinite(getattr(self, "threshold", float("inf"))):
            return self.N

        step = self.detect_step if step is None else max(1, int(step))

        # If multiple factors were provided, return an array of change-points.
        if getattr(self, "thresholds", None) is not None and np.asarray(self.thresholds).size > 1:
            return self.detect_many(self.factors, step=step, return_time=return_time)

        hits = 0
        end = self.N - self.B + 1
        for i in range(self.N_train, end, step):
            Y = self._pool_window(self.data[i: (i + self.B)])
            Y = self._subsample_fixed(Y, self.max_points, key=i)
            n = Y.shape[0]
            if n < 2:
                continue

            Y2 = np.sum(Y * Y, axis=1)[:, None]
            YT = Y.T
            Y2T = Y2.T

            Kyy_sum = self._rbf_sum_chunked_BT(Y, YT, Y2, Y2T, self.gamma)
            yy = (Kyy_sum - n) / (n * (n - 1))

            vals = []
            for k in self.valid_k:
                Kxy_sum = self._rbf_sum_chunked_BT(self.X[k], YT, self.X2[k], Y2T, self.gamma)
                mmd_k = self.xx[k] + yy - 2.0 * (Kxy_sum / (self.m[k] * n))
                vals.append(mmd_k)

            rec = float(np.sum(vals)) if self.stat_aggregation == "sum" else float(np.mean(vals))
            rec = max(rec, 0.0)

            if rec > self.threshold:
                hits += 1
                if hits >= self.consecutive:
                    if return_time == "start":
                        return i
                    if return_time == "mid":
                        return min(i + self.B // 2, self.N)
                    return min(i + self.B, self.N)
            else:
                hits = 0

        return self.N
