import numpy as np
from random_projection import RandomProjection as rp
from regression_solvers import iterative_lasso, Ls_solver
from regression_solvers import iterative_lasso_step_size
import time
#from random_projection import fast_countSketch
from time import process_time
from numba import njit, jit
from scipy import sparse
from scipy.sparse import coo_matrix
#from scipy.sparse import coo_matrix
import datetime
from timeit import default_timer as timer
import random
import cvxpy as cp
import copy

@njit(fastmath=True)
def fast_countSketch(SA, row, col, data, sign, row_map):
    N = len(data)
    for idx in range(N):
        #print(idx)
        #print(row_map.shape, "shape")
        # print(row[idx])
        #print(row_map[row[idx]], col[idx], row[idx])
        SA[row_map[row[idx]], col[idx]] += data[idx] * sign[row[idx]]
    return SA

class IHS:
    '''Implementation of the iterative hessian sketching scheme of
    Pilanci and Wainwright (https://arxiv.org/pdf/1411.0347.pdf)
    '''

    def __init__(self, data, targets, sketch_method, sketch_dimension,
                 col_sparsity=1, flag = False):

        # optimisation setup
        self.A = data
        self.b = targets

        # Need to deal with sparse type
        if isinstance(self.A, np.ndarray):
            #print(self.A.T.shape)
            #print(self.b.shape)
            if(flag == False):
                self.ATb = self.A.T @ self.b
        else:
            self.ATb = sparse.csr_matrix.dot(self.A.T, self.b)
            # self.ATb = np.squeeze(self.ATb.toarray())

        self.n, self.d = self.A.shape
        # print('self.d:')
        # print(self.d)
        self.x = np.zeros((self.d, ))  # initialise the startin point.
        # print('self.x.shape(self.x = np.zeros((self.d,))):')
        # print(self.x.shape)
        self.sketch_method = sketch_method
        self.sketch_dimension = sketch_dimension
        self.col_sparsity = col_sparsity
        # initialise the sketch to avoid the repeated costs
        self.sketcher = rp(self.A, self.sketch_dimension,
                           self.sketch_method, self.col_sparsity)
        self.coo_data = coo_matrix(data)
        self.rows = self.coo_data.row
        self.cols = self.coo_data.col
        self.vals = self.coo_data.data

    ############# OLS (VANILLA) ##########################




    def ols_fit_new_sketch(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.'''
        #print(self.x)
        #print(self.ATb)
        #print(self.A.T@(self.A @self.x))
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            H = _sketch.T @ _sketch
            #print(_sketch)
            grad_term = self.ATb - self.A.T @ (self.A @ self.x)
            #print(grad_term)
            #print(H)
            u = np.linalg.solve(H, grad_term)
            #flag =True
            # if flag == True:
            #     print(u)
            #     flag = False
            #     assert 1 == 0
            self.x = u + self.x
        return self.x

    def ols_fit_new_sketch_learn(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.'''
        #print(self.x)
        #print(self.ATb)
        #print(self.A.T@(self.A @self.x))
        for ii in range(iterations):
            #_sketch = self.sketcher.sketch()
            _sketch = np.load("ghg30" + str(0) + ".npy")
            # print(_sketch) #iterations - 1 - ii 0
            # print(_sketch.shape)
            # _sketch = self.sketcher.sketch()
            # print(_sketch.shape)
            _sketch = _sketch @ self.A
            H = _sketch.T @ _sketch
            #print(_sketch)
            grad_term = self.ATb - self.A.T @ (self.A @ self.x)
            #print(grad_term)
            #print(H)
            u = np.linalg.solve(H, grad_term)
            #flag =True
            # if flag == True:
            #     print(u)
            #     flag = False
            #     assert 1 == 0
            self.x = u + self.x
        return self.x

    def ols_fit_one_sketch(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.

            This needs a larger sketch than if we generate a fresh sketch
            for every iteration.'''
        _sketch = self.sketcher.sketch()
        H = _sketch.T @ _sketch
        for ii in range(iterations):
            grad_term = self.ATb - self.A.T @ (self.A @ self.x)
            u = np.linalg.solve(H, grad_term)
            self.x = u + self.x
        return self.x

    ############# OLS WITH ERROR TRACKING ##########################
    def ols_fit_new_sketch_track_errors(self, iterations):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration and tracking the error
            after every iteration.

            step_size:

            '''
        self.sol_errors = np.zeros((self.d, iterations))
        # print(self.d)
        # print(iterations)
        # print(self.sol_errors.shape)
        self.x = np.zeros((self.d,))
        # print(self.x)
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            H = _sketch.T @ _sketch
            grad_term = self.ATb.T[0] - self.A.T @ (self.A @ self.x)
            # print('self.ATb.shape:')
            # print(self.ATb.shape)
            # print('self.A.T @ (self.A @ self.x).shape:')
            # print((self.A.T @ (self.A @ self.x)).shape)
            # print('grad_term.shape:')
            # print(grad_term.shape)
            # print('self.x.shape(grad_term = self.ATb - self.A.T @ (self.A @ self.x)):')
            # print(self.x.shape)
            # if choose_step:
            #     eigs = np.linalg.eig(H)[0]
            #     alpha = 2.0/(eigs[0]+eigs[-1])
            # else:
            #     alpha = 1.0
            #
            #
            # grad_term *= alpha
            u = np.linalg.solve(H, grad_term)
            # print('H.shape:')
            # print(H.shape)

            # print('u.shape:')
            # print(u.shape)
            self.x = u + self.x
            # print('self.x.shape(self.x = u + self.x):')
            # print(self.x.shape)
            # print(self.x.shape)
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    def ols_fit_one_sketch_track_errors(self, iterations, step_size):
        '''Solve the ordinary least squares problem iteratively using ihs
            generating a fresh sketch at every iteration.

            This needs a larger sketch than if we generate a fresh sketch
            for every iteration.

            '''
        self.sol_errors = np.zeros((self.d, iterations))
        _sketch = self.sketcher.sketch()
        H = _sketch.T @ _sketch
        cov_mat = self.A.T @ self.A

        # Frobenius and pointwise spectral guarantee
        _, Ss, _ = np.linalg.svd(_sketch)
        _, SigmaA, _ = np.linalg.svd(self.A)
        self.frob_error = np.linalg.norm(H - cov_mat, ord='fro') / np.linalg.norm(cov_mat, ord='fro')
        self.spectral_error = np.abs(SigmaA[0] - Ss[0]) / SigmaA[0]

        # self.frob_error = np.linalg.norm(H - self.A.T@self.A,ord='fro') / np.linalg.norm(self.A.T@self.A,ord='fro')
        # self.spec_error = np.linalg.norm(H - self.A.T@self.A,ord=2) / np.linalg.norm(self.A.T@self.A,ord=2)
        # if choose_step:
        #     eigs = np.linalg.eig(H)[0]
        #     alpha = 2.0/(eigs[0]+eigs[-1])
        # else:
        #     alpha=1.0

        for ii in range(iterations):
            grad_term = step_size * (self.ATb - self.A.T @ (self.A @ self.x))
            u = np.linalg.solve(H, grad_term)
            self.x = u + self.x
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    ############# LASSO ##########################
    ##############################################
    ##############################################
    ##############################################

    def lasso_fit_new_sketch(self, iterations, ell1Bound):
        '''
            fit the lasso model with the ell1Bound constraint.'''
        total_time = 0.0

        start = time.perf_counter()
        tmp = np.zeros(self.A.shape[1],)
        sketch = []
        self.sketcher.sketch()
        # self.sketcher.sketch()
        # self.sketcher.sketch()
        for ii in range(iterations):
            sketch.append(self.sketcher.sketch())

        for ii in range(iterations):

            _sketch = sketch[ii]

            d = _sketch.shape[1]
            x = cp.Variable(d, )
            f = (self.A.T) @ (self.b - self.A @ tmp)
            objective = cp.Minimize(0.5 * cp.sum_squares(_sketch @ (x - tmp)) - cp.sum(cp.multiply(f, x)))
            constraints = [cp.norm(x, 1) <= ell1Bound]  # 0 <= x, cp.sum(x) == 1 #
            prob = cp.Problem(objective, constraints)  # + cp.norm(x, 'nuc')
            result = prob.solve()

            tmp = x.value
            self.x = tmp
            # self.x = u + self.x
            # print
        end = time.perf_counter()
        total_time = end - start
        return self.x ,  total_time #



    def lasso_fit_new_sketch_learn(self, iterations, ell1Bound):
        '''
            fit the lasso model with the ell1Bound constraint.'''

        tmp = np.zeros(self.A.shape[1], )
        start = time.perf_counter()
        for ii in range(iterations):

            _sketch = np.load("ele//ele_use_72" + str(ii) + ".npy") # ghg_fr
            _sketch = _sketch @ self.A

            d = _sketch.shape[1]
            x = cp.Variable(d, )
            f = (self.A.T) @ (self.b - self.A @ tmp)
            objective = cp.Minimize(0.5 * cp.sum_squares(_sketch @ (x - tmp)) - cp.sum(cp.multiply(f, x)))
            constraints = [cp.norm(x, 1) <= ell1Bound]
            prob = cp.Problem(objective, constraints)
            result = prob.solve()

            tmp = x.value
            self.x = tmp

            end = time.perf_counter()

        return self.x, end-start


    def learnCountSketch(self, A, m, n, t):

        S = np.zeros((m, n))
        u, s, v = np.linalg.svd(A, full_matrices=0)
        val = np.zeros(n)
        for i in range(n):
            # val[i] = np.linalg.norm(u[i, :])
            val[i] = np.linalg.norm(u[i, :])
        thr = (-np.sort(-val))[t - 1]
        pos = 0
        order = np.zeros(m)
        for i in range(m):
            order[i] = i
        np.random.seed(int(time.time()))
        np.random.shuffle(order)
        for i in range(n):
            if val[i] >= thr:
                S[int(order[pos])][i] = 2 * random.randint(0, 1) - 1
                pos += 1

            else:
                row = random.randint(t, m - 1)
                sign = 2 * random.randint(0, 1) - 1
                S[int(order[row])][i] = sign

        return S

    def learnCountSketch_new(self, A, m, n, t):

        position = np.load("ghg_position.npy")
        idx = np.zeros(327)

        id = 0
        B = np.zeros((121, A.shape[1]))
        for i in range(n):
            if position[i] >= 10:
                B[id, :] = A[i, :]
                idx[id] = i
                id += 1

        S = np.zeros((m, n))
        #
        val = np.zeros(327)

        for i in range(121):
            val[int(idx[i])] = np.linalg.norm(B[i, :])
        thr = (-np.sort(-val))[t - 1]
        pos = 0
        order = np.zeros(m)
        for i in range(m):
            order[i] = i
        np.random.seed(int(time.time()))
        np.random.shuffle(order)
        for i in range(n):
            if val[i] >= thr:
                S[int(order[pos])][i] = 2 * random.randint(0, 1) - 1
                pos += 1

            else:
                row = random.randint(t, m - 1)
                sign = 2 * random.randint(0, 1) - 1
                S[int(order[row])][i] = sign

        return S

    def lasso_fit_new_sketch_learnCS(self, iterations, ell1Bound):
        '''
            fit the lasso model with the ell1Bound constraint.'''

        start = time.perf_counter()
        for ii in range(iterations):

            if (False):#0
                # _sketch = self.sketcher.sketch()
                _sketch = np.load("ghg" + str(ii) + ".npy")
                _sketch = _sketch @ self.A
            else:
                d1, d2 = self.A.shape
                _sketch = self.learnCountSketch(self.A, self.sketch_dimension, d1, int(0.3 * self.sketch_dimension))
                _sketch = _sketch @ self.A

            self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                     self.b, self.x, ell1Bound)
            end = time.perf_counter()
            # self.x = u + self.x
            # print
        return self.x , end-start

    def lasso_fit_new_sketch_learnCS_new(self, iterations, ell1Bound):
        '''
            fit the lasso model with the ell1Bound constraint.'''

        # pos = np.load("position")
        tmp = np.zeros(self.A.shape[1], )
        start = time.perf_counter()
        for ii in range(iterations):

            if (False):
                # _sketch = self.sketcher.sketch()
                _sketch = np.load("ghg" + str(ii) + ".npy")
                _sketch = _sketch @ self.A
            else:
                d1, d2 = self.A.shape
                _sketch = self.learnCountSketch_new(self.A, self.sketch_dimension, d1, int(0.3 * self.sketch_dimension))
                _sketch = _sketch @ self.A

                d = _sketch.shape[1]
                x = cp.Variable(d, )
                f = (self.A.T) @ (self.b - self.A @ tmp)
                objective = cp.Minimize(0.5 * cp.sum_squares(_sketch @ (x - tmp)) - cp.sum(cp.multiply(f, x)))
                constraints = [cp.norm(x, 1) <= ell1Bound]
                prob = cp.Problem(objective, constraints)
                result = prob.solve()

                tmp = x.value
                self.x = tmp


            end = time.perf_counter()

        return self.x , end-start

    def lasso_fit_new_sketch_track_errors(self, ell1Bound, iterations):
        '''
            fit the lasso model with the ell1bound constraint
            and return the iterates for error check.'''
        print(f'Using X{self.A.shape},y{self.b.shape}')
        print('Using ell1Bound = ', ell1Bound)
        self.sol_errors = np.zeros((self.d, iterations))
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                     self.b, self.x, ell1Bound)
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    def lasso_fit_new_sketch_speedup(self, ell1Bound, iterations):
        '''
            fit the lasso model with the ell1bound constraint
            and return the iterates for error check until the
            allotted time period runs out.
            '''
        time_used = 0
        self.sol_errors = np.zeros((self.d, 1))
        start_time = timer()
        for ii in range(iterations):
            _sketch = self.sketcher.sketch()
            self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                     self.b, self.x, ell1Bound)
        end_time = timer()
        t_elapsed = end_time - start_time
        return self.x, t_elapsed

    def lasso_fit_new_sketch_timing(self, ell1Bound, timeBound):
        '''
            fit the lasso model with the ell1bound constraint
            and return the iterates for error check until the
            allotted time period runs out.
            '''
        print("RUNNING FOR {} SECONDS".format(timeBound))
        iterations = 0
        time_used = 0
        self.sol_errors = np.zeros((self.d, 1))
        endTime = datetime.datetime.now() + datetime.timedelta(seconds=timeBound)
        while True:
            if datetime.datetime.now() >= endTime:
                break
                print("IHS ran for {} seconds".format(timeBound))
            else:
                iterations += 1
                _sketch = self.sketcher.sketch()
                self.x = iterative_lasso(_sketch, self.ATb, self.A,
                                         self.b, self.x, ell1Bound)
        return self.x, iterations

    def lasso_fit_one_sketch_track_errors(self, ell1Bound, iterations, step_size=1.0):
        '''
            Fit the Lasso model with a single sketch and step size equal to
            step_size to descend to optimum
            '''
        _sketch = self.sketcher.sketch()
        self.sol_errors = np.zeros((self.d, iterations))
        for ii in range(iterations):
            self.x = step_size * (iterative_lasso_step_size(_sketch, self.ATb, self.A,
                                                            self.b, self.x, ell1Bound, step_size))
            self.sol_errors[:, ii] = self.x
        return self.x, self.sol_errors

    def svm_fit_new_sketch(self, iterations):
        '''
            fit the lasso model with the ell1Bound constraint.'''
        tmp = np.zeros(self.A.shape[1])

        total_time = 0.0
        for ii in range(iterations):

            start = time.perf_counter()
            _sketch = self.sketcher.sketch()
            d = _sketch.shape[1]
            x = cp.Variable(d)
            f = (self.A.T) @ self.A @ tmp
            objective = cp.Minimize(cp.sum_squares(_sketch @ (x - tmp)) + (f.T @ x))
            constraints = [0 <= x, cp.sum(x) == 1]
            prob = cp.Problem(objective, constraints)
            result = prob.solve()

            for i in range(tmp.shape[0]):
                tmp[i] = x.value[i]
            self.x = tmp
            end = time.perf_counter()
            total_time += end - start
            #print(tmp)
        return self.x, total_time

    def svm_fit_new_sketch_learn(self, iterations):
        '''
            fit the lasso model with the ell1Bound constraint.'''
        total_time = 0
        tmp = self.x
        total_time = 0.0
        for ii in range(iterations):
            #_sketch = self.sketcher.sketch()  + 2 90
            # print(ii)

            if(True):
                _sketch = np.load("gaussian" + str(ii) + ".npy")
                start = time.perf_counter()
                _sketch = _sketch @ self.A

            else :
                _sketch = self.sketcher.sketch()
            d = _sketch.shape[1]
            x = cp.Variable(d)
            #print(d)
            f = (self.A.T) @ self.A @ tmp
            objective = cp.Minimize(cp.sum_squares(_sketch @ (x - tmp)) + (f.T @ x))
            constraints = [0 <= x, cp.sum(x) == 1]
            prob = cp.Problem(objective, constraints)
            result = prob.solve()
            #print(x.value)

            #return x.value
            tmp = x.value
            self.x = tmp
            end = time.perf_counter()
            total_time += end - start
        return self.x, total_time #, total_time

    def nuclear_fit_new_sketch(self, iterations, cons, B):
        '''
            fit the lasso model with the ell1Bound constraint.'''

        tmp = np.zeros((self.A.shape[1], B.shape[1]))
        start = time.perf_counter()
        for ii in range(iterations):
            # print(ii)

            _sketch = self.sketcher.sketch()
            _sketch = self.sketcher.sketch()
            d = _sketch.shape[1]
            x = cp.Variable((d, B.shape[1]))
            f = (self.A.T) @ (B - self.A @ tmp)
            objective = cp.Minimize(0.5 * cp.sum_squares(_sketch @ (x - tmp)) - cp.sum(cp.multiply(f, x) ))
            constraints = [cp.norm(x, 'nuc') <= 10, cp.norm(x, 'fro') <= 100]#0 <= x, cp.sum(x) == 1 #
            prob = cp.Problem(objective, constraints)#  + cp.norm(x, 'nuc')
            result = prob.solve()

            # for i in range(tmp.shape[0]):
            #     tmp[i] = x.value[i]
            tmp = x.value

            self.x = tmp
        end = time.perf_counter()
        return self.x, end - start

    def nuclear_fit_new_sketch_learn(self, iterations, cons, B):
        '''
            fit the lasso model with the ell1Bound constraint.'''


        tmp = np.zeros((self.A.shape[1], B.shape[1]))
        start = time.perf_counter()
        for ii in range(iterations):


            _sketch = np.load("") #

            _sketch = _sketch @ self.A
            # else:
                # _sketch = self.sketcher.sketch()

            d = _sketch.shape[1]
            x = cp.Variable((d, B.shape[1]))
            f = (self.A.T) @ (B - self.A @ tmp)
            objective = cp.Minimize(0.5 * cp.sum_squares(_sketch @ (x - tmp)) - cp.sum(cp.multiply(f, x)))
            constraints = [cp.norm(x, 'nuc') <= 10, cp.norm(x, 'fro') <= 100]#0 <= x, cp.sum(x) == 1  + cp.norm(x, 'nuc')
            prob = cp.Problem(objective, constraints)
            result = prob.solve()

            # for i in range(tmp.shape[0]):
            #     tmp[i] = x.value[i]
            tmp = x.value
            self.x = tmp

        end = time.perf_counter()
        # print(start, end)
        return self.x, end - start

    def fastls_fit(self, flag, ii = None):

       _sketch = self.sketcher.sketch()

        # learned version
        # _sketch = np.load("ini_sketch_90" + str(ii + 3) + ".npy")
        # _sketch = _sketch @ self.A

       A = self.A
       n, d = A.shape
       y = self.b
       Q, R = np.linalg.qr(_sketch)
       R = np.linalg.inv(R)

       u, s, v = np.linalg.svd(A @ R)


       T = (R.T @ A.T @ A @ R)



       x = np.zeros((d, 1))

       t = 3
       if ii == flag:
           t = 3

       for i in range(t):
           eta = 1
           if i > 0:
               eta = 0.2

           x = x - eta * (T.T @ (T @ x - R.T @ y))

       return R @ x, s[0] / s[8]



