function result = kron_omp_refined(D, Y, tolerance_ratio)

N = numel(D);
tensor_dims = size(Y);
y_vec = vec(Y);
norm_y = norm(y_vec);
max_iter = prod(tensor_dims);

% Normalize input dictionaries
col_norms_inv_vecs = cell(1, N);
for i = 1:N
    col_norms = vecnorm(D{i});
    col_norms(col_norms < eps) = 1;
    D{i} = D{i} ./ col_norms;
    col_norms_inv_vecs{i} = col_norms.^-1;
end

residual_tol_abs = tolerance_ratio * norm_y;

R = Y; % Residual tensor
support_indices = zeros(max_iter, N);
W = cell(1, N); % {W1, W2, ...} stores selected atoms
KR = [];        % Khatri-Rao product of W matrices
Z_inv = [];     % Inverse of Gram matrix
y_proj = [];    % Projection (KR)' * y
coeffs = [];    % The coefficient vector 

tic;

for k = 1:max_iter
    
    % Select next atom that maximizes correlation
    C = abs(tmprod(R, transp(D), 1:N));
    [~, linind] = max(C(:));
    
    new_idx_cell = cell(1, N);
    [new_idx_cell{:}] = ind2sub(size(C), linind);
    new_idx = cell2mat(new_idx_cell);

    % Update support and prepare new atoms/vectors
    support_indices(k, :) = new_idx;
    
    w_new = cell(1, N);
    for n = 1:N
        w_new{n} = D{n}(:, new_idx(n));
    end
    
    % Khatri-Rao product
    kr_new_col = khatri_rao_vec(w_new);
    
    % Low-complexity update of Z_inv and coefficients 
    if k == 1
        Z_inv = 1;
        y_proj = kr_new_col' * y_vec;
    else
        % Compute vector b
        b = ones(k-1, 1);
        for n = 1:N
            b = b .* (W{n}' * w_new{n});
        end
        
      
        schur_comp = 1 - b' * Z_inv * b;
        if abs(schur_comp) < 1e-12
            k = k - 1;
            break;
        end
        
        % Update inv(Z) using the Schur complement
        d_vec = -Z_inv * b;
        c_inv = 1 / schur_comp;
        Z_inv = [Z_inv + c_inv * (d_vec * d_vec'), c_inv * d_vec;
                 c_inv * d_vec',                     c_inv];
        
        % Update the projected y vector
        y_proj = [y_proj; kr_new_col' * y_vec];
    end
    
    for n = 1:N
        W{n} = [W{n}, w_new{n}];
    end
    KR = [KR, kr_new_col];

    % Compute the new coefficients 
    coeffs = Z_inv * y_proj;
    
    % Update residual R
    R_vec = y_vec - KR * coeffs;
    R = reshape(R_vec, tensor_dims);
    
    residual_norm = norm(R(:));
    
    if residual_norm < residual_tol_abs
        break;
    end
end

time_total = toc;

support_indices = support_indices(1:k, :);
coeffs = coeffs(1:k);

% De-normalize the coefficients
final_coeffs = coeffs;
for i = 1:k
    norm_prod = 1;
    for n = 1:N
        norm_prod = norm_prod * col_norms_inv_vecs{n}(support_indices(i, n));
    end
    final_coeffs(i) = final_coeffs(i) * norm_prod;
end

% Create the final sparse tensor
dims = cellfun(@(d) size(d, 2), col_norms_inv_vecs);
X = zeros(dims);
for i = 1:k
    subs = num2cell(support_indices(i, :));
    X(subs{:}) = final_coeffs(i);
end

result = cell(3,1);
result{1} = X;  % recovered sparse tensor
result{2} = time_total;  % total computation time
result{3} = support_indices;  % support indices for non-zero elements
end

function K = khatri_rao_vec(A)
% Computes the Khatri-Rao product of a cell array of column vectors
I = length(A);
ncol = size(A{1},2);
K = ones(1,ncol);
for j = 1:I
    K = kr(K, A{j});
end
end

function K = kr(A, B)
[m,k] = size(A);
[n,k2] = size(B);
assert(k==k2,'Both factors must have same #columns');
K = zeros(m*n,k);
for j = 1:k
    K(:,j) = kron( B(:,j), A(:,j) );
end
end

function [D] = transp(D)
N = size(D,2);
for n = 1:N
    D{n} = D{n}';
end
end