function X = omp_mmv(D, Y, tol_coef)
    % initialize
    [m,n] = size(D);
    [~,L] = size(Y);
    R = Y;                 
    X = zeros(n, L);       
    k = 0;
    idx = [];
    tol = tol_coef * norm(vec(Y));
    col_norm = vecnorm(D);
    D = D./ col_norm;

    if norm(R,'fro') == 0
        return
    end

    while norm(R,'fro') > tol
        k = k + 1;

        % joint‐correlation across all measurement vectors
        corr = abs((D' * R));               % n×L
        row_norm = sum(abs(corr.^2), 2);   % n×1
        [~, j] = max(row_norm);
        idx = [idx,j];

        % solve least squares on the restricted dictionary
        Ds = D(:, idx);         % m×k

        % using Schur complement for the matrix pseudo inversion
        if k == 1
            Z_inv = 1 / 1;
            temp = Ds' * Y;
            Xs = Z_inv * temp;
        else
            u = Ds(:,1:(end-1))'*Ds(:,end);

            zinvu = Z_inv * u;

            alpha = 1 - u' * zinvu;

            if abs(alpha) < 1e-4
                fprintf('Warning: New atom is nearly linearly dependent.');
                break;
            end

            Z_inv = [Z_inv + zinvu*zinvu'/alpha, - zinvu/alpha;
                -zinvu'/alpha, 1/alpha];

            temp = [temp; Ds(:,end)'*Y];

            Xs = Z_inv * temp;
        end
        
        % Xs = (Ds' * Ds) \ Ds' * Y; % for debugging purpose
        % update residual
        R = Y - Ds * Xs;
    end

    % scatter Xs back into the full X
    if exist("Xs", "var")
        X = zeros(n,L);
        X(idx, :) = Xs;
        X = X./col_norm';
    end
end
