function x_re = iht_mmv(A, y, k)

    [~, n] = size(A);
    [~, L] = size(y);

    x_hat = zeros(n, L);
    max_iter = 200;

    const = 1.9 / (norm(A)^2);

    for iter = 1:max_iter
        x_prev = x_hat;

        residual = y - A * x_hat;
        x_hat = x_hat + const * A' * residual;

        [~, indices] = sort(sum(abs(x_hat).^2,2), 'descend');
        
        x_hat(indices((k+1):end),:) = 0;

        if norm(x_hat - x_prev,'fro') / norm(x_hat,'fro') < 1e-3
            break;
        end
    end

    x_re = x_hat;
end