import numpy as np
from collections import defaultdict
from sklearn.utils.extmath import randomized_svd
from scipy.sparse.linalg import lsqr
from numpy.linalg import lstsq
from scipy.sparse import coo_matrix, csr_matrix
from math import sqrt
from scipy.linalg import orth
from alt import alt 
from joblib import Parallel, delayed
from sklearn.linear_model import LinearRegression
from math import sqrt
from fancyimpute import KNN, NuclearNormMinimization, SoftImpute, BiScaler
import tools
from mysoftimpute import MySoftImpute

def knn(M, omega, k = 3, omega2=None, method='alt'):
	M_full = M.toarray()
	M_empty = M_full.copy()
	omega_ = set(omega)
	omega2_ = set(omega2)
	for i in range(M.shape[0]):
		for j in range(M.shape[1]):
			if (i,j) not in omega:
				M_full[i,j] = np.nan
			if (i,j) not in omega2:
				M_empty[i,j] = np.nan
	M_filled= KNN(k=k).fit_transform(M_full)
	knn_mse = np.nanmean((M_filled- M_empty) ** 2)
	print("knnImpute MSE: %f" % sqrt(knn_mse))

def nuclear(M, omega,k = 3, omega2=None, method='alt'):
	M_full = M.toarray()
	M_empty = M_full.copy()
	omega_ = set(omega)
	omega2_ = set(omega2)
	for i in range(M.shape[0]):
		for j in range(M.shape[1]):
			if (i,j) not in omega:
				M_full[i,j] = np.nan
			if (i,j) not in omega2:
				M_empty[i,j] = np.nan
	M_filled = NuclearNormMinimization().fit_transform(M_full)
	mse = np.nanmean((M_filled - M_empty) ** 2)
	print("nuclearImpute MSE: %f" % sqrt(mse))


def soft(M, omega, rng = None, k = 3, omega2=None, err_type='rmse', to_avoid = {}):
	compute_error = tools.rmse if err_type =='rmse' else tools.rse
	errors = {}
	V = np.eye(M.shape[1])
	def soft_get(M_full, omega, alpha):
		msi = MySoftImpute(shrinkage_value = alpha, verbose=False)
		msi.fit_transform(M_full)
		if msi.rank not in errors and msi.rank not in to_avoid:
			M_filled = msi.X_reconstruction
			errors[msi.rank] = (compute_error(M_filled, V, M, omega, omegaT = omegaT, rng=rng), compute_error(M_filled, V, M, omega2, omegaT = omega2T, rng=rng))
		print('at rank:', msi.rank)
		return msi.rank
	M_full = M.toarray().astype(np.float64)
	omegaT = np.transpose(omega)
	omegaT = (omegaT[0], omegaT[1])
	if omega2 != None:
		omega2T = np.transpose(omega2)
		omega2T = (omega2T[0], omega2T[1])
	omega_ = set(omega)
	for i in range(M.shape[0]):
		for j in range(M.shape[1]):
			if (i,j) not in omega_:
				M_full[i,j] = np.nan

	if k != None:
		alpha_l = alpha_r = 1.0

		while soft_get(M_full, omega, alpha_l) < k: alpha_l /= 2.0
		while soft_get(M_full, omega, alpha_r) > k: alpha_r *= 2.0
		while alpha_r - alpha_l > 1e-9:
			alpha_m = (alpha_l + alpha_r) / 2.0
			get_me = soft_get(M_full, omega, alpha_m)
			if get_me < k: alpha_r = alpha_m
			elif get_me > k: alpha_l = alpha_m
			else: alpha_l = alpha_r = alpha_m

		msi = MySoftImpute(shrinkage_value=alpha_l, verbose=False)
	else:
		msi = MySoftImpute(verbose=True)
	msi.fit_transform(M_full)
	M_filled = msi.X_reconstruction
	rank = msi.rank

	#print(M_filled)
	#print('soft rank', rank)
	errors[rank] = (compute_error(M_filled, V, M, omega, omegaT = omegaT, rng=rng), compute_error(M_filled, V, M, omega2, omegaT = omega2T, rng=rng))
	return M_filled, V, errors
