import numpy as np
from collections import defaultdict
from sklearn.utils.extmath import randomized_svd
from scipy.sparse.linalg import lsqr
from numpy.linalg import lstsq
from scipy.sparse import coo_matrix, csr_matrix
from math import sqrt
from scipy.linalg import orth
from joblib import Parallel, delayed
from sklearn.linear_model import LinearRegression
from math import sqrt
import tools

def greedy_rpca(M, k = 3, method='alt', rng=None, err_type = 'rmse', iter_lim = None, delta = 20, loss='huber'):
	examples = []
	U = np.zeros((M.shape[0], 0))
	V = np.zeros((M.shape[1], 0))
	omega = omega2 = [(i,j) for i in range(M.shape[0]) for j in range(M.shape[1])]
	mask = defaultdict(list)
	for (i, j) in omega:
		mask[j].append(i)
	mask2 = defaultdict(list)
	for (i, j) in omega:
		mask2[i].append(j)
	omegaT = np.transpose(omega)
	omegaT = (omegaT[0], omegaT[1])
	if omega2 != None:
		omega2T = np.transpose(omega2)
		omega2T = (omega2T[0], omega2T[1])

	def grad_huber(X):
		ret = np.clip(M - X, -delta, delta)
		return ret
	def grad_l2(X):
		return M - X

	grad = grad_huber if loss == 'huber' else grad_l2
	R = grad(np.zeros(M.shape))

	compute_error = tools.abse #if err_type =='rmse' else tools.rse

	errors = {}
	for i in range(k):
		print('computing sv')
		x, _, yT = randomized_svd(R, 1)
		#x, y = tools.get_max_sv(U, V, M, omega, rng=rng)

		print('optimizing')
		if method == 'alt':
			if i % 2 == 0:
				U = np.concatenate((U, x), axis=1)
				V = np.concatenate((V, yT.T), axis=1)
				V = tools.step(U, M, mask, iter_lim = iter_lim, V=V)
			else:
				U = np.concatenate((U, x), axis=1)
				V = np.concatenate((V, yT.T), axis=1)
				U = tools.step(V, M.T, mask2, iter_lim = iter_lim, V=U)
		elif method == 'opt':
			U = np.concatenate((U, x), axis=1)
			V = np.concatenate((V, yT.T), axis=1)
			U = tools.optimize(U, V, M, omega, iter_lim = iter_lim)

		print('computing R')
		#def er(a,b):
		#	return M[a,b] - np.dot(U[a], V[b])
		#ret = Parallel(n_jobs=8)(delayed(er)(a,b) for (a,b) in omega)

		def compute_UV(U, V, omegaT, rng=None):
			def err_val(a):
				return U[omegaT[0], a] * V[omegaT[1], a]
			ret = sum(Parallel(n_jobs=min(12,U.shape[1]))(delayed(err_val)(a) for a in range(U.shape[1])))
			if rng != None:
				ret[ret > rng[1]] = rng[1]
				ret[ret < rng[0]] = rng[0]
			return ret.reshape(U.shape[0], V.shape[0])
		UV = compute_UV(U, V, omegaT, rng=rng)
		R = grad(UV)
		
		S = M - UV
		examples.extend([S[:,140], UV[:, 140]])

		'''
		print('computing error')
		#print(i, 'error', tools.rse(U,V,M,omega2,rng=rng))
		train_error = compute_error(U,V,M,omega,rng=rng, omegaT = omegaT, UV=UV)
		val_error = compute_error(U,V,M,omega2,rng=rng, omegaT=omega2T)
		print(i+1, 'error', train_error, val_error)
		errors[i+1] = (train_error, val_error)
		'''
	print(UV, UV.shape)
	return UV, examples #U, V, errors
