function [L, kIdx1, theta] = construct_knn_graph(Y_cpu, kn, Normal_cpu, M, varargin)
    % Construct KNN graph with normal similarity and Gaussian-weighted Laplacian

    % Step 1: KNN Search
    kdtree = KDTreeSearcher(Y_cpu');
    [kIdx, dist] = knnsearch(kdtree, Y_cpu', 'K', kn);
    kIdx1 = kIdx;                  % Store full KNN indices
    kIdx = kIdx(:, 2:kn);          % Remove self-neighbor
    dist = dist(:, 2:kn);

    % Step 2: Normal similarity computation
    idx_flat = kIdx(:);                                           % (M*(kn-1)) x 1
    normal_neighbors = Normal_cpu(idx_flat, :);                   % (M*(kn-1)) x 3
    normal_neighbors = reshape(normal_neighbors, M, kn-1, 3);     % M x (kn-1) x 3
    normal_self = reshape(Normal_cpu, [M, 1, 3]);                 % M x 1 x 3

    % Euclidean distance of normals
    diff_n = normal_self - normal_neighbors;                      % M x (kn-1) x 3
    dist_n = sqrt(sum(diff_n.^2, 3));                             % Normal distance

    % Step 3: Gaussian weights with optional theta override
    flag = cellfun(@isequal, varargin, repmat({'theta'}, size(varargin)));
    if any(flag)
        idx = circshift(flag, 1);
        theta = varargin{idx};
        theta1 = theta(1);
        theta2 = theta(2);
    else
        theta1 = median(dist_n(:));
        theta2 = median(dist(:));
    end

    % Gaussian kernels
    dist_n_ker = exp(-dist_n.^2 / (2 * theta1^2));
    W_geo = exp(-dist.^2 / (2 * theta2^2));
    W = W_geo + dist_n_ker;

    % Step 4: Normalize weights
    W(:, 1) = max(W(:, 2:end), [], 2);        % Ensure the first neighbor has max weight
    W = W ./ (sum(W, 2) + eps);               % Row normalization
    W = W + 0.001;                            % Prevent very small values
    W(:, 1) = max(W(:, 2:end), [], 2);        % Re-ensure max first neighbor
    W = W ./ (sum(W, 2) + eps);               % Re-normalize

    % Step 5: Construct sparse adjacency matrix
    row_idx = repmat((1:M)', [1, kn-1]);      
    Adj = sparse(row_idx(:), kIdx(:), double(W(:)), M, M); 
    Adj = Adj ./ (sum(Adj, 2) + eps);         % Row normalization

    % Step 6: Construct Laplacian matrix (L = D - A)
    L = spdiags(sum(Adj, 2), 0, M, M) - Adj;

    theta = [theta1; theta2];
end
