import numpy as np
import torch
import time
import os
import math
from tqdm import tqdm
from numba import njit
from scipy import sparse
from scipy.sparse import coo_matrix
#from regression_solvers import lasso_solver_tch
from regression_solvers import lasso_solver, iterative_lasso
import cvxopt as cp
# from hadamard import fastwht
# import pyfht


# Helper functions

@njit(fastmath=True)
def fast_countSketch(SA, row, col, data, sign, row_map):
    N = len(data)
    for idx in range(N):
        SA[row_map[row[idx]], col[idx]] += data[idx] * sign[row[idx]]
    return SA


def shift_bit_length(x):
    '''Given int x find next largest power of 2.
    If x is a power of 2 then x is returned '''
    return 1 << (x - 1).bit_length()


class RandomProjection:
    '''
    Class based implementation of different
    Random projection methods for RNLA.

    Current examples are taken from [1]
    however there are various sampling methods
    which could also be implemented.

    [1] - https://researcher.watson.ibm.com/researcher/files/us-dpwoodru/wNow.pdf'''

    def __init__(self, data, proj_dim,
                 sketch_type, col_sparsity=1):
        '''
        data - ndarray of the data (or vector) to sketch
        proj_dim - int --> Value to sketch down to
        sketch_type - str determining which sketch to use

        Sketch Type support args:
        1. gaussian
        2. srht
        3. sjlt -- default col_sparsity is 4
        4. countsketch
        '''

        if sketch_type == 'learnSketch':
            '''
            data 是一个3D tensor，第一维代表不同的数据集（共7个），第二维代表不同的数据（每个数据集1000条），第三维代表不同的feature（300个feature，第301个是x）
            '''
            # 数据集的数量
            self.N_train = data.shape[0]
            self.n = data.shape[1]
            self.d = data.shape[2]

            self.proj_dim = proj_dim
            self.sketch_type = sketch_type

            # self.sketch_vector = np.random.choice(self.proj_dim,
            #                                  self.n,
            #                                  replace=True,
            #                                  )
            # self._sign_map = np.random.choice(2, self.n, replace=True) * 2 - 1

            self.sketch_vector = torch.randint(0, self.proj_dim, (1, self.n))

            # self.sketch_values = torch.rand(self.n)
            self.sketch_values = ((torch.randint(2, [1, self.n]).float() - 0.5) * 2).cpu()

            self.A_train = data[:, :, :-1]
            self.B_train = data[:, :, -1]
            # print('self.B_train.shape:')
            # print(self.B_train.shape)
        else:
            # print('sketch_type:', sketch_type)
            # print(sketch_type)
            self.data = data
            self.n, self.d = data.shape
            self.proj_dim = proj_dim
            self.sketch_type = sketch_type
            #print(self.n)
            # print('self.proj_dim:')
            # print(self.proj_dim)
            # print('self.n:')
            # print(self.n)
            # Create an attribute for dense data for later reference.
            # If data is ndarray then just make new reference, otherwise, if input
            # is sparse then make reference for dense data.
            if isinstance(self.data, np.ndarray):
                self.dense_data = self.data
                # print('np.ndarray')
            elif isinstance(self.data, sparse.coo.coo_matrix) or isinstance(self.data, sparse.csr.csr_matrix):
                # print('Converting sparse to dense data.')
                self.dense_data = self.data.toarray()
                # print('sparse.coo.coo_matrix or sparse.csr.csr_matrix')

            # Convert data to sparse data for cross-comparison between sketches
            # and accept sparse inputs
            # Note that the second arg for `isinstance` depends on how the scipy
            # import is written. If it is `import scipy` then we need
            # `scipy.sparse...` but if we have `import scipy.sparse as sparse`
            # then `sparse.coo....` will suffice.

            # LOGIC: if self.data is sparse just make references for later
            # otherwise, convert to sparse data.
            if isinstance(self.data, sparse.coo.coo_matrix):
                self.coo_data = self.data
                self.rows = self.coo_data.row
                self.cols = self.coo_data.col
                self.vals = self.coo_data.data
                # print('sparse.coo.coo_matrix')
            elif isinstance(self.data, sparse.csr.csr_matrix):
                self.coo_data = self.data.tocoo()
                self.rows = self.coo_data.row
                self.cols = self.coo_data.col
                self.vals = self.coo_data.data
                # print('sparse.csr.csr_matrix')
            else:
                self.coo_data = coo_matrix(data)
                self.rows = self.coo_data.row
                self.cols = self.coo_data.col
                self.vals = self.coo_data.data
                # print('others')
            if self.proj_dim > self.n:
                raise Exception(f'Sketching with projection dimension\
                                         $self.proj_dim > $self.n is not supported')

            ## For SRHT do the bit shift here so not timed in call
            ## to sketch
            if self.sketch_type is 'srht':
                # preprocess data so length is correct and repeat for
                # the intermediate function after hadamard transform
                next_power2_data = shift_bit_length(self.n)
                deficit = next_power2_data - self.n
                self.dense_data = np.concatenate((self.dense_data,
                                                  np.zeros((deficit, self.d),
                                                           dtype=self.dense_data.dtype)), axis=0)
                # set the new n for later use although
                # sampling will only ever be from self.n
                # as the power of 2 extension only necessary for
                # the hadamard transform
                self.new_n = self.dense_data.shape[0]

            ######################## SPARSE SKETCHES ##############################

            ### For sparse sketches generate hash functions here so
            # that they aren't timed in the call to sketch.
            # Would it be better to only accept sparse datasets for the sparse
            # sketches?
            elif self.sketch_type is 'sjlt' or 'countSketch':
                self.col_sparsity = col_sparsity
                if self.sketch_type is 'sjlt' and self.col_sparsity == 1:
                    self.col_sparsity = 4

        ######################## LEARNED SKETCH################################
        ## Function dictionary to call later on.

        self.fct_dict = {'gaussian': self.GaussianSketch,
                         # 'srht': self.SRHT,
                         'countSketch': self.CountSketch,
                         'sjlt': self.SparseJLT,
                         'learnSketch': self.learnSketch}

    def GaussianSketch(self):
        '''Compute the Gaussian random transform
        S_ij = G_ij ~ N(0,1) / sqrt(proj_dim)'''
        S = np.random.randn(self.proj_dim, self.n)
        S /= np.sqrt(self.proj_dim)
        return S @ self.dense_data

    '''
    def SRHT(self):
        diag = np.random.choice([1, -1], self.new_n)[:, None]
        # print("diag shape: {}".format(diag.shape))
        # print("input mat shape: {}".format(input_matrix.shape))
        signed_mat = diag * self.dense_data
        # print(signed_mat.shape)
        S = fastwht(signed_mat) * shift_bit_length(self.n)  # shift bit length is normalising factor
        sample = np.random.choice(self.n, self.proj_dim, replace=False)
        # sample.sort()
        # number from num_rows_data universe
        S = S[sample]
        S = (self.proj_dim) ** (-0.5) * S
        return S
    '''

    # def SRHT(self):
    #     '''
    #     Compute the Subsampled Randomized Hadamard
    #     Transform (aka Fast Johnson Lindesntrauss Transform).
    #
    #     Compute:
    #     PHDA where
    #     - DA is a diagonal matrix whose entries are
    #     {±1} chose uar
    #     - H (DA) applies the Hadamard transform on DA
    #     - P (HDA) uniformly subsamples HDA.
    #
    #
    #     Notes:
    #     The pyfht.fht_inplace doesn't seem to work so we
    #     need a little bit of working space (to store the
    #     FFTed versions of the columns) to do the
    #     entire transform'''
    #     diag = np.random.choice([1,-1], self.new_n)[:,None]
    #     signed_data = diag*self.dense_data
    #
    #     # the [:,None] syntax is just to add a 2nd dimension
    #     # so that the columns after hadamard transform can
    #     # be easily appended.
    #     # It is slightly quicker to generate lists, call
    #     # np.array, then reshape than initialising a zero
    #     # array to store the fhted columns
    #
    #     #Y = np.zeros((self.new_n,self.d))
    #     Y = []
    #     # perform the in place fht on each column
    #     for _col in range(self.d):
    #         # Y[:,_col] = pyfht.fht(self.data[:,_col])
    #         Y.append(pyfht.fht(self.dense_data[:,_col]))
    #     Y = np.array(Y)
    #     Y = Y.T
    #     #print(type(Y),Y.shape)
    #     #print(Y)
    #     # number from num_rows_data universe
    #     sample = np.random.choice(self.n, self.proj_dim, replace=False)
    #
    #     S = Y[sample] * (np.sqrt(1/(self.proj_dim)))
    #     return S

    def CountSketch(self):
        '''Compute the CountSketch transform of the data.
        This is just the  SJLT but with column sparsity 1.
        Given its own method to ensure speed is not
        implicated during later testing.'''
        self.SA = np.zeros((self.proj_dim, self.d))
        self._row_map = np.random.choice(self.proj_dim,
                                         self.n,
                                         replace=True)
        self._sign_map = np.random.choice(2, self.n, replace=True) * 2 - 1
        return fast_countSketch(self.SA,
                                self.rows,
                                self.cols,
                                self.vals,
                                self._sign_map,
                                self._row_map)

    def SparseJLT(self):
        '''Compute the SparseJLT of the data
        using the Kane-Nelson construction of
        concatenated CountSketches.

        1. Generate `s` independent countsketches
        each of size m/s x n and concatenate them.

        2. Use initial hash functions as decided above in the
        class definition and then generate new hashes for
        subsequent countsketch calls.'''
        # set the new projection dimension for sjlt
        # this is because the sjlt is an m x n sketch
        # composed of s*(m/s) x n shorter countSketches.
        self.sjlt_proj_dim = self.proj_dim // self.col_sparsity
        self.SA = np.zeros((self.sjlt_proj_dim, self.d))
        self._row_map = np.zeros((self.col_sparsity, self.n))
        self._sign_map = np.zeros((self.col_sparsity, self.n))
        # Generate array whose rows are lists for :
        # 1. row_map
        # 2. sign_map
        # Generate single array self.sjlt_proj_dim
        # to populate for the sketch to which new local
        # sketches will be added.

        for _ in range(self.col_sparsity):
            self._row_map[_, :] = np.random.choice(self.sjlt_proj_dim,
                                                   self.n,
                                                   replace=True)
            self._sign_map[_, :] = np.random.choice(2, self.n, replace=True) * 2 - 1
        self._row_map = self._row_map.astype(int)

        # print('size of SA ', self.SA.shape)
        # print('dType of row map ', self._row_map.dtype)
        local_row_map = self._row_map[0, :]
        local_sign_map = self._sign_map[0, :]
        global_summary = fast_countSketch(self.SA,
                                          self.rows,
                                          self.cols,
                                          self.vals,
                                          local_sign_map,
                                          local_row_map)
        # print('global summary \n', global_summary)
        for sketch_id in range(1, self.col_sparsity):
            # print('sketch_id ', sketch_id+1)
            # print(self.SA)
            local_row_map = self._row_map[sketch_id, :]
            local_sign_map = self._sign_map[sketch_id, :]
            local_summary = fast_countSketch(np.zeros_like(self.SA),
                                             self.rows,
                                             self.cols,
                                             self.vals,
                                             local_sign_map,
                                             local_row_map)
            global_summary = np.concatenate((global_summary,
                                             local_summary), axis=0)
        global_summary *= 1 / np.sqrt(self.col_sparsity)
        return global_summary

    def sketch(self):
        '''
        Perform the transform sketch(A) = SA
        '''
        return self.fct_dict[self.sketch_type]()

    def sketch_data_targets(self, targets):
        '''
        For classic sketching aka sketch-and-solve
        we need to sketch the matrix X where X is the
        matrix which has data|targets appended'''
        X = np.c_[self.data, targets]
        S_Ab = RandomProjection(X, self.proj_dim, self.sketch_type).sketch()
        SA = S_Ab[:, :-1]
        Sb = S_Ab[:, -1]
        # print(SA.shape, Sb.shape)
        return SA, Sb

    def learnSketch(self):
        N_train = self.N_train
        # print(N_train)
        A_train = self.A_train
        B_train = self.B_train
        #print("tt", N_train,)
        #print(A_train)
        #print("ttt", B_train)
        lr = 10e-1
        _iter = N_train
        bs = 1
        device = "cuda:0"
        self.device = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        m = self.proj_dim
        n = self.n
        d = A_train.shape[2]
        num_exp = 1
        k_sparse = 1
        # N_train: 训练集的总数量 | lr:learning rate | m,n:S矩阵的维数 | bs:batch-size
        # num_exp:实验次数 | _iter:学习/迭代次数
        # change your dir
        rltdir = "./output/baselines/"
        save_dir = os.path.join(rltdir, "learnSketch")

        avg_over_exps = 0
        print_freq = 50  # TODO
        for exp_num in range(num_exp):

            it_save_dir = os.path.join(save_dir, "exp_%d" % exp_num)
            it_print_freq = print_freq
            it_lr = lr

            if not os.path.exists(it_save_dir):
                os.makedirs(it_save_dir)

            test_errs = []
            train_errs = []
            fp_times = []
            bp_times = []

            # Initialize sparsity pattern

            # sketch_vector = torch.from_numpy(self.sketch_vector).cuda()


            # for i in range(n):
            #     S[sketch_vector[0][i], i] = sketch_value[i]

            # S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(
            #     k_sparse)] = sketch_value.reshape(-1)

            sketch_vector = self.sketch_vector.cuda()
            sketch_vector.requires_grad = False
            sketch_value = self.sketch_values.cuda()
            sketch_value.requires_grad = True
            # print(sketch_vector)
            # print(sketch_value)
            # print(S.shape)
            # for i in range(n):
            #    print(S[sketch_vector[0][i], i])
            '''
            S_np = np.random.randn(self.proj_dim, self.n)
            S_np /= np.sqrt(self.proj_dim)
            S = torch.from_numpy(S_np).cuda()
            # print(S)
            S.requires_grad = True
            '''
            for bigstep in tqdm(range(_iter)):


                S = torch.zeros(m, n).to(device)
                # S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(1)] = sketch_value.reshape(-1)
                S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(1)] = sketch_value.reshape(-1)

                if (bigstep % 1000 == 0) and it_lr > 1:
                    it_lr = it_lr * 0.3
                if bigstep > 200:
                    it_print_freq = 200

                # fp_start_time = time.time()

                # 用mini-batch来训练
                batch_rand_ind = np.random.randint(0, high=N_train, size=bs)

                x_opt = lasso_solver(A_train[0], B_train[0], constraint=1)
                '''
                # ATb = torch.tensor(A_train[batch_rand_ind].T).cuda() @ B_train[batch_rand_ind]
                A = torch.tensor(A_train[batch_rand_ind].squeeze(), dtype=torch.float32).cuda()
                # AT = torch.tensor(A_train[batch_rand_ind].squeeze().T).cuda()

                # print(AT.shape)
                b = torch.tensor(B_train[batch_rand_ind], dtype=torch.float32).squeeze().cuda()
                ATb = torch.matmul(A.t(), b).cuda()
                # x = torch.tensor(np.zeros((n,)), dtype=torch.float32).cuda()
                x = torch.tensor(np.zeros((d,)), dtype=torch.float32).cuda()
                lambda_val = torch.tensor([1.],  dtype=torch.float32)
                # print(A_train[batch_rand_ind].squeeze().shape)
                # x_opt = torch.tensor(lasso_solver(A_train[batch_rand_ind].squeeze().astype(np.double), B_train[batch_rand_ind].squeeze().astype(np.double), lambda_val))
                # constraint = lambda_val


                # x = iterative_lasso(S, ATb, A.t(), b, x, lambda_val)
                Q = torch.matmul(S.t(), S)
                '''
                '''
                big_Q = [Q, -Q;
                         -Q, Q]
                '''
                '''

                big_Q = torch.cat((torch.cat((Q, -1.0*Q), dim=1), (torch.cat((-1.0*Q, Q), dim=1))), dim=0)
                # big_Q_1 = torch.cat((Q, -1.0*Q), dim=1)
                # print('big_Q_1.shape:')
                # print(big_Q_1.shape)
                # print('Q.shape:')
                # print(Q.shape)
                # print('big_Q.shape:')
                # print(big_Q.shape)
                print('S.shape:')
                print(S.shape)
                print('Q.shape:')
                print(Q.shape)
                print('x.shape:')
                print(x.shape)
                linear_term_1 = torch.matmul(Q, x)
                linear_term_2 = torch.matmul(S.t(), torch.matmul(S, x))
                # linear_term = torch.matmul(Q, x) + ATb - torch.matmul(S.t(), torch.matmul(S, x))
                linear_term = linear_term_1 + ATb -linear_term_2
                big_c = torch.cat((linear_term, -1.0*linear_term), dim=0)

                # penalty term
                constraint_term = lambda_val * torch.ones((2 * d,))
                big_linear_term = constraint_term - big_c

                # nonnegative constraints
                G = -1.0 * torch.eye(2 * d, dtype=torch.float64)
                h = torch.zeros((2 * d,), dtype=torch.float64)
                res = cp.solvers.qp(big_Q, big_linear_term, G, h)
                x = res['x']

                print('**************************')
                loss = torch.mean(torch.norm(x - x_opt, dim=(1, 2)))
                print('loss:')
                print(loss)
                '''
                # AM = torch.from_numpy(A_train[batch_rand_ind]).to(device)
                # BM = torch.from_numpy(B_train[batch_rand_ind]).to(device)

                # AM = torch.tensor(A_train[batch_rand_ind], dtype=torch.float).cuda()
                # BM = torch.tensor(B_train[batch_rand_ind], dtype=torch.float).t().cuda()
                '''
                AM = torch.tensor(np.float32(A_train[batch_rand_ind])).cuda()
                # print('AM:')
                # print(AM)
                # print('BM:')
                BM = torch.tensor(np.float32(B_train[batch_rand_ind])).t().cuda()
                # print(BM)
                # print('AM.shape:')
                # print(AM.shape)
                # print('BM.shape:')
                # print(BM.shape)


                # if bigstep % it_print_freq == 0 or bigstep == (_iter - 1):
                #     # IPython.embed()
                #     train_err, test_err = save_iteration_regression(S, A_train, B_train, A_test, B_test, it_save_dir,
                #                                                     bigstep)
                #     train_errs.append(train_err)
                #     test_errs.append(test_err)
                #     if bigstep == (_iter - 1):
                #         avg_over_exps += (test_err / num_exp)

                SA = torch.matmul(S, AM).cuda()
                SB = torch.matmul(S, BM).cuda()
                U, Sig, V = torch.svd(SA)

                X = V.matmul(torch.diag_embed(1.0 / Sig)).matmul(U.permute(0, 2, 1)).matmul(SB).cuda()
                # print(X)
                ans = AM.matmul(X).cuda()
                '''
                #print(A_train)
                AM = torch.tensor(A_train[batch_rand_ind].squeeze(), dtype=torch.float32).cuda()
                # # AT = torch.tensor(A_train[batch_rand_ind].squeeze().T).cuda()
                #
                # # print(AT.shape)
                BM = torch.tensor(B_train[batch_rand_ind], dtype=torch.float32).squeeze().cuda()
                print(AM.size())
                print(BM.size())
                # # constraint = torch.tensor([1.])
                # ans = lasso_solver_tch(A, b, constraint=1) # , flag = False
                # # print(ans)
                # # TODO: change the loss function
                # loss = torch.mean(torch.norm(ans - x_opt, dim=(1, 2)))
                # print(loss)
                # print("fp: ", time.time() -fp_start_time)
                # fp_times.append(time.time() - fp_start_time)
                # bp_start_time = time.time()

                #AM = A_train[batch_rand_ind].to(self.device)
                #BM = B_train[batch_rand_ind].to(self.device)

                SA = torch.matmul(S, AM)
                SB = torch.matmul(S, BM)
                U, Sig, V = torch.svd(SA)
                X = V.matmul(torch.diag_embed(1.0 / Sig)).matmul(U.permute(0, 2, 1)).matmul(SB)
                ans = AM.matmul(X)
                loss = torch.mean(torch.norm(ans - BM, dim=(1, 2)))

                loss.backward(retain_graph=True)
                print('sketch_value.grad:')
                print(sketch_value.grad)
                # print("bp: ", time.time() -bp_start_time)
                # bp_times.append(time.time() - bp_start_time)

                # TODO: Maybe don't have to divide by bs: is this similar to lev_score_experiments bug?
                # However, if you change it, then you need to compensate in lr... all old exp will be invalidated
                with torch.no_grad():
                    # sketch_value -= (it_lr / bs) * sketch_value.grad
                    # sketch_value.grad.zero_()
                    sketch_value -= (it_lr / bs) * sketch_value.grad
                    sketch_value.grad.zero_()
                # print('sketch_value:')
                # print(sketch_value)
                # del SA, SB, U, Sig, V, X, ans, loss
                # del AT, b, ATb, x, x_opt
                torch.cuda.empty_cache()

            # print(S_learn)
            #A = torch.tensor(np.float32(A_train[0])).cuda()

            return np.array(S, dtype=np.float)

            # np.save(os.path.join(it_save_dir, "train_errs.npy"), train_errs, allow_pickle=True)
            # np.save(os.path.join(it_save_dir, "test_errs.npy"), test_errs, allow_pickle=True)
            # np.save(os.path.join(it_save_dir, "fp_times.npy"), fp_times, allow_pickle=True)
            # np.save(os.path.join(it_save_dir, "bp_times.npy"), bp_times, allow_pickle=True)


def save_iteration_regression(S, A_train, B_train, A_test, B_test, save_dir, bigstep):
    """
    Not implemented:
    Mixed matrix evaluation
    """
    torch_save_fpath = os.path.join(save_dir, "it_%d" % bigstep)

    test_err = evaluate_to_rule_them_all_regression(A_test, B_test, S)
    # train_err = 0
    train_err = evaluate_to_rule_them_all_regression(A_train, B_train, S)
    torch.save([[S], [train_err, test_err]], torch_save_fpath)

    print(train_err, test_err)
    print("Saved iteration: %d" % bigstep)
    return train_err, test_err


def evaluate_to_rule_them_all_regression(A_set, B_set, S):
    """
    BATCHED, but also iterative (i.e. for data=hyper, eval list may be ~3000)
    :param A: list of matrices (3D tensor)
    :param sketch: S or [S, S2] concatenated; assumed matrices
    :param k: low-rank k
    :return: K-rk approx cost, averaged over matrices in eval_list
    """
    # A = eval_list
    n = A_set.size()[0]
    bs = 100
    loss = 0
    # print("ln 149 in eval: check everything on cpu")
    # IPython.embed()
    for i in range(math.ceil(n / float(bs))):
        AM = A_set[i * bs:min(n, (i + 1) * bs)]
        BM = B_set[i * bs:min(n, (i + 1) * bs)]

        SA = torch.matmul(S.cpu(), AM)
        SB = torch.matmul(S.cpu(), BM)
        U, Sig, V = torch.svd(SA)
        X = V.matmul(torch.diag_embed(1.0 / Sig)).matmul(U.permute(0, 2, 1)).matmul(SB)
        ans = AM.matmul(X)
        it_loss = torch.sum(torch.norm(ans - BM, dim=(1, 2))) / n
        loss += it_loss.item()
    # print("in evaluate.py, ln 154")
    # IPython.embed()
    return loss
