import numpy as np
from collections import defaultdict
from sklearn.utils.extmath import randomized_svd
from scipy.sparse.linalg import lsqr
from numpy.linalg import lstsq
from scipy.sparse import coo_matrix, csr_matrix
from math import sqrt
from scipy.linalg import orth
from alt import alt 
from joblib import Parallel, delayed
from sklearn.linear_model import LinearRegression
from math import sqrt
import tools

def greedy(M, omega, k = 3, omega2=None, method='alt', rng=None, err_type = 'rmse', iter_lim = None):
	U = np.zeros((M.shape[0], 0))
	V = np.zeros((M.shape[1], 0))
	mask = defaultdict(list)
	for (i, j) in omega:
		mask[j].append(i)
	mask2 = defaultdict(list)
	for (i, j) in omega:
		mask2[i].append(j)
	omegaT = np.transpose(omega)
	omegaT = (omegaT[0], omegaT[1])
	if omega2 != None:
		omega2T = np.transpose(omega2)
		omega2T = (omega2T[0], omega2T[1])

	R = []
	for (a,b) in omega:
		R.append(M[a,b])
	R = csr_matrix((R, omegaT), shape=M.shape)

	compute_error = tools.rmse if err_type =='rmse' else tools.rse

	errors = {}
	for i in range(k):
		print('computing sv')
		x, _, yT = randomized_svd(R, 1)
		#x, y = tools.get_max_sv(U, V, M, omega, rng=rng)

		print('optimizing')
		if method == 'alt':
			if i % 2 == 0:
				U = np.concatenate((U, x), axis=1)
				V = np.concatenate((V, yT.T), axis=1)
				V = tools.step(U, M, mask, iter_lim = iter_lim, V=V)
			else:
				U = np.concatenate((U, x), axis=1)
				V = np.concatenate((V, yT.T), axis=1)
				U = tools.step(V, M.T, mask2, iter_lim = iter_lim, V=U)
		elif method == 'opt':
			U = np.concatenate((U, x), axis=1)
			V = np.concatenate((V, yT.T), axis=1)
			U = tools.optimize(U, V, M, omega, iter_lim = iter_lim)

		print('computing R')
		#def er(a,b):
		#	return M[a,b] - np.dot(U[a], V[b])
		#ret = Parallel(n_jobs=8)(delayed(er)(a,b) for (a,b) in omega)

		def compute_UV(U, V, omegaT, rng=None):
			def err_val(a):
				return U[omegaT[0], a] * V[omegaT[1], a]
			ret = sum(Parallel(n_jobs=min(12,U.shape[1]))(delayed(err_val)(a) for a in range(U.shape[1])))
			if rng != None:
				ret[ret > rng[1]] = rng[1]
				ret[ret < rng[0]] = rng[0]
			return csr_matrix((ret, omegaT), shape=(U.shape[0], V.shape[0]))
		UV = compute_UV(U, V, omegaT, rng=rng)
		R = M - UV

		#error = 0
		#error = sum(Parallel(n_jobs=32)(delayed(lambda x : x**2 )(r) for r in ret))
		#print('error', error, error / len(omega))

		##A = np.dot(U, V.T)
		##A_ = np.zeros(A.shape)
		##A_[omegaT] = A[omegaT]
		##R = M_ - A_ #- rho * A
		print('computing error')
		#print(i, 'error', tools.rse(U,V,M,omega2,rng=rng))
		train_error = compute_error(U,V,M,omega,rng=rng, omegaT = omegaT, UV=UV)
		val_error = compute_error(U,V,M,omega2,rng=rng, omegaT=omega2T)
		print(i+1, 'error', train_error, val_error)
		errors[i+1] = (train_error, val_error)
	return U, V, errors
