import numpy as np
from scipy import sparse
from sklearn.utils.extmath import randomized_svd
import fbpca
from greedy_rpca import greedy_rpca
from rpca_other import R_pca

TOL=1e-9
MAX_ITERS=3

def converged(Z, d_norm):
	err = np.linalg.norm(Z, 'fro') / d_norm
	print('error: ', err)
	return err < TOL

def shrink(M, tau):
	S = np.abs(M) - tau
	return np.sign(M) * np.where(S>0, S, 0)

def _svd(M, rank): 
	return fbpca.pca(M, k=min(rank, np.min(M.shape)), raw=True)

def norm_op(M): 
	return _svd(M, 1)[1][0]

def svd_reconstruct(M, rank, min_sv):
	u, s, v = _svd(M, rank)
	s -= min_sv
	nnz = (s > 0).sum()
	return u[:,:nnz] @ np.diag(s[:nnz]) @ v[:nnz], nnz

def pcp(X, maxiter=10, k=10): # refactored
	m, n = X.shape
	trans = m<n
	if trans: X = X.T; m, n = X.shape
		
	lamda = 0.001 #1/np.sqrt(m)
	op_norm = norm_op(X)
	Y = np.copy(X) / max(op_norm, np.linalg.norm( X, np.inf) / lamda)
	mu = k*1.25/op_norm; mu_bar = mu * 1e7; rho = k * 1.5
	
	d_norm = np.linalg.norm(X, 'fro')
	L = np.zeros_like(X); sv = 1
	
	examples = []
	
	for i in range(maxiter):
		print("rank sv:", sv)
		X2 = X + Y/mu
		
		# update estimate of Sparse Matrix by "shrinking/truncating": original - low-rank
		S = shrink(X2 - L, lamda/mu)
		
		# update estimate of Low-rank Matrix by doing truncated SVD of rank sv & reconstructing.
		# count of singular values > 1/mu is returned as svp
		L, svp = svd_reconstruct(X2 - S, sv, 1/mu)
		
		# If svp < sv, you are already calculating enough singular values.
		# If not, add 20% (in this case 240) to sv
		sv = svp + (1 if svp < sv else round(0.05*n))
		
		# residual
		Z = X - L - S
		Y += mu*Z; mu *= rho
		
		if trans: examples.extend([S[140,:], L[140,:]])
		else: examples.extend([S[:, 140], L[:, 140]])
		
		if m > mu_bar: m = mu_bar
		if converged(Z, d_norm): break
	
	if trans: L=L.T; S=S.T
	return L, S, examples

def rpca(M, loss = 'huber'):
	if loss == 'huber': UV, examples = greedy_rpca(M, k=3, loss='huber')
	else: UV, examples = greedy_rpca(M, k=3, loss='l2', iter_lim=1)
	S = M - UV
	UV = np.clip(np.abs(UV), 0, 255)
	S = np.clip(np.abs(S), 0, 255)
	return UV, S, examples

def just_pca(M):
	U, sigma, VT = randomized_svd(M, 3)
	UV = np.dot(np.dot(U, np.diag(sigma)), VT)
	S = M - UV
	UV = np.clip(np.abs(UV), 0, 255)
	S = np.clip(np.abs(S), 0, 255)
	return UV, S #, examples

def pcp2(M):
	rpca = R_pca(M)
	UV, S = rpca.fit()
	UV = np.clip(np.abs(UV), 0, 255)
	S = np.clip(np.abs(S), 0, 255)
	return UV, S
