import numpy as np
from scipy.sparse import csc_matrix

INVALID_RESPONSE = -99999

def construct_markov_chain_accelerated(X, lambd=0.1):
    m, _ = X.shape
    D = np.ma.masked_equal(X, INVALID_RESPONSE, copy=False)
    
    D_compl = 1. - D
    M = np.ma.dot(D, D_compl.T) # This computes Mij = sum_l Alj Ali Xli (1-Xlj)
    np.fill_diagonal(M, 0)
    M = np.round(M)
    
    # Add regularization
    M = np.where(np.logical_or((M != 0), (M.T != 0)), M+lambd, M)
    
    d = []
    # Construct a row stochastic matrix
    for i in range(m):
        di = max(np.sum(M[i, :]), 1)
        d.append(di)
        M[i, :] /= max(d[i], 1)
        M[i, i] = 1. - np.sum(M[i, :])

    d = np.array(d)
    return M, d
    

def spectral_estimate(X, max_iters=10000, lambd=1, eps=1e-6):
    """Estimate the hidden parameters according to the Rasch model, either for the tests' difficulties
    or the students' abilities. We follow the convention in Girth https://eribean.github.io/girth/docs/quickstart/quickstart/
    the response matrix X has shape (m, n) where m is the number of items and n is the number of users.
    The algorithm returns the item estimates.
    
    X: np.array of size (m, n) where missing entries have value INVALID_RESPONSE
    max_iters: int, maximum number of iterations to compute the stationary distribution of the Markov chain
    lambd: float, regularization parameter
    eps: tolerance for convergence checking
    
    """
    M, d = construct_markov_chain_accelerated(X, lambd=lambd)
    M = csc_matrix(M)
    
    m = len(A)        
    pi = np.ones((m,)).T
    for _ in range(max_iters):
        pi_next = (pi @ M)
        pi_next /= np.sum(pi_next)
        if np.linalg.norm(pi_next - pi) < eps:
            pi = pi_next
            break
        pi = pi_next
        
    pi = pi.T/d
    beta = np.log(pi)
    beta = beta - np.mean(beta)
    return beta