import scipy
import numpy as np

import numpy.linalg as la
from sklearn.utils.extmath import row_norms, safe_sparse_dot

from optmethods.loss.loss_oracle import Oracle
from optmethods.loss.utils import safe_sparse_add, safe_sparse_multiply, safe_sparse_norm, safe_sparse_inner_prod


class LinearRegression(Oracle):
    """
    Linear regression oracle that returns loss values, gradients and Hessians.
    """
    def __init__(self, A, b, store_mat_vec_prod=True, *args, **kwargs):
        super(LinearRegression, self).__init__(*args, **kwargs)
        self.A = A
        self.b = np.asarray(b)
        self.n, self.dim = A.shape
        self.store_mat_vec_prod = store_mat_vec_prod
        self.x_last = 0.
        self._mat_vec_prod = np.zeros(self.n)
    
    def _value(self, x):
        z = self.mat_vec_product(x)
        regularization = self.l2/2*safe_sparse_norm(x)**2
        return 0.5*safe_sparse_norm(z-self.b)**2 + regularization
    
    def gradient(self, x):
        z = self.mat_vec_product(x)
        regularization = self.l2 * x
        grad = self.A.T @ (z - self.b) + regularization
        print(safe_sparse_norm(grad))
        assert safe_sparse_norm(grad) >= 0.0
        return grad
    
    def mat_vec_product(self, x):
        if self.store_mat_vec_prod and self.is_equal(x, self.x_last):
            return self._mat_vec_prod
        Ax = self.A @ x
        if scipy.sparse.issparse(Ax):
            Ax = Ax.toarray()
        Ax = Ax.ravel()
        if self.store_mat_vec_prod:
            self._mat_vec_prod = Ax
            self.x_last = x.copy()
        return Ax

    @property
    def smoothness(self):
        if self._smoothness is not None:
            return self._smoothness
        if self.dim > 20000 and self.n > 20000:
            warnings.warn("The matrix is too large to estimate the smoothness constant, so Frobenius estimate is used instead.")
            if scipy.sparse.issparse(self.A):
                self._smoothness = 0.25*scipy.sparse.linalg.norm(self.A, ord='fro')**2/self.n + self.l2
            else:
                self._smoothness = 0.25*np.linalg.norm(self.A, ord='fro')**2/self.n + self.l2
        else:
            sing_val_max = scipy.sparse.linalg.svds(self.A, k=1, return_singular_vectors=False)[0]
            self._smoothness = 0.25*sing_val_max**2/self.n + self.l2
        return self._smoothness
    
    def max_smoothness(self):
        max_squared_sum = row_norms(self.A, squared=True).max()
        return 0.25*max_squared_sum + self.l2
    
    def average_smoothness(self):
        ave_squared_sum = row_norms(self.A, squared=True).mean()
        return 0.25*ave_squared_sum + self.l2
    
    @staticmethod
    def norm(x, ord=None):
        return safe_sparse_norm(x, ord=ord)
    
    @staticmethod
    def inner_prod(x, y):
        return safe_sparse_inner_prod(x, y)
    


