function [tform, sigma2, loglikelihood, iter] = GraphCPD_opti(source, target, varargin)
%
% Inputs:
%   source: source point cloud
%   target: target point cloud
%   varargin: optional parameters (see below)
%
% Outputs:
%   tform: estimated transformation
%   P: probability matrix
%   sigma2: estimated variance
%   loglikelihood: log likelihood value
%   iter: number of iterations performed
 
% --------------------------- Parameter Parsing ---------------------------
% Parse optional input parameters with default values
 
flag = cellfun(@isequal, varargin, repmat({'maxIter'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    parm.maxIter = varargin{idx};
else
    parm.maxIter = 50;
end

flag = cellfun(@isequal, varargin, repmat({'tolerance'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    parm.tolerance = varargin{idx};
else
    parm.tolerance = 1e-3;
end
 
flag = cellfun(@isequal, varargin, repmat({'truncationThreshold'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    parm.truncate_threshold = varargin{idx};
else
    parm.truncate_threshold = 0.19;
end

flag = cellfun(@isequal, varargin, repmat({'optimizationIter'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    parm.opti_maxIter = varargin{idx};
else
    parm.opti_maxIter = 2;
end

flag = cellfun(@isequal, varargin, repmat({'optimizationTolerance'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    parm.opti_tolerance = varargin{idx};
else
    parm.opti_tolerance = 1e-3;
end

flag = cellfun(@isequal, varargin, repmat({'neighbours'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    parm.neighbours = varargin{idx};
else
    parm.neighbours = 10;
end

% --------------------------- Data Loading -------------------------------
flag = cellfun(@isequal, varargin, repmat({'dataType'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    if strcmp(varargin{idx}, 'array')
        parm.output = 1;
        X = source;
        Y = target;
    else
        if strcmp(varargin{idx}, 'pointCloud')
            parm.output = 0;
            X = gpuArray(source.Location); % source points
            Y = gpuArray(target.Location); % target points
        else
            error('Input data type unsupported.');
        end
    end
else
    parm.output = 0;
    X = gpuArray(source.Location); % source points
    Y = gpuArray(target.Location); % target points
end

% ---------------------- Confidence Filtering -----------------------------
% Compute normals and curvature
N = size(X,1);
M = size(Y,1);
[Normal_cpu] = findPointNormals(gather(Y), parm.neighbours);
[Normal_cpu_x] = findPointNormals(gather(X), parm.neighbours);

% ------------------ Centroid Transformation ------------------------------
% Transform points to centroid if enabled
flag = cellfun(@isequal, varargin, repmat({'xform2center'}, size(varargin)));
parm.mean_xform = 0;
if any(flag)
    idx = circshift(flag,1);
    if strcmp(varargin{idx}, 'true')
        parm.mean_xform = 1;
        xmean = mean(X);
        ymean = mean(Y);
        X = X - xmean;
        Y = Y - ymean;
        xmean = gather(xmean);
        ymean = gather(ymean);
    end
end
 
% Prepare data matrices (transposed for calculations)
Y = Y';
Y_cpu = gather(Y);
X = X';
X_cpu = gather(X);

% Scale normals according to point cloud extent
maxXY = median([ (abs(X_cpu(:))); (abs(Y_cpu(:)))]);
Normal_cpu =  Normal_cpu.*maxXY;
Normal_cpu_x = Normal_cpu_x.*maxXY;

% ----------------- Construct Target Point Features -----------------------
kn = 8;
[LY, kIdx1, theta] = construct_knn_graph(Y_cpu, kn, Normal_cpu, M); 
LY = sparse(LY);
Y_cpu = double(Y_cpu);
Normal_cpu = double(Normal_cpu);

% -----------------Compute multi-scale geometric features-----------------
yhi_1 = LY* (Y_cpu');
yhi_2 = yhi_1 + LY*yhi_1;
y_high_1 = sqrt(sum(yhi_1.^2, 2));  % First-order feature
y_high_2 = sqrt(sum((LY* (Normal_cpu)).^2, 2)); % Second-order feature
y_high_3 = sqrt(sum(yhi_2.^2, 2));

g_y = [y_high_1, y_high_2, y_high_3];
% graph signal
Y_long = [Y_cpu', Normal_cpu, g_y]';
  
Y_long_gpu = single(gpuArray(Y_long));
kIdx1_gpu = gpuArray(kIdx1);
epsilon = 1e-8;

% ----------------- Construct Local Covariance Matrices -------------------
Y_kn_all = (Y_long_gpu(:, kIdx1_gpu')); % [9 x kn*M]
Y_kn_all = reshape(Y_kn_all, 9, kn, M); % [9 x kn x M]

% Compute pairwise distances
A = Y_kn_all; % [9 x kn x M]
A_t = permute(A, [2 1 3]); % [kn x 9 x M]
G = pagemtimes(A, A_t); % [9 x 9 x M]
norms = sum(A.^2, 2); % [9 x 1 x M]
norm_i = reshape(norms, [9 1 M]); % [9 x 1 x M]
norm_j = reshape(norms, [1 9 M]); % [1 x 9 x M]
sq_dist = norm_i + norm_j - 2*G; % [9 x 9 x M]

% Compute similarity and normalized Laplacian
sq_vec = gather(abs(sq_dist(:)));
ss = sort(sq_vec);
radius = ss(ceil(2*M*81/10));
W = exp(-sq_dist./(2*radius )); % [9 x 9 x M]
I9 = eye(9, 'gpuArray');
W = W - reshape(I9, 9, 9, 1); % Remove diagonal

% Normalized symmetric Laplacian
D_inv_sqrt = 1./sqrt(sum(abs(W), 2) + epsilon); % [9 x 1 x M]
D1 = D_inv_sqrt;
D2 = permute(D_inv_sqrt, [2 1 3]); % [1 x 9 x M]
norm_W = D1.*W.*D2; % [9 x 9 x M]

L_sigma3_gpu = (1.1*I9 - norm_W); % Final Laplacian
L_sigma3_gpu = single(L_sigma3_gpu);
L_sigma3 = gather(L_sigma3_gpu);

% ----------------- Construct Source Point Features -----------------------
[LX] = construct_knn_graph(X_cpu, kn, Normal_cpu_x, N, 'theta',theta);
LX = sparse(LX); 
X_cpu = double(X_cpu);
Normal_cpu_x = double(Normal_cpu_x);

% Compute multi-scale geometric features for source
xhi_1 = LX* (X_cpu');
xhi_2 = xhi_1 + LX*xhi_1;
x_high_1 = sqrt(sum(xhi_1.^2, 2));  % First-order feature
x_high_2 = sqrt(sum((LX* (Normal_cpu_x)).^2, 2)); % Second-order feature
x_high_3 = sqrt(sum((xhi_2).^2, 2));

g_x = ([x_high_1, x_high_2, x_high_3]);
X_long = [X_cpu', Normal_cpu_x, g_x]';

% ----------------- Initial Correspondence Estimation ---------------------
k_match = 50;  %KITTI
kdtree = KDTreeSearcher(g_y);
[minIdx, ~] = knnsearch(kdtree, g_x, 'K', k_match);
minIdx = gpuArray(minIdx');  
 
weidu = 3;
% Calculate volume of bounding cube for outlier estimation
V = (max(Y(1,:))-min(Y(1,:))) * (max(Y(2,:))-min(Y(2,:))) * (max(Y(3,:))-min(Y(3,:)));

% Initialize sigma2 if not provided
flag = cellfun(@isequal, varargin, repmat({'sigma2'}, size(varargin)));
if any(flag)
    idx = circshift(flag,1);
    sigma2 = varargin{idx};
else
    sigma2 = 0;
    for i = 1:3
        sigma2 = sigma2 + sum(sum((X(i,:)'-Y(i,:)).^2));
    end
    sigma2 = sigma2/(3*M*N);
end

% ----------------- Pre-calculations --------------------------------------
% Skew-symmetric matrices for SE(3) optimization
E1 = [0 0 0; 0 0 -1; 0 1 0];
E2 = [0 0 1; 0 0 0; -1 0 0];
E3 = [0 -1 0; 1 0 0; 0 0 0];
E1X_cpu = gather(E1*X);
E2X_cpu = gather(E2*X);
E3X_cpu = gather(E3*X);

% ----------------- Weight Calculations -----------------------------------
%% det
det_vec = zeros(M,1);
for i = 1:M
    det_vec(i) = (sqrt(det(L_sigma3(:,:,i))));  
end

% Outlier weight estimation
wn = V/M*reshape(ones(M,1),1,[])*single((2*pi*sigma2)^(-weidu/2).* det_vec); 
F_matrix = ((1-wn)./wn)*(ones(1,N).*det_vec);


% Precompute terms for E-step
invSigma_flatten_const = zeros(M,9); % M x 9
y_invSigma_const = zeros(M,3); % M x 3
y_invSigma_y_const = zeros(M,1);
 
for m = 1:M
    invSigma_const_m = L_sigma3(1:3,1:3,m);
    invSigma_flatten_const(m,:) = reshape(invSigma_const_m,1,[]);
    y_invSigma_const(m,:) = Y_cpu(:,m)'*invSigma_const_m;
    y_invSigma_y_const(m) = Y_cpu(:,m)'*invSigma_const_m*Y_cpu(:,m);
end
 
invSigma_flatten_const = gpuArray(single(invSigma_flatten_const)); 
y_invSigma_const = gpuArray(single(y_invSigma_const)); 
y_invSigma_y_const = gpuArray(single(y_invSigma_y_const)); 



n_indices = gpuArray(single(repelem(1:N,k_match)')); 
m_indices =   minIdx(:) ;
linear_indices = sub2ind([M,N], m_indices, n_indices); 
F_matrix_index = F_matrix(linear_indices) ;
F_matrix_index = F_matrix_index(:);

%   
threads_per_block = 256;  
num_blocks = ceil(numel(m_indices) / threads_per_block);  

% kernrl
kernel  = parallel.gpu.CUDAKernel('acc_kernel.ptx', 'acc_kernel.cu','accumArraySpecific');
kernel.ThreadBlockSize = threads_per_block;
kernel.GridSize = num_blocks;
output_gpu = gpuArray.zeros(M, N, 'single');


% --------------------------- EM Process ----------------------------------
Y_long_gpu = gpuArray(Y_long);

iter = 0;
loglikelihood = 0;
R = eye(3); % Initial rotation
t = [0; 0; 0]; % Initial translation
n_indices = gpuArray(single(repelem(1:N,k_match)'));  % (N*k)x1

while iter <= parm.maxIter
    loglikelihood_prev = loglikelihood;

    % ----------------------------- E-step --------------------------------
    C = (2*pi*sigma2)^(weidu/2)*(1/V);
    c = -1/(2*sigma2);

    R_big = blkdiag(R, R, eye(3,3));
    RX_big = gpuArray(R_big*X_long);
    t_big = [t; zeros(3,1); zeros(3,1)];

    [P , M_0 , M_1, M_2 ] = E_step_optimized(Y_long_gpu, RX_big, t_big, F_matrix_index, M, N, C, c, ...
        invSigma_flatten_const, y_invSigma_const, y_invSigma_y_const, L_sigma3_gpu, m_indices, n_indices, k_match, kernel, output_gpu);

    [R, t] = NewtonSE3(R, t, M_0, M_1, X_cpu, N, E1X_cpu, E2X_cpu, E3X_cpu, ...
                      parm.opti_maxIter, parm.opti_tolerance);

    % ----------------------------- M-step --------------------------------
    iter = iter + 1;
    
    % ----------------- Convergence Checking ------------------------------
    [loglikelihood, sigma2] = Shrink_step(R, t, X_cpu, P, M_0, M_1, M_2, N,weidu);
    if abs(loglikelihood-loglikelihood_prev)/loglikelihood < parm.tolerance || loglikelihood < 1e-5
        break
    end
end

% ----------------- Final Transformation ----------------------------------
if parm.mean_xform == 1
    t = t + ymean' - (R*xmean');
end

% Return appropriate transform type based on output flag
if parm.output == 1
    tform = [R, t; 0 0 0 1];
else
    tform = rigid3d(R', t');
end
end

% ======================== Utility Functions =============================
function [P, M_0, M_1, M_2] = E_step_optimized(Y_long, RX_big, t_big, F_matrix_index, M, N, C, c, ...
    invSigma_flatten_const, y_invSigma_const, y_invSigma_y_const, L_sigma3, m_indices, n_indices, k, kernel, output_gpu )
   
    GX =  (RX_big + t_big);  % (9×N)  
    dd_reshaped = reshape(reshape(GX(:, n_indices) - Y_long(:, m_indices), 9, 1, []), 9, []);
    dd_reshaped(4:6,:,:)  = min(abs(GX(4:6, n_indices) - Y_long(4:6, m_indices)),abs(GX(4:6, n_indices) + Y_long(4:6, m_indices)));

    quad_forms_vector = squeeze(sum(sum(L_sigma3(:, :, m_indices) .* (repmat(reshape(dd_reshaped, 9, 1, []), [1, 9, 1])) .* (repmat(reshape(dd_reshaped, 1, 9, []), [9, 1, 1])), 1), 2));

    P = F_matrix_index .* exp(c* quad_forms_vector);  % (N*k)x1

    exp_matrix = reshape(P,[k,N]);
    
    denominator = C + sum(exp_matrix,1);  % (1×N)
    P  = exp_matrix./denominator;  % (k×N)

    P  = feval(kernel, m_indices,   P , output_gpu,  M, N,k);
    %    
    %% ---------- M_0: invSigma_flatten_const -------------
    M_0_flatten = P'  * invSigma_flatten_const ;
    M_0 = gather(reshape(M_0_flatten', 3, 3, N));
    M_1 = gather(P' * y_invSigma_const ); 
    M_2 = gather(P' * y_invSigma_y_const);

    
end
 
function [g_gradient, H] = GradientSE3(R, t, M_0, M_1, X, N, E1X, E2X, E3X)
% GradientSE3 - Compute gradient and Hessian for SE(3) optimization

    M_1_flatten = reshape(M_1',1,3*N);
    gX_flatten = reshape(R*X + t,3*N,1);

    % Compute transformed points for each basis
    g_E1_X = reshape(R*E1X + t,1,3,N); % 1*3*N
    g_E2_X = reshape(R*E2X + t,1,3,N); % 1*3*N
    g_E3_X = reshape(R*E3X + t,1,3,N); % 1*3*N
    g_E4_X = reshape(R(:,1),1,3);
    g_E5_X = reshape(R(:,2),1,3);
    g_E6_X = reshape(R(:,3),1,3);

    % Flatten transformed points
    g_E1_X_flatten = reshape(g_E1_X,3*N,1);
    g_E2_X_flatten = reshape(g_E2_X,3*N,1);
    g_E3_X_flatten = reshape(g_E3_X,3*N,1);
    g_E4_X_flatten = reshape(repmat(R(:,1),1,N),3*N,1);
    g_E5_X_flatten = reshape(repmat(R(:,2),1,N),3*N,1);
    g_E6_X_flatten = reshape(repmat(R(:,3),1,N),3*N,1);

    % Compute gradient components
    gE1X_M_0 = reshape(pagemtimes(g_E1_X,M_0),1,3*N);
    gE2X_M_0 = reshape(pagemtimes(g_E2_X,M_0),1,3*N);
    gE3X_M_0 = reshape(pagemtimes(g_E3_X,M_0),1,3*N);
    gE4X_M_0 = reshape(pagemtimes(g_E4_X,M_0),1,3*N);
    gE5X_M_0 = reshape(pagemtimes(g_E5_X,M_0),1,3*N);
    gE6X_M_0 = reshape(pagemtimes(g_E6_X,M_0),1,3*N);

    % Final gradient and Hessian
    g_gradient = 2.*([gE1X_M_0; gE2X_M_0; gE3X_M_0; gE4X_M_0; gE5X_M_0; ...
                     gE6X_M_0]*gX_flatten - (M_1_flatten* ...
                     [g_E1_X_flatten, g_E2_X_flatten, g_E3_X_flatten,...
                      g_E4_X_flatten, g_E5_X_flatten, g_E6_X_flatten])');

    H = 2.*([gE1X_M_0; gE2X_M_0; gE3X_M_0; gE4X_M_0; gE5X_M_0; gE6X_M_0]*...
           [g_E1_X_flatten, g_E2_X_flatten, g_E3_X_flatten, ...
            g_E4_X_flatten, g_E5_X_flatten, g_E6_X_flatten]);
end

function [R, t] = NewtonSE3(R, t, M_0, M_1, X, N, E1X, E2X, E3X, maxIter, tolerance)
% NewtonSE3 - Newton optimization on SE(3) manifold

    iter = 1;
    while iter <= maxIter
        % Calculate gradient and Hessian
        [g_gradient, H] = GradientSE3(R, t, M_0, M_1, X, N, E1X, E2X, E3X);
        
        % Check convergence
        if norm(g_gradient) <= tolerance
            break
        else
            % Newton update
            x_opti = -(((1/2).*(H+H'))\eye(6))*g_gradient;
            X_opti = [0 -x_opti(3) x_opti(2) x_opti(4); ...
                     x_opti(3) 0 -x_opti(1) x_opti(5); ...
                     -x_opti(2) x_opti(1) 0 x_opti(6);
                     0 0 0 0];
            g = [R t; 0 0 0 1]*expm(X_opti);
            R = g(1:3,1:3);
            t = g(1:3,4);
        end
        iter = iter + 1;
    end
end

function [loglikelihood, sigma2] = Shrink_step(R, t, X, P, M_0, M_1, M_2, N,weidu)
% Shrink_step - Compute log-likelihood and update sigma2

    gX = R*X + t;
    gX_T = reshape(gX,1,3,N);
    gX = reshape(gX,3,1,N);

    % Calculate Log-likelihood
    loglikelihood = sum(pagemtimes(gX_T,pagemtimes(M_0,gX)));
    loglikelihood = loglikelihood - 2*reshape(gX,3*N,1)'*reshape(M_1',3*N,1);
    loglikelihood = loglikelihood + sum(M_2);

    % Update sigma2
    sum_P = sum(sum(P));
    sigma2 = gather(loglikelihood/(weidu*sum_P));
end