import numpy as np
import torch
import time
import os
import math
from tqdm import tqdm
from numba import njit
from scipy import sparse
from scipy.sparse import coo_matrix
#from regression_solvers import lasso_solver_tch
from regression_solvers import lasso_solver, iterative_lasso
import cvxopt as cp
# from hadamard import fastwht
# import pyfht


# Helper functions

@njit(fastmath=True)
def fast_countSketch(SA, row, col, data, sign, row_map):
    N = len(data)
    for idx in range(N):
        SA[row_map[row[idx]], col[idx]] += data[idx] * sign[row[idx]]
    return SA


def shift_bit_length(x):
    '''Given int x find next largest power of 2.
    If x is a power of 2 then x is returned '''
    return 1 << (x - 1).bit_length()


class RandomProjection:
    '''
    Class based implementation of different
    Random projection methods for RNLA.

    Current examples are taken from [1]
    however there are various sampling methods
    which could also be implemented.

    [1] - https://researcher.watson.ibm.com/researcher/files/us-dpwoodru/wNow.pdf'''

    def __init__(self, data, proj_dim,
                 sketch_type, col_sparsity=1):
        '''
        data - ndarray of the data (or vector) to sketch
        proj_dim - int --> Value to sketch down to
        sketch_type - str determining which sketch to use

        Sketch Type support args:
        1. gaussian
        2. srht
        3. sjlt -- default col_sparsity is 4
        4. countsketch
        '''

        if sketch_type == 'learnSketch':
            '''
            data 是一个3D tensor，第一维代表不同的数据集（共7个），第二维代表不同的数据（每个数据集1000条），第三维代表不同的feature（300个feature，第301个是x）
            '''
            # 数据集的数量
            self.N_train = data.shape[0]
            self.n = data.shape[1]
            self.d = data.shape[2]

            self.proj_dim = proj_dim
            self.sketch_type = sketch_type

            # self.sketch_vector = np.random.choice(self.proj_dim,
            #                                  self.n,
            #                                  replace=True,
            #                                  )
            # self._sign_map = np.random.choice(2, self.n, replace=True) * 2 - 1

            self.sketch_vector = torch.randint(0, self.proj_dim, (1, self.n))

            # self.sketch_values = torch.rand(self.n)
            self.sketch_values = ((torch.randint(2, [1, self.n]).float() - 0.5) * 2).cpu()

            self.A_train = data[:, :, :-1]
            self.B_train = data[:, :, -1]
            # print('self.B_train.shape:')
            # print(self.B_train.shape)
        else:
            # print('sketch_type:', sketch_type)
            # print(sketch_type)
            self.data = data
            self.n, self.d = data.shape
            self.proj_dim = proj_dim
            self.sketch_type = sketch_type
            #print(self.n)
            # print('self.proj_dim:')
            # print(self.proj_dim)
            # print('self.n:')
            # print(self.n)
            # Create an attribute for dense data for later reference.
            # If data is ndarray then just make new reference, otherwise, if input
            # is sparse then make reference for dense data.
            if isinstance(self.data, np.ndarray):
                self.dense_data = self.data
                # print('np.ndarray')
            elif isinstance(self.data, sparse.coo.coo_matrix) or isinstance(self.data, sparse.csr.csr_matrix):
                # print('Converting sparse to dense data.')
                self.dense_data = self.data.toarray()
                # print('sparse.coo.coo_matrix or sparse.csr.csr_matrix')

            # Convert data to sparse data for cross-comparison between sketches
            # and accept sparse inputs
            # Note that the second arg for `isinstance` depends on how the scipy
            # import is written. If it is `import scipy` then we need
            # `scipy.sparse...` but if we have `import scipy.sparse as sparse`
            # then `sparse.coo....` will suffice.

            # LOGIC: if self.data is sparse just make references for later
            # otherwise, convert to sparse data.
            if isinstance(self.data, sparse.coo.coo_matrix):
                self.coo_data = self.data
                self.rows = self.coo_data.row
                self.cols = self.coo_data.col
                self.vals = self.coo_data.data
                # print('sparse.coo.coo_matrix')
            elif isinstance(self.data, sparse.csr.csr_matrix):
                self.coo_data = self.data.tocoo()
                self.rows = self.coo_data.row
                self.cols = self.coo_data.col
                self.vals = self.coo_data.data
                # print('sparse.csr.csr_matrix')
            else:
                self.coo_data = coo_matrix(data)
                self.rows = self.coo_data.row
                self.cols = self.coo_data.col
                self.vals = self.coo_data.data
                # print('others')
            if self.proj_dim > self.n:
                raise Exception(f'Sketching with projection dimension\
                                         $self.proj_dim > $self.n is not supported')

            ## For SRHT do the bit shift here so not timed in call
            ## to sketch
            if self.sketch_type is 'srht':
                # preprocess data so length is correct and repeat for
                # the intermediate function after hadamard transform
                next_power2_data = shift_bit_length(self.n)
                deficit = next_power2_data - self.n
                self.dense_data = np.concatenate((self.dense_data,
                                                  np.zeros((deficit, self.d),
                                                           dtype=self.dense_data.dtype)), axis=0)
                # set the new n for later use although
                # sampling will only ever be from self.n
                # as the power of 2 extension only necessary for
                # the hadamard transform
                self.new_n = self.dense_data.shape[0]

            ######################## SPARSE SKETCHES ##############################

            ### For sparse sketches generate hash functions here so
            # that they aren't timed in the call to sketch.
            # Would it be better to only accept sparse datasets for the sparse
            # sketches?
            elif self.sketch_type is 'sjlt' or 'countSketch':
                self.col_sparsity = col_sparsity
                if self.sketch_type is 'sjlt' and self.col_sparsity == 1:
                    self.col_sparsity = 4

        ######################## LEARNED SKETCH################################
        ## Function dictionary to call later on.

        self.fct_dict = {'gaussian': self.GaussianSketch,
                         # 'srht': self.SRHT,
                         'countSketch': self.CountSketch,
                         'sjlt': self.SparseJLT,
                         'learnSketch': self.learnSketch}

    def GaussianSketch(self):
        '''Compute the Gaussian random transform
        S_ij = G_ij ~ N(0,1) / sqrt(proj_dim)'''
        S = np.random.randn(self.proj_dim, self.n)
        S /= np.sqrt(self.proj_dim)
        return S @ self.dense_data

    '''
    def SRHT(self):
        diag = np.random.choice([1, -1], self.new_n)[:, None]
        # print("diag shape: {}".format(diag.shape))
        # print("input mat shape: {}".format(input_matrix.shape))
        signed_mat = diag * self.dense_data
        # print(signed_mat.shape)
        S = fastwht(signed_mat) * shift_bit_length(self.n)  # shift bit length is normalising factor
        sample = np.random.choice(self.n, self.proj_dim, replace=False)
        # sample.sort()
        # number from num_rows_data universe
        S = S[sample]
        S = (self.proj_dim) ** (-0.5) * S
        return S
    '''

    # def SRHT(self):
    #     '''
    #     Compute the Subsampled Randomized Hadamard
    #     Transform (aka Fast Johnson Lindesntrauss Transform).
    #
    #     Compute:
    #     PHDA where
    #     - DA is a diagonal matrix whose entries are
    #     {±1} chose uar
    #     - H (DA) applies the Hadamard transform on DA
    #     - P (HDA) uniformly subsamples HDA.
    #
    #
    #     Notes:
    #     The pyfht.fht_inplace doesn't seem to work so we
    #     need a little bit of working space (to store the
    #     FFTed versions of the columns) to do the
    #     entire transform'''
    #     diag = np.random.choice([1,-1], self.new_n)[:,None]
    #     signed_data = diag*self.dense_data
    #
    #     # the [:,None] syntax is just to add a 2nd dimension
    #     # so that the columns after hadamard transform can
    #     # be easily appended.
    #     # It is slightly quicker to generate lists, call
    #     # np.array, then reshape than initialising a zero
    #     # array to store the fhted columns
    #
    #     #Y = np.zeros((self.new_n,self.d))
    #     Y = []
    #     # perform the in place fht on each column
    #     for _col in range(self.d):
    #         # Y[:,_col] = pyfht.fht(self.data[:,_col])
    #         Y.append(pyfht.fht(self.dense_data[:,_col]))
    #     Y = np.array(Y)
    #     Y = Y.T
    #     #print(type(Y),Y.shape)
    #     #print(Y)
    #     # number from num_rows_data universe
    #     sample = np.random.choice(self.n, self.proj_dim, replace=False)
    #
    #     S = Y[sample] * (np.sqrt(1/(self.proj_dim)))
    #     return S

    def CountSketch(self):
        '''Compute the CountSketch transform of the data.
        This is just the  SJLT but with column sparsity 1.
        Given its own method to ensure speed is not
        implicated during later testing.'''
        # start = time.perf_counter()
        self.SA = np.zeros((self.proj_dim, self.d))
        self._row_map = np.random.choice(self.proj_dim,
                                         self.n,
                                         replace=True)
        # end = time.perf_counter()
        # print(end - start)
        self._sign_map = np.random.choice(2, self.n, replace=True) * 2 - 1
        return fast_countSketch(self.SA,
                                self.rows,
                                self.cols,
                                self.vals,
                                self._sign_map,
                                self._row_map)

    def SparseJLT(self):
        '''Compute the SparseJLT of the data
        using the Kane-Nelson construction of
        concatenated CountSketches.

        1. Generate `s` independent countsketches
        each of size m/s x n and concatenate them.

        2. Use initial hash functions as decided above in the
        class definition and then generate new hashes for
        subsequent countsketch calls.'''
        # set the new projection dimension for sjlt
        # this is because the sjlt is an m x n sketch
        # composed of s*(m/s) x n shorter countSketches.
        self.sjlt_proj_dim = self.proj_dim // self.col_sparsity
        self.SA = np.zeros((self.sjlt_proj_dim, self.d))
        self._row_map = np.zeros((self.col_sparsity, self.n))
        self._sign_map = np.zeros((self.col_sparsity, self.n))
        # Generate array whose rows are lists for :
        # 1. row_map
        # 2. sign_map
        # Generate single array self.sjlt_proj_dim
        # to populate for the sketch to which new local
        # sketches will be added.

        for _ in range(self.col_sparsity):
            self._row_map[_, :] = np.random.choice(self.sjlt_proj_dim,
                                                   self.n,
                                                   replace=True)
            self._sign_map[_, :] = np.random.choice(2, self.n, replace=True) * 2 - 1
        self._row_map = self._row_map.astype(int)

        # print('size of SA ', self.SA.shape)
        # print('dType of row map ', self._row_map.dtype)
        local_row_map = self._row_map[0, :]
        local_sign_map = self._sign_map[0, :]
        global_summary = fast_countSketch(self.SA,
                                          self.rows,
                                          self.cols,
                                          self.vals,
                                          local_sign_map,
                                          local_row_map)
        # print('global summary \n', global_summary)
        for sketch_id in range(1, self.col_sparsity):
            # print('sketch_id ', sketch_id+1)
            # print(self.SA)
            local_row_map = self._row_map[sketch_id, :]
            local_sign_map = self._sign_map[sketch_id, :]
            local_summary = fast_countSketch(np.zeros_like(self.SA),
                                             self.rows,
                                             self.cols,
                                             self.vals,
                                             local_sign_map,
                                             local_row_map)
            global_summary = np.concatenate((global_summary,
                                             local_summary), axis=0)
        global_summary *= 1 / np.sqrt(self.col_sparsity)
        return global_summary

    def sketch(self):
        '''
        Perform the transform sketch(A) = SA
        '''
        return self.fct_dict[self.sketch_type]()

    def sketch_data_targets(self, targets):
        '''
        For classic sketching aka sketch-and-solve
        we need to sketch the matrix X where X is the
        matrix which has data|targets appended'''
        X = np.c_[self.data, targets]
        S_Ab = RandomProjection(X, self.proj_dim, self.sketch_type).sketch()
        SA = S_Ab[:, :-1]
        Sb = S_Ab[:, -1]
        # print(SA.shape, Sb.shape)
        return SA, Sb

    def learnSketch(self):
        N_train = self.N_train
        # print(N_train)
        A_train = self.A_train
        B_train = self.B_train
        #print("tt", N_train,)
        #print(A_train)
        #print("ttt", B_train)
        lr = 10e-1
        _iter = N_train
        bs = 1
        device = "cuda:0"
        self.device = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        m = self.proj_dim
        n = self.n
        d = A_train.shape[2]
        num_exp = 1
        k_sparse = 1
        # N_train: 训练集的总数量 | lr:learning rate | m,n:S矩阵的维数 | bs:batch-size
        # num_exp:实验次数 | _iter:学习/迭代次数
        # change your dir
        rltdir = "./output/baselines/"
        save_dir = os.path.join(rltdir, "learnSketch")

        avg_over_exps = 0
        print_freq = 50  # TODO
        for exp_num in range(num_exp):

            it_save_dir = os.path.join(save_dir, "exp_%d" % exp_num)
            it_print_freq = print_freq
            it_lr = lr

            if not os.path.exists(it_save_dir):
                os.makedirs(it_save_dir)

            test_errs = []
            train_errs = []
            fp_times = []
            bp_times = []

            # Initialize sparsity pattern

            # sketch_vector = torch.from_numpy(self.sketch_vector).cuda()


            # for i in range(n):
            #     S[sketch_vector[0][i], i] = sketch_value[i]

            # S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(
            #     k_sparse)] = sketch_value.reshape(-1)

            sketch_vector = self.sketch_vector.cuda()
            sketch_vector.requires_grad = False
            sketch_value = self.sketch_values.cuda()
            sketch_value.requires_grad = True
            # print(sketch_vector)
            # print(sketch_value)
            # print(S.shape)
            # for i in range(n):
            #    print(S[sketch_vector[0][i], i])
            '''
            S_np = np.random.randn(self.proj_dim, self.n)
            S_np /= np.sqrt(self.proj_dim)
            S = torch.from_numpy(S_np).cuda()
            # print(S)
            S.requires_grad = True
            '''
            for bigstep in tqdm(range(_iter)):


                S = torch.zeros(m, n).to(device)
                # S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(1)] = sketch_value.reshape(-1)
                S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(1)] = sketch_value.reshape(-1)

                if (bigstep % 1000 == 0) and it_lr > 1:
                    it_lr = it_lr * 0.3
                if bigstep > 200:
                    it_print_freq = 200

                # fp_start_time = time.time()

                # 用mini-batch来训练
                batch_rand_ind = np.random.randint(0, high=N_train, size=bs)

                x_opt = lasso_solver(A_train[0], B_train[0], constraint=1)
                '''
                # ATb = torch.tensor(A_train[batch_rand_ind].T).cuda() @ B_train[batch_rand_ind]
                A = torch.tensor(A_train[batch_rand_ind].squeeze(), dtype=torch.float32).cuda()
                # AT = torch.tensor(A_train[batch_rand_ind].squeeze().T).cuda()

                # print(AT.shape)
                b = torch.tensor(B_train[batch_rand_ind], dtype=torch.float32).squeeze().cuda()
                ATb = torch.matmul(A.t(), b).cuda()
                # x = torch.tensor(np.zeros((n,)), dtype=torch.float32).cuda()
                x = torch.tensor(np.zeros((d,)), dtype=torch.float32).cuda()
                lambda_val = torch.tensor([1.],  dtype=torch.float32)
                # print(A_train[batch_rand_ind].squeeze().shape)
                # x_opt = torch.tensor(lasso_solver(A_train[batch_rand_ind].squeeze().astype(np.double), B_train[batch_rand_ind].squeeze().astype(np.double), lambda_val))
                # constraint = lambda_val


                # x = iterative_lasso(S, ATb, A.t(), b, x, lambda_val)
                Q = torch.matmul(S.t(), S)
                '''
                '''
                big_Q = [Q, -Q;
                         -Q, Q]
                '''
                '''

                big_Q = torch.cat((torch.cat((Q, -1.0*Q), dim=1), (torch.cat((-1.0*Q, Q), dim=1))), dim=0)
                # big_Q_1 = torch.cat((Q, -1.0*Q), dim=1)
                # print('big_Q_1.shape:')
                # print(big_Q_1.shape)
                # print('Q.shape:')
                # print(Q.shape)
                # print('big_Q.shape:')
                # print(big_Q.shape)
                print('S.shape:')
                print(S.shape)
                print('Q.shape:')
                print(Q.shape)
                print('x.shape:')
                print(x.shape)
                linear_term_1 = torch.matmul(Q, x)
                linear_term_2 = torch.matmul(S.t(), torch.matmul(S, x))
                # linear_term = torch.matmul(Q, x) + ATb - torch.matmul(S.t(), torch.matmul(S, x))
                linear_term = linear_term_1 + ATb -linear_term_2
                big_c = torch.cat((linear_term, -1.0*linear_term), dim=0)

                # penalty term
                constraint_term = lambda_val * torch.ones((2 * d,))
                big_linear_term = constraint_term - big_c

                # nonnegative constraints
                G = -1.0 * torch.eye(2 * d, dtype=torch.float64)
                h = torch.zeros((2 * d,), dtype=torch.float64)
                res = cp.solvers.qp(big_Q, big_linear_term, G, h)
                x = res['x']

                print('**************************')
                loss = torch.mean(torch.norm(x - x_opt, dim=(1, 2)))
                print('loss:')
                print(loss)
                '''
                # AM = torch.from_numpy(A_train[batch_rand_ind]).to(device)
                # BM = torch.from_numpy(B_train[batch_rand_ind]).to(device)

                # AM = torch.tensor(A_train[batch_rand_ind], dtype=torch.float).cuda()
                # BM = torch.tensor(B_train[batch_rand_ind], dtype=torch.float).t().cuda()
                '''
                AM = torch.tensor(np.float32(A_train[batch_rand_ind])).cuda()
                # print('AM:')
                # print(AM)
                # print('BM:')
                BM = torch.tensor(np.float32(B_train[batch_rand_ind])).t().cuda()
                # print(BM)
                # print('AM.shape:')
                # print(AM.shape)
                # print('BM.shape:')
                # print(BM.shape)


                # if bigstep % it_print_freq == 0 or bigstep == (_iter - 1):
                #     # IPython.embed()
                #     train_err, test_err = save_iteration_regression(S, A_train, B_train, A_test, B_test, it_save_dir,
                #                                                     bigstep)
                #     train_errs.append(train_err)
                #     test_errs.append(test_err)
                #     if bigstep == (_iter - 1):
                #         avg_over_exps += (test_err / num_exp)

                SA = torch.matmul(S, AM).cuda()
                SB = torch.matmul(S, BM).cuda()
                U, Sig, V = torch.svd(SA)

                X = V.matmul(torch.diag_embed(1.0 / Sig)).matmul(U.permute(0, 2, 1)).matmul(SB).cuda()
                # print(X)
                ans = AM.matmul(X).cuda()
                '''
                #print(A_train)
                AM = torch.tensor(A_train[batch_rand_ind].squeeze(), dtype=torch.float32).cuda()
                # # AT = torch.tensor(A_train[batch_rand_ind].squeeze().T).cuda()
                #
                # # print(AT.shape)
                BM = torch.tensor(B_train[batch_rand_ind], dtype=torch.float32).squeeze().cuda()
                print(AM.size())
                print(BM.size())
                # # constraint = torch.tensor([1.])
                # ans = lasso_solver_tch(A, b, constraint=1) # , flag = False
                # # print(ans)
                # # TODO: change the loss function
                # loss = torch.mean(torch.norm(ans - x_opt, dim=(1, 2)))
                # print(loss)
                # print("fp: ", time.time() -fp_start_time)
                # fp_times.append(time.time() - fp_start_time)
                # bp_start_time = time.time()

                #AM = A_train[batch_rand_ind].to(self.device)
                #BM = B_train[batch_rand_ind].to(self.device)

                SA = torch.matmul(S, AM)
                SB = torch.matmul(S, BM)
                U, Sig, V = torch.svd(SA)
                X = V.matmul(torch.diag_embed(1.0 / Sig)).matmul(U.permute(0, 2, 1)).matmul(SB)
                ans = AM.matmul(X)
                loss = torch.mean(torch.norm(ans - BM, dim=(1, 2)))

                loss.backward(retain_graph=True)
                print('sketch_value.grad:')
                print(sketch_value.grad)
                # print("bp: ", time.time() -bp_start_time)
                # bp_times.append(time.time() - bp_start_time)

                # TODO: Maybe don't have to divide by bs: is this similar to lev_score_experiments bug?
                # However, if you change it, then you need to compensate in lr... all old exp will be invalidated
                with torch.no_grad():
                    # sketch_value -= (it_lr / bs) * sketch_value.grad
                    # sketch_value.grad.zero_()
                    sketch_value -= (it_lr / bs) * sketch_value.grad
                    sketch_value.grad.zero_()
                # print('sketch_value:')
                # print(sketch_value)
                # del SA, SB, U, Sig, V, X, ans, loss
                # del AT, b, ATb, x, x_opt
                torch.cuda.empty_cache()

            # print(S_learn)
            #A = torch.tensor(np.float32(A_train[0])).cuda()

            return np.array(S, dtype=np.float)

            # np.save(os.path.join(it_save_dir, "train_errs.npy"), train_errs, allow_pickle=True)
            # np.save(os.path.join(it_save_dir, "test_errs.npy"), test_errs, allow_pickle=True)
            # np.save(os.path.join(it_save_dir, "fp_times.npy"), fp_times, allow_pickle=True)
            # np.save(os.path.join(it_save_dir, "bp_times.npy"), bp_times, allow_pickle=True)


def save_iteration_regression(S, A_train, B_train, A_test, B_test, save_dir, bigstep):
    """
    Not implemented:
    Mixed matrix evaluation
    """
    torch_save_fpath = os.path.join(save_dir, "it_%d" % bigstep)

    test_err = evaluate_to_rule_them_all_regression(A_test, B_test, S)
    # train_err = 0
    train_err = evaluate_to_rule_them_all_regression(A_train, B_train, S)
    torch.save([[S], [train_err, test_err]], torch_save_fpath)

    print(train_err, test_err)
    print("Saved iteration: %d" % bigstep)
    return train_err, test_err


def evaluate_to_rule_them_all_regression(A_set, B_set, S):
    """
    BATCHED, but also iterative (i.e. for data=hyper, eval list may be ~3000)
    :param A: list of matrices (3D tensor)
    :param sketch: S or [S, S2] concatenated; assumed matrices
    :param k: low-rank k
    :return: K-rk approx cost, averaged over matrices in eval_list
    """
    # A = eval_list
    n = A_set.size()[0]
    bs = 100
    loss = 0
    # print("ln 149 in eval: check everything on cpu")
    # IPython.embed()
    for i in range(math.ceil(n / float(bs))):
        AM = A_set[i * bs:min(n, (i + 1) * bs)]
        BM = B_set[i * bs:min(n, (i + 1) * bs)]

        SA = torch.matmul(S.cpu(), AM)
        SB = torch.matmul(S.cpu(), BM)
        U, Sig, V = torch.svd(SA)
        X = V.matmul(torch.diag_embed(1.0 / Sig)).matmul(U.permute(0, 2, 1)).matmul(SB)
        ans = AM.matmul(X)
        it_loss = torch.sum(torch.norm(ans - BM, dim=(1, 2))) / n
        loss += it_loss.item()
    # print("in evaluate.py, ln 154")
    # IPython.embed()
    return loss
