function result = tenMulReIHT(Y, A, N, M, K)
    I = numel(N);
    R = N;  
    T = Y; 
    support_indices = cell(I, 1);

    tic;
    for d = I:-1:1
        % unfold T along mode-d
        T_d = tens2mat(T, d); % using tensor toolbox, or implement own unfolding
        temp = iht_mmv(A{d},T_d,K(d));
        Z_concat = temp{1};

        % concatenate and prune
        S = sparse(Z_concat);
        nz_rows = find(sum(abs(S), 2) > 1e-6);
        R(d) = numel(nz_rows);
        M(d) = numel(nz_rows);
        support_indices{d} = nz_rows;
        Z_pruned = Z_concat(nz_rows, :);

        % fold back to tensor
        T = fold(Z_pruned, d, M);
    end

    X_hat = zeros(N.');
    idx = cell(I,1);
    for d = 1:I
        idx{d} = support_indices{d};
    end
    X_hat(idx{:}) = T;

    time_total = toc;

    result = cell(3,1);
    result{1} = X_hat;  % recovered sparse tensor
    result{2} = time_total;  % total computation time
    result{3} = support_indices;  % support indices for non-zero elements
end

% fold: fold matrix back to tensor
function T = fold(X_mat, mode, sz)
    order = [mode, 1:mode-1, mode+1:numel(sz)];
    perm_order = [2:mode,1,mode+1:numel(sz)];
    dim_temp = order;
    dim_temp((order==mode)) = [];
    T = permute(reshape(X_mat, [sz(mode), sz(dim_temp).']), perm_order);
end