import numpy as np
from collections import defaultdict
from sklearn.utils.extmath import randomized_svd
from scipy.sparse.linalg import lsqr
from numpy.linalg import lstsq
from scipy.sparse import coo_matrix, csr_matrix
from math import sqrt
from scipy.linalg import orth
from alt import alt 
from joblib import Parallel, delayed
from sklearn.linear_model import LinearRegression
from math import sqrt
import tools

def local_newnew(M, omega, k = 3, omega2=None, method='alt', rng=None, err_type = 'rmse', iter_lim = None, warm_start = None):
	U = np.zeros((M.shape[0], 0))
	V = np.zeros((M.shape[1], 0))
	mask = defaultdict(list)
	for (i, j) in omega:
		mask[j].append(i)
	mask2 = defaultdict(list)
	for (i, j) in omega:
		mask2[i].append(j)
	omegaT = np.transpose(omega)
	omegaT = (omegaT[0], omegaT[1])
	if omega2 != None:
		omega2T = np.transpose(omega2)
		omega2T = (omega2T[0], omega2T[1])


	compute_error = tools.rmse if err_type =='rmse' else tools.rse

	errors = {}
	i = -1
	if warm_start is not None:
		i, U, V, UV, szs = warm_start
		train_error = compute_error(U,V,M,omega,rng=rng, omegaT = omegaT, UV = UV)
		R = M - UV
	else:
		train_error = compute_error(U,V,M,omega,rng=rng, omegaT = omegaT, UV = csr_matrix((U.shape[0], V.shape[0])))
		R = []
		for (a,b) in omega:
			R.append(M[a,b])
		R = csr_matrix((R, omegaT), shape=M.shape)
	while True:
		i += 1
		U_prev = U.copy()
		V_prev = V.copy()
		train_error_prev = train_error
		print('computing sv')
		x, _, yT = randomized_svd(R, 1)
		#x, y = tools.get_max_sv(U, V, M, omega, rng=rng)

		print('reducing rank')
		if i >= k:
			j = np.argmin(szs)
			U = U[:, (np.arange(U.shape[1]) != j)]
			V = V[:, (np.arange(V.shape[1]) != j)]

		print('optimizing')
		if method == 'alt':
			if i % 2 == 0:
				U = np.concatenate((U, x), axis=1)
				V = np.concatenate((V, yT.T), axis=1)
				V = tools.step(U, M, mask, iter_lim = iter_lim)#, V)
			else:
				U = np.concatenate((U, x), axis=1)
				V = np.concatenate((V, yT.T), axis=1)
				U = tools.step(V, M.T, mask2, iter_lim = iter_lim)#), U)
		elif method == 'opt':
			U = np.concatenate((U, x), axis=1)
			V = np.concatenate((V, yT.T), axis=1)
			U = tools.optimize(U, V, M, omega, iter_lim = iter_lim)

		print('computing UV')
		def compute_UV(U, V, omegaT, rng=None):
			def err_val(a):
				return U[omegaT[0], a] * V[omegaT[1], a], np.linalg.norm(U[omegaT[0],a]) * np.linalg.norm(V[omegaT[1], a])
			ret = Parallel(n_jobs=min(12,U.shape[1]))(delayed(err_val)(a) for a in range(U.shape[1]))
			ret, szs = zip(*ret)
			ret = sum(ret)
			if rng != None:
				ret[ret > rng[1]] = rng[1]
				ret[ret < rng[0]] = rng[0]
			return csr_matrix((ret, omegaT), shape=(U.shape[0], V.shape[0])), szs
		UV, szs = compute_UV(U, V, omegaT, rng=rng)

		if i == k - 1:
			warm_start = (k - 1, U, V, UV, szs)

		print('computing R')
		#def er(a,b):
		#	return M[a,b] - np.dot(U[a], V[b])
		#ret = Parallel(n_jobs=8)(delayed(er)(a,b) for (a,b) in omega)
		R = M - UV

		#error = 0
		#error = sum(Parallel(n_jobs=32)(delayed(lambda x : x**2 )(r) for r in ret))
		#print('error', error, error / len(omega))

		##A = np.dot(U, V.T)
		##A_ = np.zeros(A.shape)
		##A_[omegaT] = A[omegaT]
		##R = M_ - A_ #- rho * A

		print('computing error')
		#print(i, 'error', tools.rse(U,V,M,omega2,rng=rng))
		train_error = compute_error(U,V,M,omega,rng=rng, omegaT = omegaT, UV=UV)
		if train_error_prev < train_error and i >= k:
			break
		val_error = compute_error(U,V,M,omega2,rng=rng, omegaT=omega2T)
		print(U.shape[1], 'error', train_error, val_error)
		errors[U.shape[1]] = (train_error, val_error)
	return U_prev, V_prev, errors, warm_start
