function result = tenMulReOMP(Y, A, N, M, tol_coef)
I = numel(N);
R = N;
T = Y;
support_indices = cell(I, 1);

tic;
for d = I:-1:2
    % unfold T along mode-d
    T_d = tens2mat(T, d);

    % determine number of blocks
    num_blocks = prod(R((d+1):end));
    block_size = size(T_d, 2) / num_blocks;

    % solve CS for each block sequentially
    Z_blocks = cell(1, num_blocks);
    for j = 1:num_blocks
        cols = (j-1)*block_size + 1 : j*block_size;
        Y_block = T_d(:, cols);
        Z_blocks{j} = omp_mmv(A{d},Y_block,tol_coef);
    end

    % concatenate and prune
    Z_concat = cat(2, Z_blocks{:});
    S = sparse(Z_concat);
    nz_rows = find(sum(abs(S), 2) > 1e-6);
    R(d) = numel(nz_rows);
    M(d) = numel(nz_rows);
    support_indices{d} = nz_rows;
    Z_pruned = Z_concat(nz_rows, :);

    % fold back to tensor
    T = fold(Z_pruned, d, M);
end

d = 1;
% unfold T along mode-d
T_d = tens2mat(T, d);

% determine number of blocks
num_blocks = prod(R((d+1):end));
block_size = size(T_d, 2) / num_blocks;

% solve CS for each block sequentially
Z_blocks = cell(1, num_blocks);
for j = 1:num_blocks
    cols = (j-1)*block_size + 1 : j*block_size;
    Y_block = T_d(:, cols);
    Z_blocks{j} = omp_mmv(A{d},Y_block,0.1);
end

% concatenate and prune
Z_concat = cat(2, Z_blocks{:});
S = sparse(Z_concat);
nz_rows = find(sum(abs(S), 2) > 1e-6);
R(d) = numel(nz_rows);
M(d) = numel(nz_rows);
support_indices{d} = nz_rows;
Z_pruned = Z_concat(nz_rows, :);

% fold back to tensor
T = fold(Z_pruned, d, M);

% initialize full tensor and insert T
X_hat = zeros(N);
idx = cell(I,1);
for d = 1:I
    idx{d} = support_indices{d};
end
X_hat(idx{:}) = T;

time_total = toc;

result = cell(3,1);
result{1} = X_hat;  % recovered sparse tensor
result{2} = time_total;  % total computation time
result{3} = support_indices;  % support indices for non-zero elements
end

% fold: fold matrix back to tensor
function T = fold(X_mat, mode, sz)
order = [mode, 1:mode-1, mode+1:numel(sz)];
perm_order = [2:mode,1,mode+1:numel(sz)];
dim_temp = order;
dim_temp((order==mode)) = [];
T = permute(reshape(X_mat, [sz(mode), sz(dim_temp).']), perm_order);
end