import numpy as np
from functools import reduce
from typing import Optional
from polynomial_optimized import generate_basis_mat


class matrix_detection:
    """
    Your original matrix detector, extended to support:
      - factor as a scalar (backward compatible)
      - factor as a NumPy array: return alarm times for each factor (multi-threshold scan)

    Key idea:
      The training phase produces a *base* threshold vector (per lag) that does NOT depend on factor:
          threshold0[j] = threshold_train[j] / (log(N_train)^2)

      Online at time t (internal cur_index), the effective threshold is:
          threshold0[j] * factor * (log(cur_index)^2)

      So we can compute a single "ratio score" each time step:
          score(t) = max_j  cusum_j(t) / (threshold0[j] * log(cur_index)^2)

      Then factor alarms when score(t) > factor.
    """

    def __init__(self, dim, m, max_len, min_lag, rank, x_train, index, factor=1.0):
        self.dim = int(dim)
        self.m = int(m)
        self.shapes = [self.m for _ in range(self.dim)]
        self.ranks = int(rank)
        self.min_lag = int(min_lag)
        self.max_len = int(max_len)

        self.N_train = len(x_train)
        # keep your original convention
        self.cur_index = len(x_train)

        self.right_coef = np.array([k for k in range(self.max_len, 0, -1)], dtype=float)
        self.index = index

        # store default factor (scalar) for detect()
        self.factor = float(np.asarray(factor).reshape(-1)[0]) if np.ndim(factor) != 0 else float(factor)

        self.polynomial = generate_basis_mat(self.m, self.dim, 1)

        # domain scaling (kept exactly as in your file: fixed [0,1]^d)
        self.mins = np.array([0.0 for _ in range(self.dim)], dtype=float)
        self.maxs = np.array([1.0 for _ in range(self.dim)], dtype=float)

        record_shape = [self.m for _ in range(self.dim + 1)]
        record_shape[0] = self.max_len
        record_left = np.zeros(record_shape, dtype=float)
        record_right = np.zeros(record_shape, dtype=float)

        record_left[0] = self.compute_tensor(self.scale(x_train[0]))
        record_right[self.max_len - 1] = self.compute_tensor(self.scale(x_train[self.max_len - 1]))

        for i in range(1, self.max_len):
            record_left[i] = record_left[i - 1] + self.compute_tensor(self.scale(x_train[i]))
            record_right[self.max_len - 1 - i] = (
                record_right[self.max_len - i] + self.compute_tensor(self.scale(x_train[self.max_len - 1 - i]))
            )

        threshold_train = [0.0 for _ in range(self.min_lag)]

        for i in range(self.max_len + 1, self.N_train, 1):
            Ti = self.compute_tensor(self.scale(x_train[i]))

            record_right = np.roll(record_right, -1, axis=0)
            record_right[-1, :] = 0.0
            record_right += Ti

            left_coef = np.array([k for k in range(i - self.max_len + 1, i + 1)], dtype=float)

            for j in range(self.min_lag):
                left = record_left[j] * np.sqrt(self.right_coef[j] / left_coef[j] / (i + 1.0))
                right = record_right[j] * np.sqrt(left_coef[j] / self.right_coef[j] / (i + 1.0))
                diff = left - right
                if self.dim > 1:
                    diff = poisson_svd(self.dim, self.shapes, self.ranks, diff, self.index).compute()
                threshold_train[j] = max(threshold_train[j], float(diff))

            # update record_left
            prev_last = record_left[-1, :].copy()
            record_left = np.roll(record_left, -1, axis=0)
            record_left[-1, :] = prev_last + Ti

        self.record_left = record_left
        self.record_right = record_right

        # --- base threshold independent of factor ---
        self._logN_train_sq = float(np.log(max(self.N_train, 2)) ** 2)
        self.threshold0 = np.asarray(threshold_train, dtype=float) / max(self._logN_train_sq, 1e-12)

    def scale(self, mat):
        if len(mat) == 0:
            return []
        return (mat - self.mins) / (self.maxs - self.mins)

    def compute_tensor(self, mat):
        result = np.zeros(self.shapes, dtype=float)
        if len(mat) == 0:
            return result
        basis_mat = self.polynomial.all_x_multivariate(mat)
        for v in basis_mat:
            result += reduce(np.multiply.outer, v)
        return result

    # ----------------------------
    # Core online step (always advances state)
    # ----------------------------
    def step_score(self, new_data) -> float:
        """
        Advance one time step and return the scalar score:

            score = max_j cusum_j / (threshold0[j] * log(cur_index)^2)

        Larger score => easier to alarm. A factor alarms when score > factor.
        """
        self.cur_index += 1

        Ti = self.compute_tensor(self.scale(new_data))
        self.record_right = np.roll(self.record_right, -1, axis=0)
        self.record_right[-1, :] = 0.0
        self.record_right += Ti

        left_coef = np.array(
            [k for k in range(self.cur_index + 1 - self.max_len, self.cur_index + 1)],
            dtype=float,
        )

        log_cur_sq = float(np.log(max(self.cur_index, 2)) ** 2)
        denom_vec = self.threshold0 * log_cur_sq  # (min_lag,)

        eps = 1e-12
        score = 0.0
        for j in range(self.min_lag):
            left = self.record_left[j] * np.sqrt(self.right_coef[j] / left_coef[j] / (self.cur_index + 1.0))
            right = self.record_right[j] * np.sqrt(left_coef[j] / self.right_coef[j] / (self.cur_index + 1.0))
            diff = left - right
            if self.dim > 1:
                diff = poisson_svd(self.dim, self.shapes, self.ranks, diff, self.index).compute()
            denom = float(denom_vec[j])
            if denom <= eps:
                # if training threshold is (near) zero, any positive diff is "huge"
                ratio = float("inf") if float(diff) > 0.0 else 0.0
            else:
                ratio = float(diff) / denom
            if ratio > score:
                score = ratio

        # update record_left (always)
        prev_last = self.record_left[-1, :].copy()
        self.record_left = np.roll(self.record_left, -1, axis=0)
        self.record_left[-1, :] = prev_last + Ti
        return score

    # ----------------------------
    # Backward-compatible single-factor detect
    # ----------------------------
    def detect(self, new_data, factor: Optional[float] = None) -> bool:
        """
        Backward-compatible boolean alarm for a single factor.
        If factor is None, uses self.factor (from __init__).
        """
        f = self.factor if factor is None else float(factor)
        score = self.step_score(new_data)
        return bool(score > f)

    # ----------------------------
    # NEW: multi-factor alarm times
    # ----------------------------
    def detect_alarm_times(
        self,
        data_stream,
        factors: np.ndarray,
        *,
        start_index: Optional[int] = None,
        sentinel: Optional[int] = None,
    ) -> np.ndarray:
        """
        Scan over `data_stream` (list/iterable of new_data) and return alarm times for each factor.

        Parameters
        ----------
        data_stream:
            iterable of new_data arriving AFTER the training prefix used to build this object.
        factors:
            np.ndarray of threshold multipliers; result[i] is the FIRST alarm time for factors[i].
        start_index:
            global time index for data_stream[0]. Default: len(x_train) used at init.
            (So if you did x_train=data[:N_train], pass data_stream=data[N_train:], start_index defaults to N_train.)
        sentinel:
            value to return when no alarm occurs. Default: start_index + len(data_stream).

        Returns
        -------
        alarms: np.ndarray of same shape as factors (dtype=int)
        """
        fac = np.asarray(factors, dtype=float)
        scalar = (fac.ndim == 0)
        fac_flat = fac.reshape(-1)

        if start_index is None:
            start_index = int(self.N_train)  # the natural global index of the first post-train element
        if sentinel is None:
            # "no change" => end of stream index
            sentinel = int(start_index + len(data_stream))

        alarms = np.full(fac_flat.shape[0], int(sentinel), dtype=int)
        active = np.ones(fac_flat.shape[0], dtype=bool)

        for k, new_data in enumerate(data_stream):
            if not np.any(active):
                break
            score = self.step_score(new_data)
            hit = active & (score > fac_flat)
            if np.any(hit):
                alarms[hit] = int(start_index + k)
                active[hit] = False

        out = alarms.reshape(fac.shape)
        return int(out) if scalar else out


class poisson_svd:
    def __init__(self, dim, shapes, rank, tensor, index):
        self.dim = dim
        self.tensor = tensor
        self.shapes = shapes

        length = 1
        for kk in index[0]:
            length *= shapes[kk]

        temp_mat = tensor.transpose(tuple(index[0] + index[1])).reshape(length, -1)
        self.diag = np.linalg.svd(temp_mat, full_matrices=False)[1][:rank]

    def compute(self):
        return float(np.sqrt(np.sum(self.diag ** 2)))
