# Find a basis in the feature space.
# This implementation relies on https://www.esann.org/sites/default/files/proceedings/legacy/es2002-15.pdf
# "An Formation of a Basis in a Kernel Induced Feature Space" by Talbot and Cawley, 2002.

import torch

def find_basis_points(X, kernel_func, max_basis_vectors, num_candidates, error_thres = 0.01, num_eval_cand=200):
    """ Apply the greedy algorithms to find a suitable basis in the feature space. 
        :param X: input data set
        :param num_candidates: How many candidates for basis vectors should be sampled in each greedy iteration
        :param error_thres: Tolerated reconstruction error in range [0, 1] (0.01 means 1% error is tolerated.)
    """
    basis_point_idx_list = torch.tensor([], dtype=torch.long)
    candidate_binary = torch.ones(len(X), dtype=torch.long)

    while True:
        candidate_idx = torch.nonzero(candidate_binary).flatten() # Set C 
        #print("Shape of candidate set C", len(candidate_idx))
        if len(candidate_idx) == 0:
            #print("Basis found. No more points to consider.")
            break
        r_cand = torch.randperm(torch.sum(candidate_binary).item())[:num_candidates] # Candidate seletion set.
        r_cand_id = candidate_idx[r_cand]
        eval_cand = torch.randperm(torch.sum(candidate_binary).item())[:num_eval_cand] # Evaluation set
        eval_cand_id = candidate_idx[eval_cand]
        #print("Candidate IDs:", r_cand_id)
        JS = torch.zeros(len(r_cand), len(eval_cand)) # Matrix of reconstruction error adding point in i for evaluation point j
        
        for i in range(len(r_cand)):
            basis_point_idx_list_ext = torch.cat((basis_point_idx_list, r_cand_id[i].reshape(1)))
            Xselect = X[basis_point_idx_list_ext]
            # If the same data point is in the train set twice, the kernel matrix can become singular.
            # This will make it impossible to invert it and compute the reconstruction error.
            if torch.abs(torch.det(kernel_func(Xselect, Xselect))) < 1e-7:
                JS[i, :] = 0
                # Eliminate this point from candidate set.
                candidate_binary[r_cand_id[i]] = 0
            else:
                #print(basis_point_idx_list_ext)
                JS[i, :] = compute_reconstruction_errors(X, basis_point_idx_list_ext, eval_cand_id, kernel_func)
                #print(JS[i,:].mean())
        
        val, idx = torch.max(JS.mean(dim=1), dim=0) # Find point with the lowest reconstruction error
        # This case occurs if all points considered make the kernel matrix singular.
        if val.item() == 0.0:
            continue
        #print(f"Adding point {r_cand_id[idx.item()]} with (neg.) reconstruction error {val.item()}")
        # Exclude points with a low error threshold.
        dsi = 1-JS[idx, :] # errors on test set.
        #print(dsi.shape)
        # # Set all points with errors smaller then error_thres to 0.
        candidate_binary[eval_cand_id] = (dsi > error_thres).long() 
        candidate_binary[r_cand_id[idx.item()]] = 0 # remove from candidate list.
        basis_point_idx_list = torch.cat((basis_point_idx_list, r_cand_id[idx].reshape(1)))
        if len(basis_point_idx_list) == max_basis_vectors:
            print("Maximum number of basis vectors reached.")
            break
        if val > 1-error_thres:
            print(f"Basis found. Below error threshold")
            break
    print(f"N={len(basis_point_idx_list)}, Reconstruction error {1-val}.")
    return basis_point_idx_list

def compute_reconstruction_errors(X, basis_point_list, test_point_list, kernel_func):
    """
        Compute J(S) over the test_point_list (Monte-Carlo-Estimate)
        basis_point_list: Index tensor with points already in basis. Length (N)
        test_point_list: Index tensor with points for which the reconstruction error should be estimated. Length (L)
    """
    K_si = kernel_func(X[basis_point_list], X[test_point_list]) # (N, L)
    K_ii = torch.diag(kernel_func(X[test_point_list], X[test_point_list])) # (L) inefficient
    K_ssinv = compute_inverse(X, basis_point_list, kernel_func) # (N,N)
    res = torch.sum(K_si * K_ssinv.matmul(K_si), dim=0)/K_ii # (N, L)*(N, L) - > (L)
    return res

def compute_inverse(X, basis_point_list, kernel_func):
    """ Non optimized inverse computation. """
    Xselect = X[basis_point_list]
    K = kernel_func(Xselect, Xselect)
    #print(K)
    #print(Xselect)
    return torch.inverse(K)

