function x_re = iht_mmv(A, y, k)
[~, n] = size(A);
[~, L] = size(y);

x_hat = zeros(n, L);

max_iter = 200;

const = 1.9 / (norm(A)^2);

for iter = 1:max_iter
    x_prev = x_hat;

    residual = y - A * x_hat;
    x_hat = x_hat + const * A' * residual;

    [~, indices] = sort(sum(abs(x_hat).^2,2), 'descend');

    x_hat(indices((k+1):end),:) = 0;

    if norm(x_hat - x_prev,'fro') / norm(x_hat,'fro') < 1e-6
        break;
    end
end
x_re = x_hat;
end