# This file contains implementations of various selection algorithms, including Matching Pursuit (MP), Greedy LR, Determinantal Point Processes (DPP),and mutual information-based methods. It also includes optimized versions using Numba for performance.

import numpy as np
import math
import numba
from numba import jit
import time
from sklearn.utils.extmath import randomized_svd

class DPPModel():
    def __init__(self, kernel_matrix, max_length):
        self.kernel_matrix = kernel_matrix
        self.max_length = max_length

    def compute_selection(self, epsilon=1E-10):
        """
        Our proposed fast implementation of the greedy algorithm
        :param kernel_matrix: 2-d array
        :param max_length: positive int
        :param epsilon: small positive scalar
        :return: list
        """
        item_size = self.kernel_matrix.shape[0]
        cis = np.zeros((self.max_length, item_size))
        di2s = np.copy(np.diag(self.kernel_matrix))
        selected_items = list()
        selected_item = np.argmax(di2s)
        selected_items.append(selected_item)
        while len(selected_items) < self.max_length:
            k = len(selected_items) - 1
            ci_optimal = cis[:k, selected_item]
            di_optimal = math.sqrt(di2s[selected_item])
            elements = self.kernel_matrix[selected_item, :]
            eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal
            cis[k, :] = eis
            di2s -= np.square(eis)
            #print(di2s)
            di2s[selected_item] = -np.inf
            selected_item = np.argmax(di2s)
            if di2s[selected_item] < epsilon:
                break
            selected_items.append(selected_item)
        return selected_items

class BFModel(): 
    def __init__(self, kernel_matrix, reward_vec, max_length) -> None:
        self.kernel_matrix = kernel_matrix
        self.reward_vec = reward_vec
        self.max_length = max_length

    def compute_selection(self):
        L = np.linalg.cholesky(self.kernel_matrix + 1e-5 * np.eye(self.kernel_matrix.shape[0]))
        t_v_true = np.linalg.inv(L)
        obj = t_v_true.dot(self.reward_vec)
        res = np.argsort(-obj, axis=0)[:self.max_length][:,0]
        return res 

@jit
def setdiff(item_size, selected_items):
    # Allocate an array of the maximum possible size.
    result = np.empty(item_size, dtype=np.int64)
    count = 0
    # Loop over every index from 0 to item_size - 1.
    for i in range(item_size):
        is_selected = False
        # Check if i is in selected_items.
        for j in range(selected_items.shape[0]):
            if i == selected_items[j]:
                is_selected = True
                break
        # If not, add it to the result.
        if not is_selected:
            result[count] = i
            count += 1
    # Return only the filled part of the array.
    return result[:count]

@jit 
def compute_selection_numba(
        selected_items,
        max_length, 
        kernel_matrix, 
        reward_vec, 
        cis,
        i_row,
        h_i,
        item_size, 
        mask,
        h,
        obj_mask,
        dis,
        t_v):
    """Fast greedy algorithm to maximize the mutual information

    Returns:
        list: selected items
    """
    k = 0
    selected_item = selected_items[0]
    while k < max_length: 
        cur_select_items = selected_items[:(k+1)]
        elements = kernel_matrix[selected_item, :]
        ci_optimal = cis[:k, selected_item]
        di_optimal = dis[selected_item] + 1e-6

        eis = ((elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal).reshape(-1, 1)
        i_row[:, k] = 1

        if k == 0: 
            t_v[k, :] = i_row / di_optimal
        else:
            t_v[k, :] = (i_row - np.dot(ci_optimal.T, t_v[:k, :])) / di_optimal
        
        
        h_i[k, :] = t_v[k:k+1, :(k+1)].dot(reward_vec[cur_select_items,:])
        cis[k, :] = eis.reshape(-1,)
        dis = ((dis ** 2) - np.square(eis)) ** (1/2) + 1e-6
        i_row[:, k] = 0
        
        v_i = cis[:(k+1), :]
        t_v_prev_i = t_v[:(k+1), :]
        
        i_row[:, (k+1)] = 1
        t_v_i = (i_row - np.dot(v_i.T, t_v_prev_i)) / dis
        i_row[:, (k+1)] = 0
        
        not_selected = setdiff(item_size, cur_select_items)
        mask[not_selected] = 0
        h[k, :] = (np.dot(t_v_i, reward_vec) * mask).reshape(-1, )

        cur_t_cis = t_v_i[:, :(k+1)]
        cur_t_dis = t_v_i[:, (k+1)]

        obj_mask[selected_item] = -np.inf
        obj = (cur_t_cis.dot(reward_vec[selected_items[:(k+1)]]) + cur_t_dis * reward_vec) ** 2

        obj = (obj - obj.min() + 1e-6) * obj_mask
        selected_item = np.argmax(obj, axis=0)[0]
        #selected_item = 0
        selected_items[k+1] = int(selected_item)
        k += 1 

    return selected_items, t_v

@numba.njit
def compute_selection_v3_numba(kernel_matrix, reward_vec, max_length):
    """
    Numba-accelerated code that does a rank-1 update selection
    in a way consistent with the original (unoptimized) approach.

    Key points:
      1) 'cis' has shape (k, n): each row k stores the vector 'eis' for iteration k.
      2) 'dis' is a length-n array tracking diagonal updates.
      3) 't_v' also has shape (k, n), but we are careful about how we
         compute the partial sums so that we do the correct dimension logic.
      4) The main bug fix: for computing the next candidate's objective,
         we do a "fresh rank-1 *prediction* step" in a dimension-consistent way.
    """
    n = kernel_matrix.shape[0]
    if max_length == 0 or n == 0:
        return np.zeros(0, dtype=np.int64), np.empty((0,0), dtype=kernel_matrix.dtype)

    # Manually extract the diagonal instead of np.diagonal()
    dis = np.zeros(n, dtype=kernel_matrix.dtype)
    for i in range(n):
        dis[i] = kernel_matrix[i, i]

    # Arrays for incremental updates
    cis = np.zeros((max_length, n), dtype=kernel_matrix.dtype)
    t_v = np.zeros((max_length, n), dtype=kernel_matrix.dtype)
    t_v_i = np.zeros((n, n), dtype=kernel_matrix.dtype)  # used when building updates

    i_row = np.zeros(n, dtype=kernel_matrix.dtype)      # used when building updates
    obj_mask = np.ones(n, dtype=kernel_matrix.dtype)    # to exclude selected items

    reward_vec = reward_vec.ravel()  # ensure 1D
    eps = 1e-6

    # -- 1) Pick the first item by the simple ratio reward / diag
    first_round_obj = reward_vec / (dis + eps)
    first_item = np.argmax(first_round_obj)
    selected_items = np.zeros(max_length, dtype=np.int64)
    selected_items[0] = first_item
    obj_mask[first_item] = -np.inf

    sum_rewards_chosen = reward_vec[first_item]
    num_selected = 1

    # -------------------------------------------------------------------------
    # Repeatedly select items until we have max_length picks
    # -------------------------------------------------------------------------
    for k in range(1, max_length):
        prev_item = selected_items[k-1]
        di_optimal = dis[prev_item] + eps

        #
        # ============== Rank-1 Update: fill cis[k-1, :] ==================
        #
        # cis[k-1, j] = e_i[j] for each j in [0..n-1]
        #
        # e_i[j] = ( kernel_matrix[prev_item, j]
        #            - sum_{p=0..k-2} [ cis[p, prev_item ] * cis[p, j ] ]
        #          ) / di_optimal
        #
        if k == 1:
            # no previous row in cis
            for j in range(n):
                val = kernel_matrix[prev_item, j] / di_optimal
                cis[k-1, j] = val
        else:
            # subtract the rank-1 contributions
            for j in range(n):
                acc = 0.0
                for p in range(k-1):
                    acc += cis[p, prev_item] * cis[p, j]
                val = (kernel_matrix[prev_item, j] - acc) / di_optimal
                cis[k-1, j] = val

        # =============== Update dis ================
        # dis[j] = sqrt( dis[j]^2 - e_i[j]^2 ) + eps
        for j in range(n):
            e_j = cis[k-1, j]
            dis[j] = np.sqrt(dis[j]*dis[j] - e_j*e_j) + eps
        #
        # ============ Update t_v (k-th row) ============
        #
        # t_v[k-1, j] = ( i_row[j] - sum_{p=0..k-2}[ cis[p, prev_item] * t_v[p, j] ] ) / di_optimal
        #
        # Here i_row is 0 except i_row[prev_item] = 1
        i_row[:] = 0
        i_row[k-1] = 1.0

        if k == 1:
            for j in range(n):
                t_v[k-1, j] = i_row[j] / di_optimal
        else:
            for j in range(n):
                acc = 0.0
                for p in range(k-1):
                    acc += cis[p, prev_item] * t_v[p, j]
                t_v[k-1, j] = (i_row[j] - acc) / di_optimal

        #
        # =============== Compute objective for next pick ===============
        #
        # We want to see: if we *also* pick some item j next, what's the incremental effect?
        # We'll do a "predictive rank-1 step" to get the contribution from item j.
        #
        # Pseudocode logic:
        #    i_row2:  0 except i_row2[j] = 1
        #    Then do  tmp_vec_j = i_row2[j] - sum_{p=0..k-1} [ cis[p, j]* t_v[p, ??? ] ]
        #    next_tv_j = tmp_vec_j / (dis[j]+ eps)
        #    objective_j = ( sum_rewards_chosen + next_tv_j * reward_vec[j] )^2
        #
        # We'll compute that for all j, then pick argmax.

        obj = np.full(n, -np.inf, dtype=kernel_matrix.dtype)
        # TODO: something wrong with 
        for j in range(n): 
            for p in range(k+1): 
                acc = 0.0
                for m in range(k+1): 
                    acc += cis[m, j] * t_v[m, p]

                #print("numba acc", k, j, p, acc)            
                if p == k: 
                    t_v_i[j, p] = (1 - acc) / dis[j]
                else: 
                    t_v_i[j, p] = (0 - acc) / dis[j]

        # Loop over each candidate j
        for j in range(n):
            sum_rewards_chosen = 0 
            sum_rewards_next = 0
            if obj_mask[j] <= 0:       # means it's -inf => already chosen
                continue

            # "predict" adding j by building a "temp" as if we do rank-1 update with item j
            # i.e. i_row2 is zero except i_row2[j] = 1
            # Then subtract sum_{p=0..k-1}[ cis[p, j]* t_v[p, j ] ]
            # Note: we use index [p, j], NOT [p, prev_item]! That was the bug you had if
            #       you used the same 'j' in both cis and t_v incorrectly.
            acc = 0.0
            for p in range(k+1):
                acc += cis[p, j] * t_v[p, j]   # dimensionally consistent
            tmp_val = 1.0 - acc  # since i_row2[j] = 1

            # next_tv_j
            denom = dis[j] + eps
            next_tv_j = tmp_val / denom

            for p in range(k): 
                sum_rewards_chosen += t_v_i[j, p] * reward_vec[selected_items[p]]
            
            sum_rewards_next = t_v_i[j, k] * reward_vec[j]

            # objective
            val = sum_rewards_chosen + sum_rewards_next
            obj[j] = val*val

        # pick next item from objective
        next_item = np.argmax(obj)
        selected_items[k] = next_item
        obj_mask[next_item] = -np.inf
        #sum_rewards_chosen += reward_vec[next_item]
        num_selected += 1

        if num_selected == max_length:
            break

    return selected_items[:num_selected], t_v

class OrthogonalMatchingPursuitAug(): 
    """Compute OMP for a given reward vector and embedding vec of candidates

    Returns:
        selected items 
    """

    def __init__(self, reward_mat, emb_mat):
        self.reward_mat = reward_mat
        self.emb_mat = emb_mat
        #self.max_length = max_length

    def compute_selection(self, max_length, debug=False): 
        """
        Computes a sequence of selected item indices using a greedy update strategy based on the reward and embedding matrices.
        The selection process is performed for a maximum of `max_length` rounds. Initially, the matrix M is set to the reward matrix.
        In each iteration:
            1. The index corresponding to the highest absolute value in M (across all items) is selected.
            2. The selected index is added to the list of selections and removed from further consideration via a selection mask.
            3. The matrix M is updated by subtracting a factor that accounts for the similarity between the selected item's embedding and the embeddings of all items.
            4. The row in M corresponding to the selected item is reset to zero to avoid reselection.
        Parameters:
                max_length (int): The maximum number of items to select.
                debug (bool, optional): If True, prints the selected item and the updated matrix M at each iteration for debugging purposes. Defaults to False.
        Returns:
                list[int]: A list of indices representing the selected items.
        """

        selected_items = []
        select_mask = np.ones((self.reward_mat.shape[0], ), dtype=np.int64)
        M = np.copy(self.reward_mat)
        F = np.copy(self.emb_mat)
        F = np.concatenate((F, np.ones((F.shape[0], 1))), axis=1) / (2 ** (1/2))
        for i in range(max_length): 
            selected_item = np.argmax(M, axis=0)[0]
            select_mask[selected_item] = 0
            selected_items.append(selected_item)
            M -= (M[selected_item] * (F.dot(F[selected_item, :].reshape(-1,1)))) * select_mask.reshape(-1, 1)
            M[selected_item] = -np.inf
            if debug: 
                print("selected item", selected_item)
                print("M", M)
        

        return selected_items

class OrthogonalMatchingPursuit(): 
    """Compute MP for a given reward vector and embedding vec of candidates

    Returns:
        selected items 
    """

    def __init__(
            self, 
            reward_mat, 
            emb_mat, 
            process_reward_args:dict=None
        ):
        
        self.reward_mat =\
            self.process_reward_method(reward_mat, process_reward_args)
        self.emb_mat = emb_mat

    def process_reward_method(self, reward_mat, process_reward_args):
        if process_reward_args["method"] == "flat": 
            return reward_mat.mean(axis=1, keepdims=True)
        elif process_reward_args["method"] == "pca":
            reduce_dim = process_reward_args["reduce_dim"]
            assert reduce_dim > 0, "reduce_dim should be larger than 0"

            U, _, _ = randomized_svd(
                reward_mat, 
                n_components=reduce_dim, 
                random_state=0,
                n_iter=2,
            )
            U = U[:, :reduce_dim]
            reward_mat_reduced = reward_mat.dot(U)
        
            return reward_mat_reduced    
        else: 
            return reward_mat


    def compute_selection(self, max_length, debug=False): 
        """
        Computes a sequence of selected item indices using a greedy update strategy based on the reward and embedding matrices.
        The selection process is performed for a maximum of `max_length` rounds. Initially, the matrix M is set to the reward matrix.
        In each iteration:
            1. The index corresponding to the highest absolute value in M (across all items) is selected.
            2. The selected index is added to the list of selections and removed from further consideration via a selection mask.
            3. The matrix M is updated by subtracting a factor that accounts for the similarity between the selected item's embedding and the embeddings of all items.
            4. The row in M corresponding to the selected item is reset to zero to avoid reselection.
        Parameters:
                max_length (int): The maximum number of items to select.
                debug (bool, optional): If True, prints the selected item and the updated matrix M at each iteration for debugging purposes. Defaults to False.
        Returns:
                list[int]: A list of indices representing the selected items.
        """

        selected_items = []
        select_mask = np.ones((self.reward_mat.shape[0], ), dtype=np.int64)
        M = np.copy(self.reward_mat)
        F = np.copy(self.emb_mat)
        for i in range(max_length): 
            #print("M", np.argmax(np.square(M).sum(-1, keepdims=True)))
            selected_item = np.argmax(np.square(M).sum(-1, keepdims=True), axis=0)[0]
            # print(M, selected_item)
            
            select_mask[selected_item] = 0
            selected_items.append(selected_item)
            M -= (M[selected_item] * (F.dot(F[selected_item, :].reshape(-1,1)))) * select_mask.reshape(-1, 1)
            M[selected_item] = 0
            if debug: 
                print("selected item", selected_item)
                print("M", M)

        return selected_items

    def compute_selection_multi_dim_query(self): 
        selected_items = []
        select_mask = np.ones((self.reward_mat.shape[0], 1), dtype=np.int64)
        M = np.copy(self.reward_mat)
        F = np.copy(self.emb_mat)
        for i in range(self.max_length):
            obj = (np.ones((1, self.reward_mat.shape[0]), dtype=np.int64).dot(M)).reshape(-1, 1) * select_mask
            selected_item = np.argmax(obj ** 2, axis=0)[0]
            select_mask[selected_item] = 0
            M[selected_item] = 0
            selected_items.append(selected_item)
            F -= F[selected_item, :].dot(F.cdot(F[selected_item, :].transpose()))
            M -= M[:, selected_item].dot(self.emb_mat.dot(self.emb_mat[selected_item, :])).reshape(-1, 1)

        return selected_items

    def compute_selection_multi_dim_self_compress(self): 
        selected_items = []
        select_mask = np.ones((self.reward_mat.shape[0], 1), dtype=np.int64)
        M = np.copy(self.reward_mat)
        for i in range(self.max_length):
            obj = (np.ones((1, self.reward_mat.shape[0]), dtype=np.int64).dot(M)).reshape(-1, 1) * select_mask
            selected_item = np.argmax(obj ** 2, axis=0)[0]
            select_mask[selected_item] = 0
            M[selected_item] = 0
            selected_items.append(selected_item)
            M -= self.reward_mat[:, selected_item].dot(self.emb_mat.dot(self.emb_mat[selected_item, :])).reshape(-1, 1)

        return selected_items

class OrthogonalMatchingPursuitSelfProj(OrthogonalMatchingPursuit):
    def __init__(self, reward_mat, emb_mat):
        self.reward_mat = reward_mat
        self.emb_mat = emb_mat

    def compute_selection(self, max_length, debug=False):
        selected_items = []
        select_mask = np.ones((self.reward_mat.shape[0], ), dtype=np.int64)
        M = np.copy(self.reward_mat)
        F = np.copy(self.emb_mat)
        for i in range(max_length): 
            selected_item = np.argmax(np.abs(M), axis=0)[0]
            select_mask[selected_item] = 0
            selected_items.append(selected_item)
            
            M -= (M[selected_item] * (F.dot(F[selected_item, :].reshape(-1,1)))) * select_mask.reshape(-1, 1) / np.linalg.norm(F, axis=1, keepdims=True)
            F -= F * F.dot(F[selected_item, :].reshape(-1,1)) / np.linalg.norm(F, axis=1, keepdims=True)
            #F = F / np.linalg.norm(F, axis=1, keepdims=True)
            M[selected_item] = 0
            if debug: 
                print("selected item", selected_item)
                print("M", M)

        return selected_items

@jit
def compute_selection_multi_dim_query_numba(reward_mat, emb_mat, max_length):
    """
    Numba-accelerated version of the multi-dim query selection.
    
    Parameters
    ----------
    reward_mat : np.ndarray, shape (N, N)
        A square matrix of rewards (float or int).
    emb_mat : np.ndarray, shape (N, D)
        Embeddings for each of the N items (float).
    max_length : int
        Number of items to select.
        
    Returns
    -------
    selected_items : np.ndarray, shape (max_length,)
        Indices of the selected items.
    """
    N = reward_mat.shape[0]
    selected_items = np.empty(max_length, dtype=np.int64)
    
    # Boolean or int mask to keep track of which items are still selectable
    select_mask = np.ones(N, dtype=np.int64)
    
    # Copy reward_mat into M so we can modify M in-place
    M = reward_mat.copy()
    
    for i in range(max_length):
        # --- Compute the sum of columns: sum_of_cols[col] = sum(M[row, col] for row in 0..N-1)
        sum_of_cols = np.zeros(N, dtype=M.dtype)
        for row in range(N):
            for col in range(N):
                sum_of_cols[col] += M[row, col]
        
        # --- Find the column with the highest (sum_of_cols[col] ** 2) subject to select_mask
        best_val = -1e20
        best_idx = -1
        for col in range(N):
            if select_mask[col] == 1:
                val = sum_of_cols[col] * sum_of_cols[col]  # (sum_of_cols[col])^2
                if val > best_val:
                    best_val = val
                    best_idx = col
        
        selected_item = best_idx
        selected_items[i] = selected_item
        
        # Mark this item as no longer selectable
        select_mask[selected_item] = 0
        
        # Zero out the entire row for the selected item in M
        for c in range(N):
            M[selected_item, c] = 0
        
        # --- Update M based on the chosen item
        # We subtract [reward_mat[j, selected_item] * dot(emb_mat[j], emb_mat[selected_item])]
        # from every element M[j, k].
        for j in range(N):
            # First compute the "correction" for row j
            r_val = reward_mat[j, selected_item]
            dot_val = 0.0
            for d in range(emb_mat.shape[1]):
                dot_val += emb_mat[j, d] * emb_mat[selected_item, d]
            correction = r_val * dot_val
            
            # Now subtract it from M[j, :]
            for k in range(N):
                M[j, k] -= correction
    
    return selected_items

class DMIModel(object): 
    """
    Selection model to maximize the mutual information btw 1-d Z_Q and n-d Z_F 
    """
    def __init__(self, kernel_matrix=None, reward_vec=None, emb_mat=None):
        assert not (kernel_matrix is not None and emb_mat is not None), "Only one of kernel_matrix and emb_mat should be provided"

        if emb_mat is not None:
            self.emb_mat = emb_mat
        
        self.kernel_matrix = kernel_matrix
        self.reward_vec = reward_vec
        self.item_size = kernel_matrix.shape[0] if kernel_matrix is not None else emb_mat.shape[0]

    def block_compute_selection(self, block_num=10, max_length=5, numba=False): 
        """Compute selection in blocks

        Keyword Arguments:
            block_num -- number of block to partition (default: {10})

        Returns:
            selected items after block aggregation 
        """
        block_max_len = max_length // block_num
        block_item_size = self.item_size // block_num
        selected_items = []
        t_v = []

        for i in range(block_num): 
            start = i * block_item_size
            end = (i+1) * block_item_size
            if i == block_num - 1: 
                end = self.item_size
                max_length = max_length - start

            if start == end: 
                break

            embed_mat = self.emb_mat[start:end, :]
            reward_vec = self.reward_vec[start:end, :]
            max_length = block_max_len
            kernel_matrix = embed_mat.dot(embed_mat.T)
            if numba:
                block_selected_items, block_t_v = compute_selection_v3_numba(
                    kernel_matrix, reward_vec, max_length)
            else:
                block_selected_items, block_t_v = self.compute_selection(
                kernel_matrix, reward_vec, max_length)
            block_selected_items = [s + start for s in block_selected_items]
            selected_items += block_selected_items
            t_v.append(block_t_v)

        return selected_items, t_v
        

    def compute_selection(self, kernel_matrix, reward_vec, max_length): 
        """Fast greedy algorithm to maximize the mutual information

        Returns:
            list: selected items
        """

        if max_length == 0:
            return [], None
        item_size = kernel_matrix.shape[0]
        cis = np.zeros((max_length, item_size))
        dis = np.copy(np.diag(kernel_matrix)).reshape(-1, 1)  
        t_cis = np.zeros((max_length, item_size))
        t_dis = np.zeros((max_length, 1))
        t_v = np.zeros((item_size, item_size))
        i_row = np.zeros((1, item_size))
        h = np.zeros((max_length, item_size))
        mask = np.ones((item_size, 1))
        obj_mask = np.ones((item_size, ))
        h_i = np.zeros((max_length, 1))

        first_round_obj = 1 / dis * reward_vec
        selected_item = np.argmax(first_round_obj, axis=0)[0]
        selected_items = [selected_item]
        selected_item = selected_items[0]
        while len(selected_items) <= max_length:
            elements = kernel_matrix[selected_item, :]
            k = len(selected_items) - 1
            ci_optimal = cis[:k, selected_item] if k > 0 else np.zeros_like(cis[:k, selected_item])
            di_optimal = dis[selected_item] + 1e-6
 
            if k > 0:
                eis = ((elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal).reshape(-1, 1)
                
            else:
                eis = (elements / di_optimal).reshape(-1, 1)

            i_row[:, k] = 1
            if k == 0:
                t_v[k, :] = i_row / di_optimal
            else:
                t_v[k, :] = (i_row - np.dot(ci_optimal.T, t_v[:k, :])) / di_optimal

            if len(selected_items) == max_length:
                break

            h_i[k, :] = t_v[[k], :(k+1)].dot(reward_vec[selected_items])
            cis[k, :] = eis.squeeze()
            dis = ((dis ** 2) - np.square(eis)) ** (1/2) + 1e-6
            i_row[:, k] = 0
            v_i = cis[:(k+1), :]
            t_v_prev_i = t_v[:(k+1), :]

            i_row[:, (k+1)] = 1
            t_v_i = (i_row - np.dot(v_i.T, t_v_prev_i)) / dis
            i_row[:, (k+1)] = 0
            cur_t_cis = t_v_i[:, :(k+1)]
            cur_t_dis = t_v_i[:, (k+1)]
            obj_mask[selected_item] = -np.inf
            obj = ((cur_t_cis.dot(reward_vec[selected_items])).squeeze() + cur_t_dis * reward_vec.squeeze()) ** 2
            obj = (obj - obj.min() + 1e-6) * obj_mask
            selected_item = np.argmax(obj, axis=0)
            selected_items.append(selected_item)

        return selected_items, t_v 
    

def dpp_sw(kernel_matrix, window_size, max_length, epsilon=1E-10):
    """
    Sliding window version of the greedy algorithm
    :param kernel_matrix: 2-d array
    :param window_size: positive int
    :param max_length: positive int
    :param epsilon: small positive scalar
    :return: list
    """
    item_size = kernel_matrix.shape[0]
    v = np.zeros((max_length, max_length))
    cis = np.zeros((max_length, item_size))
    di2s = np.copy(np.diag(kernel_matrix))
    selected_items = list()
    selected_item = np.argmax(di2s)
    selected_items.append(selected_item)
    window_left_index = 0
    while len(selected_items) < max_length:
        k = len(selected_items) - 1
        ci_optimal = cis[window_left_index:k, selected_item]
        di_optimal = math.sqrt(di2s[selected_item])
        v[k, window_left_index:k] = ci_optimal
        v[k, k] = di_optimal
        elements = kernel_matrix[selected_item, :]
        eis = (elements - np.dot(ci_optimal, cis[window_left_index:k, :])) / di_optimal
        cis[k, :] = eis
        di2s -= np.square(eis)
        if len(selected_items) >= window_size:
            window_left_index += 1
            for ind in range(window_left_index, k + 1):
                t = math.sqrt(v[ind, ind] ** 2 + v[ind, window_left_index - 1] ** 2)
                c = t / v[ind, ind]
                s = v[ind, window_left_index - 1] / v[ind, ind]
                v[ind, ind] = t
                v[ind + 1:k + 1, ind] += s * v[ind + 1:k + 1, window_left_index - 1]
                v[ind + 1:k + 1, ind] /= c
                v[ind + 1:k + 1, window_left_index - 1] *= c
                v[ind + 1:k + 1, window_left_index - 1] -= s * v[ind + 1:k + 1, ind]
                cis[ind, :] += s * cis[window_left_index - 1, :]
                cis[ind, :] /= c
                cis[window_left_index - 1, :] *= c
                cis[window_left_index - 1, :] -= s * cis[ind, :]
            di2s += np.square(cis[window_left_index - 1, :])
        di2s[selected_item] = -np.inf
        selected_item = np.argmax(di2s)
        if di2s[selected_item] < epsilon:
            break
        selected_items.append(selected_item)
    return selected_items

class DMIModelV2(object): 
    """
    Selection model to maximize the mutual information btw 1-d Z_Q and n-d Z_F 
    """
    def __init__(self, kernel_matrix, reward_vec, max_length):
        self.kernel_matrix = kernel_matrix
        self.reward_vec = reward_vec
        self.max_length = max_length
        self.item_size = kernel_matrix.shape[0]
        self.cis = np.zeros((max_length, self.item_size))
        self.dis = np.copy(np.diag(kernel_matrix)).reshape(-1, 1)
        self.t_cis = np.zeros((max_length, self.item_size))
        self.t_dis = np.zeros((max_length, 1))
        self.selected_items = list()
        first_round_obj = 1 / self.dis * self.reward_vec
        self.selected_item = np.argmax(first_round_obj, axis=0)[0] #TODO: checked
        self.selected_items.append(self.selected_item)
        self.t_v = np.zeros((self.item_size, self.item_size))
        self.i_row = np.zeros((1, self.item_size))
        self.h = np.zeros((max_length, self.item_size))
        self.mask = np.ones((self.item_size, 1))
        self.obj_mask = np.ones((self.item_size, 1))
        self.h_i = np.zeros((max_length, 1))

    def compute_selection(self): 
        """Fast greedy algorithm to maximize the mutual information

        Returns:
            list: selected items
        """
        while len(self.selected_items) <= self.max_length: 
            # update V 
            elements = self.kernel_matrix[self.selected_item, :]
            k = len(self.selected_items) - 1
            ci_optimal = self.cis[:k, self.selected_item]
            di_optimal = self.dis[self.selected_item] + 1e-6 

            # [n, 1]
            eis = ((elements - np.dot(ci_optimal, self.cis[:k, :])) / di_optimal).reshape(-1, 1)
            self.i_row[:, k] = 1 
            v_i_optimal = np.concatenate((ci_optimal, di_optimal), axis=-1).reshape(-1, 1)

            # update tilde V by selected item 
            if k == 0: 
                self.t_v[k, :] = self.i_row / di_optimal
            else:
                self.t_v[k, :] = (self.i_row - np.dot(ci_optimal.T, self.t_v[:k, :])) / di_optimal
            
            if len(self.selected_items) == self.max_length:
                break

            # update h_i = \tilde{v}_i^T \tilde{r}_i
            self.h_i[k, :] = self.t_v[[k], :(k+1)].dot(self.reward_vec[self.selected_items])
            self.cis[k, :] = eis.squeeze()
            self.dis = ((self.dis ** 2) - np.square(eis)) ** (1/2) + 1e-6
            self.i_row[:, k] = 0
            
            # iterating tilde V for k+1 round 
            # [k, n]
            v_i = self.cis[:(k+1),:]
            t_v_prev_i = self.t_v[:(k+1), :]
            
            # [n, n]
            self.i_row[:, (k+1)] = 1 
            t_v_i = (self.i_row - np.dot(v_i.T, t_v_prev_i)) / self.dis #TODO: checked 
            self.i_row[:, (k+1)] = 0
            
            # [n, i]
            not_selected = np.setdiff1d(np.arange(self.item_size), self.selected_items)
            self.mask[not_selected] = 0

            # [n, 1]
            self.h[k, :] = (np.dot(t_v_i, self.reward_vec) * self.mask).squeeze()

            # update selected item 
            cur_t_cis = t_v_i[:, :(k+1)]
            cur_t_dis = t_v_i[:, (k+1)]

            # optimize objective vector 
            # [n, 1]
            self.obj_mask[self.selected_item] = -np.inf
            # obj = (self.reward_vec ** 2) * (t_v_i ** 2).sum(-1,keepdims=True) +\
            #         2 * self.reward_vec * (cur_t_cis.dot(self.h_i[:(k+1)])) 
            obj = (cur_t_cis.dot(self.reward_vec[self.selected_items]) +\
                    cur_t_dis * self.reward_vec) ** 2 

            obj = (obj - obj.min() + 1e-6) * self.obj_mask
            selected_item = np.argmax(obj, axis=0)[0]

            self.selected_item = selected_item
            self.selected_items.append(selected_item)

        return self.selected_items, self.t_v 

if __name__ == '__main__': 
    np.random.seed(42)
    M = 50
    N = 50000
    D = 128
    max_len = 3
    
    reward = np.random.rand(N, N).astype(np.float64)
    emb = np.random.rand(N, D).astype(np.float64)

    # test the parallel version 
    process_reward_args = {
        "method": "pca",
        "reduce_dim": 4
    }
    omp = OrthogonalMatchingPursuit(
        reward, emb, process_reward_args=process_reward_args)
    selection = omp.compute_selection(max_len)
