from polynomial_optimized import generate_basis_mat
import numpy as np


class full_tensor:
    """
    Precompute univariate basis values and build tensorized summaries on demand.

    Purpose
    -------
    Given points X ∈ R^{N×d} (each row is a point), this class caches
    univariate basis evaluations for each coordinate and supports fast
    accumulation of the order-d tensor of basis outer-products over any
    sub-interval [start, end).

    Usage in simulations
    --------------------
    The detector represents each realization by summing outer-products of
    basis values across dimensions. `compute(start, end)` returns that sum
    over the requested slice.

    Attributes
    ----------
    all_coefficient_mat : (N, dim, m) ndarray
        For each row i and dimension k, the length-m vector of univariate
        basis values [φ_0(x_{ik}), …, φ_{m-1}(x_{ik})].
    dim : int
        Ambient dimension d.
    """
    def __init__(self, m, dim, data):
        # Vectorized basis evaluation for all points and dimensions.
        # generate_basis_mat(m, dim, 1) uses the α=1 (default) scaling.
        self.all_coefficient_mat = generate_basis_mat(m, dim, 1).all_x_multivariate(data)  # (N, dim, m)
        self.dim = dim

    def compute(self, start, end):
        """
        Accumulate the order-d tensor over rows in [start, end).

        Returns
        -------
        T : (m, …, m) ndarray, order = d
            T[a1, …, ad] = Σ_{i=start}^{end-1} Π_{k=1}^d φ_{a_k}(X_{i,k})

        Notes
        -----
        Implemented via `einsum`. For d=3, the contraction is:
            einsum('ia,ib,ic->abc', B0, B1, B2)
        where Bk = basis values for dimension k on the slice.
        """
        B = self.all_coefficient_mat[start:end]  # (N, d, m)
        d = self.dim
        # Build subscripts like 'ia,ib,ic->abc' for general d
        ins = ",".join([f"i{chr(97+k)}" for k in range(d)])
        out = "".join([chr(97+k) for k in range(d)])
        return np.einsum(f"{ins}->{out}", *[B[:, k, :] for k in range(d)])


class poisson_svd:
    """
    Restricted SVD score of a tensor under a coordinate split.

    Procedure
    ---------
    1) Reorder tensor axes so that `index[0]` (row group) come first and
       `index[1]` (column group) follow.
    2) Reshape to a 2D matrix (row_dim × col_dim) and compute its SVD.
    3) Keep the leading `rank` singular values; `compute()` aggregates them
       into a scalar score (kept exactly as implemented below).

    Parameters
    ----------
    dim : int
        Ambient dimension d.
    shapes : list[int]
        Per-dimension basis sizes (usually all equal to m).
    rank : int
        Target rank r for the restricted SVD.
    tensor : ndarray
        Order-d tensor to be analyzed (shape given by `shapes`).
    index : list[list[int]]
        Coordinate split: index[0] = row-group axes, index[1] = col-group axes.

    Attributes
    ----------
    diag : (rank,) ndarray
        Leading singular values of the matricized tensor.
    """
    def __init__(self, dim,shapes, rank,tensor, index):
        self.dim=dim

        self.tensor=tensor
        self.shapes=shapes
        self.subspaces=[]
        
        # Row dimension = product of basis sizes over the first split group
        length=1
        for kk in index[0]:
            length=length*shapes[kk]
        
        # Matricize: bring row-group axes first, then column-group axes
        temp_mat= tensor.transpose(tuple(index[0] +index[1])).reshape(length,-1)
        # Singular values in descending order; keep the top `rank`
        self.diag = np.linalg.svd(temp_mat, full_matrices=False)[1][:rank]
    

    def compute(self):
        """
        Aggregate the top-`rank` singular values into a single score.

        Current implementation (kept exactly as-is):
            return sqrt( sum( prod( diag**2 ) ) )
        which effectively reduces to the product of |diag| values.
        """
        return np.sqrt(np.sum(np.prod(self.diag*self.diag)))
