import numpy as np
from random_projection import RandomProjection as rp
from regression_solvers import iterative_lasso, Ls_solver
from regression_solvers import iterative_lasso_step_size
from time import process_time
from numba import njit, jit
from scipy import sparse
from scipy.sparse import coo_matrix
import datetime
from timeit import default_timer as timer
import random
import cvxpy as cp
import copy

class IHS:
    '''Implementation of the iterative hessian sketching scheme of
    Pilanci and Wainwright (https://arxiv.org/pdf/1411.0347.pdf)
    '''

    def __init__(self, data, targets, sketch_method, sketch_dimension,
                 col_sparsity=1, flag = False):

        # optimisation setup
        self.A = data
        self.b = targets

        # Need to deal with sparse type
        if isinstance(self.A, np.ndarray):
            #print(self.A.T.shape)
            #print(self.b.shape)
            if(flag == False):
                self.ATb = self.A.T @ self.b
        else:
            self.ATb = sparse.csr_matrix.dot(self.A.T, self.b)
            # self.ATb = np.squeeze(self.ATb.toarray())

        self.n, self.d = self.A.shape
        # print('self.d:')
        # print(self.d)
        self.x = np.zeros((self.d, ))  # initialise the startin point.
        # print('self.x.shape(self.x = np.zeros((self.d,))):')
        # print(self.x.shape)
        self.sketch_method = sketch_method
        self.sketch_dimension = sketch_dimension
        self.col_sparsity = col_sparsity
        # initialise the sketch to avoid the repeated costs
        self.sketcher = rp(self.A, self.sketch_dimension,
                           self.sketch_method, self.col_sparsity)
        self.coo_data = coo_matrix(data)
        self.rows = self.coo_data.row
        self.cols = self.coo_data.col
        self.vals = self.coo_data.data

    ############# OLS (VANILLA) ##########################
    def ols_fit_new_sketch(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.'''
        #print(self.x)
        #print(self.ATb)
        #print(self.A.T@(self.A @self.x))
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            H = _sketch.T @ _sketch
            #print(_sketch)
            grad_term = self.ATb - self.A.T @ (self.A @ self.x)
            #print(grad_term)
            #print(H)
            u = np.linalg.solve(H, grad_term)
            #flag =True
            # if flag == True:
            #     print(u)
            #     flag = False
            #     assert 1 == 0
            self.x = u + self.x
        return self.x

    def ols_fit_new_sketch_learn(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.'''
        #print(self.x)
        #print(self.ATb)
        #print(self.A.T@(self.A @self.x))
        for ii in range(iterations):
            #_sketch = self.sketcher.sketch()
            _sketch = np.load("ghg30" + str(0) + ".npy")
            # print(_sketch) #iterations - 1 - ii 0
            # print(_sketch.shape)
            # _sketch = self.sketcher.sketch()
            # print(_sketch.shape)
            _sketch = _sketch @ self.A
            H = _sketch.T @ _sketch
            #print(_sketch)
            grad_term = self.ATb - self.A.T @ (self.A @ self.x)
            #print(grad_term)
            #print(H)
            u = np.linalg.solve(H, grad_term)
            #flag =True
            # if flag == True:
            #     print(u)
            #     flag = False
            #     assert 1 == 0
            self.x = u + self.x
        return self.x

    def ols_fit_one_sketch(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.

            This needs a larger sketch than if we generate a fresh sketch
            for every iteration.'''
        _sketch = self.sketcher.sketch()
        H = _sketch.T @ _sketch
        for ii in range(iterations):
            grad_term = self.ATb - self.A.T @ (self.A @ self.x)
            u = np.linalg.solve(H, grad_term)
            self.x = u + self.x
        return self.x

    ############# OLS WITH ERROR TRACKING ##########################
    def ols_fit_new_sketch_track_errors(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration and tracking the error
            after every iteration.

            step_size:

            '''
        self.sol_errors = np.zeros((self.d, iterations))
        # print(self.d)
        # print(iterations)
        # print(self.sol_errors.shape)
        self.x = np.zeros((self.d,))
        # print(self.x)
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            H = _sketch.T @ _sketch
            grad_term = self.ATb.T[0] - self.A.T @ (self.A @ self.x)
            # print('self.ATb.shape:')
            # print(self.ATb.shape)
            # print('self.A.T @ (self.A @ self.x).shape:')
            # print((self.A.T @ (self.A @ self.x)).shape)
            # print('grad_term.shape:')
            # print(grad_term.shape)
            # print('self.x.shape(grad_term = self.ATb - self.A.T @ (self.A @ self.x)):')
            # print(self.x.shape)
            # if choose_step:
            #     eigs = np.linalg.eig(H)[0]
            #     alpha = 2.0/(eigs[0]+eigs[-1])
            # else:
            #     alpha = 1.0
            #
            #
            # grad_term *= alpha
            u = np.linalg.solve(H, grad_term)
            # print('H.shape:')
            # print(H.shape)

            # print('u.shape:')
            # print(u.shape)
            self.x = u + self.x
            # print('self.x.shape(self.x = u + self.x):')
            # print(self.x.shape)
            # print(self.x.shape)
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    def ols_fit_one_sketch_track_errors(self, iterations, step_size):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.

            This needs a larger sketch than if we generate a fresh sketch
            for every iteration.

            '''
        self.sol_errors = np.zeros((self.d, iterations))
        _sketch = self.sketcher.sketch()
        H = _sketch.T @ _sketch
        cov_mat = self.A.T @ self.A

        # Frobenius and pointwise spectral guarantee
        _, Ss, _ = np.linalg.svd(_sketch)
        _, SigmaA, _ = np.linalg.svd(self.A)
        self.frob_error = np.linalg.norm(H - cov_mat, ord='fro') / np.linalg.norm(cov_mat, ord='fro')
        self.spectral_error = np.abs(SigmaA[0] - Ss[0]) / SigmaA[0]

        # self.frob_error = np.linalg.norm(H - self.A.T@self.A,ord='fro') / np.linalg.norm(self.A.T@self.A,ord='fro')
        # self.spec_error = np.linalg.norm(H - self.A.T@self.A,ord=2) / np.linalg.norm(self.A.T@self.A,ord=2)
        # if choose_step:
        #     eigs = np.linalg.eig(H)[0]
        #     alpha = 2.0/(eigs[0]+eigs[-1])
        # else:
        #     alpha=1.0

        for ii in range(iterations):
            grad_term = step_size * (self.ATb - self.A.T @ (self.A @ self.x))
            u = np.linalg.solve(H, grad_term)
            self.x = u + self.x
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    ############# LASSO ##########################
    ##############################################
    ##############################################
    ##############################################

    def lasso_fit_new_sketch(self, iterations, ell1Bound):
        '''
            fit the lasso model with the ell1Bound constraint.'''
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                     self.b, self.x, ell1Bound)
            # self.x = u + self.x
            # print
        return self.x

    def lasso_fit_new_sketch_learn(self, iterations, ell1Bound):
        '''
            fit the lasso model with the ell1Bound constraint.'''
        pos = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        #pos = [9, 8 ,7, 6, 5, 4, 3, 2, 1, 0]
        #random.shuffle(pos)
        for ii in range(iterations):
            if (ii <= 10) :

                _sketch = np.load("u_ske//co45" + str(pos[ii]) + ".npy")  # iterations - 1 -

                _sketch = _sketch @ self.A
            else:
                _sketch = self.sketcher.sketch()
            self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                     self.b, self.x, ell1Bound)
            # self.x = u + self.x
            # print
        return self.x


    def lasso_fit_new_sketch_track_errors(self, ell1Bound, iterations):
        '''
            fit the lasso model with the ell1bound constraint
            and return the iterates for error check.'''
        print(f'Using X{self.A.shape},y{self.b.shape}')
        print('Using ell1Bound = ', ell1Bound)
        self.sol_errors = np.zeros((self.d, iterations))
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                     self.b, self.x, ell1Bound)
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    def lasso_fit_new_sketch_speedup(self, ell1Bound, iterations):
        '''
            fit the lasso model with the ell1bound constraint
            and return the iterates for error check until the
            allotted time period runs out.
            '''
        time_used = 0
        self.sol_errors = np.zeros((self.d, 1))
        start_time = timer()
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                     self.b, self.x, ell1Bound)
        end_time = timer()
        t_elapsed = end_time - start_time
        return self.x, t_elapsed

    def lasso_fit_new_sketch_timing(self, ell1Bound, timeBound):
        '''
            fit the lasso model with the ell1bound constraint
            and return the iterates for error check until the
            allotted time period runs out.
            '''
        print("RUNNING FOR {} SECONDS".format(timeBound))
        iterations = 0
        time_used = 0
        self.sol_errors = np.zeros((self.d, 1))
        endTime = datetime.datetime.now() + datetime.timedelta(seconds=timeBound)
        while True:
            if datetime.datetime.now() >= endTime:
                break
                print("IHS ran for {} seconds".format(timeBound))
            else:
                iterations += 1
                _sketch = self.sketcher.sketch()
                self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                         self.b, self.x, ell1Bound)
        return self.x, iterations

    def lasso_fit_one_sketch_track_errors(self, ell1Bound, iterations, step_size=1.0):
        '''
            Fit the Lasso model with a single sketch and step size equal to
            step_size to descend to optimum
            '''
        _sketch = self.sketcher.sketch()
        self.sol_errors = np.zeros((self.d, iterations))
        for ii in range(iterations):
            self.x = step_size * (iterative_lasso_step_size(_sketch, self.ATb, self.A,
                                                            self.b, self.x, ell1Bound, step_size))
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    def svm_fit_new_sketch(self, iterations):
        '''
            fit the lasso model with the ell1Bound constraint.'''
        tmp = np.zeros(self.A.shape[1])

        for ii in range(iterations):

            _sketch = self.sketcher.sketch()

            d = _sketch.shape[1]
            x = cp.Variable(d)

            f = (self.A.T) @ self.A @ tmp
            objective = cp.Minimize(cp.sum_squares(_sketch @ (x - tmp)) + (f.T @ x))
            constraints = [0 <= x, cp.sum(x) == 1]
            prob = cp.Problem(objective, constraints)
            result = prob.solve()

            for i in range(tmp.shape[0]):
                tmp[i] = x.value[i]
            self.x = tmp

        return self.x

    def svm_fit_new_sketch_learn(self, iterations):
        '''
            fit the lasso model with the ell1Bound constraint.'''
        tmp = self.x
        for ii in range(iterations):
            #_sketch = self.sketcher.sketch()  + 2 90
            if(ii <= 10):
                _sketch = self.sketcher.sketch()

                _sketch = np.load("1number" + str(ii) + ".npy")
                #print(_sketch)
                #print(self.A.shape)
                _sketch = _sketch @ self.A

            else :
                _sketch = self.sketcher.sketch()
            d = _sketch.shape[1]
            x = cp.Variable(d)
            #print(d)
            f = (self.A.T) @ self.A @ tmp
            objective = cp.Minimize(cp.sum_squares(_sketch @ (x - tmp)) + (f.T @ x))
            constraints = [0 <= x, cp.sum(x) == 1]
            prob = cp.Problem(objective, constraints)
            result = prob.solve()
            #print(x.value)

            #return x.value
            tmp = x.value
            self.x = tmp
        return self.x

    def nuclear_fit_new_sketch(self, iterations, cons, B):
        '''
            fit the lasso model with the ell1Bound constraint.'''


        tmp = np.zeros((self.A.shape[1], B.shape[1]))
        for ii in range(iterations):
            # print(ii)
            _sketch = self.sketcher.sketch()
            d = _sketch.shape[1]
            x = cp.Variable((d, B.shape[1]))
            f = (self.A.T) @ (B - self.A @ tmp)
            objective = cp.Minimize(0.5 * cp.sum_squares(_sketch @ (x - tmp)) - cp.sum(cp.multiply(f, x) ))
            constraints = [cp.norm(x, 'nuc') <= 10, cp.norm(x, 'fro') <= 100]#0 <= x, cp.sum(x) == 1 #
            prob = cp.Problem(objective, constraints)#  + cp.norm(x, 'nuc')
            result = prob.solve()

            tmp = x.value
            self.x = tmp

        return self.x

    def nuclear_fit_new_sketch_learn(self, iterations, cons, B):
        '''
            fit the lasso model with the ell1Bound constraint.'''


        tmp = np.zeros((self.A.shape[1], B.shape[1]))
        for ii in range(iterations):
            # print(ii)
            if (ii <= 10):
                _sketch = self.sketcher.sketch()
                _sketch = np.load("gas" + str(ii) + ".npy")
                _sketch = _sketch @ self.A

            else:
                _sketch = self.sketcher.sketch()

            d = _sketch.shape[1]
            x = cp.Variable((d, B.shape[1]))
            f = (self.A.T) @ (B - self.A @ tmp)
            objective = cp.Minimize(0.5 * cp.sum_squares(_sketch @ (x - tmp)) - cp.sum(cp.multiply(f, x)))
            constraints = [cp.norm(x, 'nuc') <= 10]
            prob = cp.Problem(objective, constraints)
            result = prob.solve()


            tmp = x.value
            self.x = tmp

        return self.x

    def fastls_fit(self, flag, ii = None):

       _sketch = self.sketcher.sketch()

       #_sketch = np.load("ntghg100" + str(ii) + ".npy") # 0  0
       #_sketch = _sketch @ self.A
       # print(_sketch.shape)

       A = self.A
       n, d = A.shape
       y = self.b
       Q, R = np.linalg.qr(_sketch)
       R = np.linalg.inv(R)

       T = (R.T @ A.T @ A @ R)

       x = np.zeros((d, 1))

       t = 3

       if ii == flag:
           t = 1

       for i in range(t):
           eta = 1.0
           if i > 0:
               eta = 0.2
           x = x - eta * (T.T @ (T @ x - R.T @ y))

       #x_opt = Ls_solver(T, R.T @ y)

       return R @ x


