import numpy as np
import scipy
import random
from timeit import default_timer as timer
from scipy.sparse import coo_array
from utils.movieLensResidualComp import *


def SRPCA_func(X, X_true,OmegaCoord, MissingIndCoord,DistribProps=[None,None], 
               i_PCA=1, alpha=1 ,
               eps=1e-5, eps_tol=1e-4, 
               max_iter=1000, #normalized=False,
               optional_smoothing=False, 
               residErrorType = 'total', #'total' or 'hiddenForMovieLens'
               verbose=False,seed=50):
    
    d1,d2 = X.shape # extract dimensions

    if verbose:
        print('#######################################')
        print('############## SRPCA ##################')
        print('#######################################')
    
    [missingIndMatrixRow,missingIndMatrixCol] = MissingIndCoord
    rmN = len(missingIndMatrixRow)

    [OmegaRow,OmegaCol] = OmegaCoord

    W = {} # Collection of Weight Matricies
    # Define the weights and center the available data for PCA
    for i_col in range(X.shape[-1]):
        qq = OmegaRow[OmegaCol==i_col] # Get row indicies of non-missing values in each column
        W_tmp = coo_array((np.ones(len(qq)), (qq, qq)), shape=(d1, d1)).toarray() # construct sparse weight matrix for each column
        W[i_col] = W_tmp
    # M is the matrix we are completing
    M = X.copy() # Matrix Iterative Estimate
    random.seed(seed)
    np.random.seed(seed)
    M[missingIndMatrixRow,missingIndMatrixCol] = np.random.normal(size=rmN)#.astype(np.float32) # sample from standard normal to fill missing data
    
    start_timer = timer()

    #mask = np.zeros(X.shape)
    #mask[OmegaRow,OmegaCol] = 1
    #C_v = [np.diag(col) for col in mask.T]
    
    #SRPCA: Eigenvalue Decomposition
    PCA_score, PCA_coef = np.linalg.eig(M.T @M)#scipy.sparse.linalg.eigs(M.T @ M)
    PCA_score_org = PCA_score.argsort()[::-1] # organize the eigenvalues in descending order
    V_o = PCA_coef[:,PCA_score_org] # organize the eqigenvectors ind escending order of their corresponding eigenvalues
    try:
        R_o = V_o[:,:i_PCA].T.copy() # extract the top eigenvectors
    except:
        print('Eigenvectors Extraction Error!')
    #print(timer()-start_timer)
    P_o = M @ R_o.T # initial Principal Components construction
    
    # Start iterative estimation part of SRPCA
    R = R_o.copy()
    M_o = P_o @ R_o
    
    for epoch in range(max_iter):
        
        if optional_smoothing:
            M = scipy.signal.medfilt2d(M) # Optional smoothening: other methods can be subsistituted here
        
        M[OmegaRow,OmegaCol] = alpha*X[OmegaRow,OmegaCol] # Correct observed entries
        P = M @ R.T # Update principle components
        for col in range(M.shape[-1]):
            W_tmp = W[col]
            weighted_P = P.T @ W_tmp
            R[:,col] = np.linalg.inv(weighted_P @ P)@weighted_P@ X[:,col]# Eigenvectors update

        #for j in range(M.shape[-1]):
        #    V[j] = np.linalg.solve(np.linalg.multi_dot([P.T, C_v[j], P]),# + 10**(-epoch*5)* np.eye(i_PCA),
        #                               np.linalg.multi_dot([P.T, C_v[j], X[:,j]]))
        #R = V.T
        
        # Check for termination
        M = P @ R
        #M_o = P_o @ R_o
        current_observed_error = np.mean((M[OmegaRow,OmegaCol]-X[OmegaRow,OmegaCol])**2)
        prev_observed_error = np.mean((M_o[OmegaRow,OmegaCol]-X[OmegaRow,OmegaCol])**2)
        if residErrorType=='total':
            error_term = np.mean((M-X_true)**2)
        else:
            error_term = movieLensResidualComputation(M,X_true,MissingIndCoord,DistribProps,eps=eps,upper_lim=5,lower_lim=1)
            
        if prev_observed_error - current_observed_error <= eps_tol:
        #if np.linalg.norm(M - M_o) / d1 / d2 <eps_tol:
            break
        if verbose:
            print(f'Iteration #{epoch}, observed error: {current_observed_error}', f'total error: {error_term}' if residErrorType=='total' else f'hidden error: {error_term}', 
                  f'time elapsed: {timer()-start_timer}')

        P_o = P.copy()
        R_o = R.copy()
        M_o = M.copy()
    
    tme = timer()-start_timer
    M[OmegaRow,OmegaCol] = X[OmegaRow,OmegaCol].copy()
    if residErrorType=='total':
        error_term = np.mean((M-X_true)**2)
    else:
            error_term = movieLensResidualComputation(M,X_true,MissingIndCoord,DistribProps,eps=eps,upper_lim=5,lower_lim=1)
    print(f'Iteration #{epoch}, observed error: {current_observed_error}', f'total error: {error_term}' if residErrorType=='total' else f'hidden error: {error_term}', 
          f'time elapsed: {timer()-start_timer}')
    
    return M, tme, error_term


