import numpy as np
from functools import reduce

from polynomial_optimized import generate_basis_mat
from utility import poisson_svd

"""
Matrix-based online change detector.

High-level (paper ↔ code)
-------------------------
• Rolling accumulators over a lag grid:
    record_left[j]  ≈ prefix/past sum   L[j]
    record_right[j] ≈ suffix/recent sum R[j]
• CUSUM-type tensor per lag j:
    D_j = normalized(record_left[j]) - normalized(record_right[j])
• Restricted SVD score:
    score_j = poisson_svd(...).compute()   # low-rank projection-based scalar
• Time-varying threshold:
    threshold_j(t) = (training_max_j / sqrt(log N_train)) * sqrt(log t)

Notes
-----
• Basis: tensor-product of univariate orthonormal polynomials evaluated by
  generate_basis_mat(m, dim, 1); per-time tensor = sum of outer products.
• Coordinate split: `index = [groupA_axes, groupB_axes]` used to matricize
  the tensor before SVD (row/col grouping).
"""


class matrix_detection():
    
    def __init__(self, dim,m,max_len,min_lag,rank, x_train,index):
        # ---------------------------
        # Hyperparameters & metadata
        # ---------------------------
        self.dim=dim
        self.m=m 
        self.shapes=[m for __ in range(dim)]     # tensor shape M×…×M (order = dim)
        self.ranks=rank
        self.min_lag= min_lag                    # number of lags evaluated (0..min_lag-1)
        self.max_len=max_len                     # size of rolling window grid (W)
        self.cur_index=len(x_train)              # time index starts at end of training
        self.N_train=len(x_train)
        # right_coef[j] = W-j (effective recent length at lag j)
        self.right_coef= np.array([k for k in range(max_len,0, -1)])
        self.index=index                         # coordinate split (row/col groups for SVD)
        threshold_factor=1.5                     # outer scaling (swept in tradeoff experiments)
        #self.polynomial=polynomial(self.m, 1)
        self.polynomial=generate_basis_mat(self.m,dim,1)  # basis evaluator (Legendre-based)
        
        
        # ---------------------------
        # Domain handling for scaling
        # ---------------------------
        # (Empirical mins/maxs computed below; current choice fixes [0,1]^d.)
        mins= np.array([np.inf for __ in range(dim)])
        maxs=np.array([-1*np.inf for __ in range(dim)])
        for i in range(self.N_train):
            if len(x_train[i])>0:
                mins = np.minimum(mins, np.min(x_train[i], axis=0))
                maxs = np.maximum( maxs,np.max(x_train[i], axis=0) ) 

        #self.mins =  mins -domain_factor*(maxs-mins)
        #self.maxs =  maxs + domain_factor*(maxs-mins)
        self.mins=np.array([0 for __ in range(dim)])  # fixed to unit hypercube (consistent with generator)
        self.maxs=np.array([1 for __ in range(dim)])
        #print(maxs-mins)

        # ---------------------------------------
        # Initialize rolling prefix/suffix sums
        # record_left, record_right have shape:
        #   (max_len, m, m, ..., m)   # (order = dim)
        # ---------------------------------------
        record_shape= [m for __ in range(dim+1)]
        record_shape[0]=max_len
        record_left = np.zeros(record_shape)   
        record_right= np.zeros(record_shape)   
        
        # Seed ends: left[0] with first training tensor, right[-1] with last of the first W block
        record_left[0] = self.compute_tensor(self.scale(x_train[0]))
        record_right[max_len-1]= self.compute_tensor(self.scale( x_train[max_len-1]))

        # Build cumulative prefix/suffix tensors over the first max_len training points
        for i in range(1, max_len):
            record_left[i] = record_left[i-1] + self.compute_tensor(self.scale(x_train[i]))
            record_right[max_len-1-i] = record_right[max_len-i] + self.compute_tensor(self.scale(x_train[max_len-1-i]))
        
        # ---------------------------------------
        # Threshold calibration on training data
        # ---------------------------------------
        # threshold[j] stores the maximum training score at lag j
        threshold=[0 for __ in range(min_lag)]
        for i in range(max_len+1, self.N_train,1):
            Ti = self.compute_tensor( self.scale(x_train[i]) )

            # Slide recent-window deque forward and add Ti to all active tails
            record_right = np.roll(record_right, -1, axis=0)   
            record_right[-1, :] = 0.0
            record_right += Ti 
            
            # Effective sample size for left averages at time i (1..W terms)
            left_coef=np.array([k for k in range(i-max_len+1,i+1)])
            
            # Evaluate CUSUM-like score across lags j=0..min_lag-1
            for j in range(min_lag):
                # Normalize both sides so D_j approximates mean(past) - mean(recent)
                left= record_left[j] * np.sqrt( self.right_coef[j]/left_coef[j] /(i+1) )  
                right=   record_right[j] *np.sqrt( left_coef[j]/self.right_coef[j] /(i+1) )  
                # Raw difference tensor
                diff=left-right
                # Restricted SVD score under the coordinate split
                if dim>1:
                    diff=poisson_svd(dim, self.shapes,rank,left-right, index).compute()
                # Keep max score seen during training for lag j
                threshold[j]=max(threshold[j] , diff)
            
            # Update prefix deque (left) for next iteration
            prev_last = record_left[-1, :].copy()
            record_left = np.roll(record_left, -1, axis=0)    # drop oldest, make room at end
            record_left[-1, :] = prev_last + Ti  
        
        # Persist rolling buffers and time-scaled thresholds.
        # Online threshold will be threshold * sqrt(log t) / sqrt(log N_train).
        self.record_left=record_left
        self.record_right=record_right
        self.threshold=np.array(threshold)*threshold_factor/np.sqrt( np.log(self.N_train))
        
    #####compute training
    
    def detect(self, new_data):
        """
        Consume one new realization and decide whether to raise an alarm.

        Returns
        -------
        True  : first crossing of the time-varying threshold at some lag j
        False : no alarm; internal buffers updated for next time step
        """
        # Advance time index and build current tensor
        self.cur_index+=1
        Ti = self.compute_tensor(self.scale(new_data) )

        # Slide recent-window deque and add Ti
        self.record_right = np.roll(self.record_right, -1, axis=0)   
        self.record_right[-1, :] = 0.0
        self.record_right += Ti 
            
        # Effective left-side sample sizes at current time
        left_coef=np.array([k for k in range(self.cur_index+1-self.max_len,self.cur_index+1)])
        
        # Time-adapted threshold with sqrt(log t) scaling
        cur_threshold=  self.threshold*np.sqrt(np.log(self.cur_index))

        # Evaluate CUSUM-like score across lags
        for j in range(self.min_lag):
            left= self.record_left[j] * np.sqrt( self.right_coef[j]/left_coef[j] /(self.cur_index+1) )  
            right=   self.record_right[j] *np.sqrt( left_coef[j]/self.right_coef[j] /(self.cur_index+1) )
            diff=left-right
            if self.dim>1:
                diff=poisson_svd(self.dim, self.shapes,self.ranks,left-right, self.index).compute()
            cusum_value_temp=   diff 
            
            # First crossing → alarm
            if cusum_value_temp> cur_threshold[j]:
                #print(self.record_left[j], self.record_right[j])
                return True
                #print(i)
        
        # No alarm: update prefix deque for the next step
        prev_last = self.record_left[-1, :].copy()
        self.record_left = np.roll(self.record_left, -1, axis=0)    # drop oldest, make room at end
        self.record_left[-1, :] = prev_last + Ti  
        return False    
    
    
    
    
    
    def scale(self, mat):
        """
        Map coordinates to [0,1]^d using fixed bounds.
        If `mat` is empty (no points), return [] to skip contributions.
        """
        if len (mat)==0:
            return []
        return   (mat - self.mins) / (self.maxs - self.mins)
        
        
        

    def compute_tensor(self,mat):
        """
        Aggregate basis outer-products over all points in `mat` to produce an
        order-d tensor of shape `self.shapes` (M×…×M). Each point contributes
        the outer product of its univariate basis evaluations across dimensions.
        """
        result= np.zeros(self.shapes)
        if len(mat)==0:
            return result
        basis_mat = self.polynomial.all_x_multivariate(mat)  # (N_points, dim, m)
        for v in basis_mat:
            # v is a length-`dim` list/array of length-`m` vectors; multiply outer across dims
            result+= reduce(np.multiply.outer,  v)
        return  result
        
        
        """
        

    """
