function result = tenMulReIHT(Y, A, N, M, K)
    I = numel(N);
    R = N; 
    T = Y; 
    support_indices = cell(I, 1);

    tic;
    for d = I:-1:1
        % unfold T along mode-d
        T_d = tens2mat(T, d); % using tensor toolbox, or implement own unfolding

        % determine number of blocks
        num_blocks = prod(R((d+1):end));
        block_size = size(T_d, 2) / num_blocks;

        % solve CS for each block sequentially
        Z_blocks = cell(1, num_blocks);
        for j = 1:num_blocks
            cols = (j-1)*block_size + 1 : j*block_size;
            Y_block = T_d(:, cols);
            temp = iht_mmv(A{d},Y_block,K(d));
            Z_blocks{j} = temp;
        end

        % concatenate and prune
        Z_concat = cat(2, Z_blocks{:});
        S = sparse(Z_concat);
        nz_rows = find(sum(abs(S), 2) > 0);
        R(d) = numel(nz_rows);
        M(d) = numel(nz_rows);
        support_indices{d} = nz_rows;
        Z_pruned = Z_concat(nz_rows, :);

        % fold back to tensor
        T = fold(Z_pruned, d, M);
    end

    X_hat = zeros(N.');
    idx = cell(I,1);
    for d = 1:I
        idx{d} = support_indices{d};
    end
    X_hat(idx{:}) = T;

    time_total = toc;

    result = cell(3,1);
    result{1} = X_hat;  % recovered sparse tensor
    result{2} = time_total;  % total computation time
    result{3} = support_indices;  % support indices for non-zero elements
end

% fold: fold matrix back to tensor
function T = fold(X_mat, mode, sz)
    order = [mode, 1:mode-1, mode+1:numel(sz)];
    perm_order = [2:mode,1,mode+1:numel(sz)];
    dim_temp = order;
    dim_temp((order==mode)) = [];
    T = permute(reshape(X_mat, [sz(mode), sz(dim_temp).']), perm_order);
end