import numpy as np
from sklearn.utils.extmath import randomized_svd
from sklearn.utils import check_array

from fancyimpute import SoftImpute


class MySoftImpute(SoftImpute):
	def solve(self, X, missing_mask):
		X = check_array(X, force_all_finite=False)

		X_init = X.copy()

		X_filled = X
		observed_mask = ~missing_mask
		max_singular_value = self._max_singular_value(X_filled)
		if self.verbose:
			print("[SoftImpute] Max Singular Value of X_init = %f" % (
				max_singular_value))

		if self.shrinkage_value:
			shrinkage_value = self.shrinkage_value
		else:
			# totally hackish heuristic: keep only components
			# with at least 1/50th the max singular value
			shrinkage_value = max_singular_value / 50.0

		for i in range(self.max_iters):
			X_reconstruction, rank = self._svd_step(
				X_filled,
				shrinkage_value,
				max_rank=self.max_rank)
			X_reconstruction = self.clip(X_reconstruction)


			converged = self._converged(
				X_old=X_filled,
				X_new=X_reconstruction,
				missing_mask=missing_mask)
			X_filled[missing_mask] = X_reconstruction[missing_mask]
			if converged:
				break
		if self.verbose:
			print("[SoftImpute] Stopped after iteration %d for lambda=%f" % (
				i + 1,
				shrinkage_value))

		self.rank = rank
		self.X_reconstruction = X_reconstruction
		return X_filled
