% Accelerated degree profile matching algorithm (version 4) version 3 + estimation
% A and B are the matrices to be matched
% Return permutation matrix P so that P*A*P' is matched to B
function [P_dp, D, D_hat, i, isPerm, msg, per, pool_count] = match_fast4(A, B)
    n = size(A, 1);
    deg1 = sum(A);  
    deg2 = sum(B);
    D_hat = abs(deg1' - deg2);
    
    [deg1_sort, ind1] = sort(deg1);
    [deg2_sort, ind2] = sort(deg2);
    N_deg1 = cell(n, 1); 
    F_deg1 = cell(n, 1);
    N_deg2 = cell(n, 1); 
    F_deg2 = cell(n, 1);
    
    tic
    for k = 1:n
        temp1 = deg1_sort(logical(A(k, ind1)));
        [temp1_unique, ~, temp1_c] = unique(temp1);
        N_deg1{k} = temp1_unique;
        F_deg1{k} = accumarray(temp1_c, 1) / deg1(k);
        temp2 = deg2_sort(logical(B(k, ind2)));
        [temp2_unique, ~, temp2_c] = unique(temp2);
        N_deg2{k} = temp2_unique;
        F_deg2{k} = accumarray(temp2_c, 1) / deg2(k);
    end
    t1 = toc;
    
    D = zeros(n, n); 
    
    is_computed = false(n, n);
    
    % 预分配一定的空间以加速，不够再扩展
    vals_pool = zeros(n*n, 1); 
    rows_pool = zeros(n*n, 1);
    cols_pool = zeros(n*n, 1);
    pool_count = 0;

    % 初始状态：所有的行和列都是"空缺"的
    missing_rows = true(n, 1);
    missing_cols = true(n, 1);
    
    max_diff = max(D_hat(:));
    
   
    % sum_dist_0 = 0; count_0 = 0; % 记录 D_hat = 0 的情况
    % sum_dist_1 = 0; count_1 = 0; % 记录 D_hat = 1 的情况

    tic
    for delta = 0:max_diff
        if pool_count < n^2/5
            
            mask_candidate = (D_hat <= delta) & ~is_computed;
            
            if pool_count > 0
                row_mask_mat = missing_rows;      
                col_mask_mat = missing_cols';     
                target_area = (row_mask_mat | col_mask_mat);
                mask_candidate = mask_candidate & target_area;
            end
            
            [rows_cand, cols_cand] = find(mask_candidate);
            num_cand = length(rows_cand);
            
            if num_cand > 0
                new_vals = zeros(num_cand, 1);
                for k = 1:num_cand
                    r = rows_cand(k);
                    c = cols_cand(k);
                    
                    if deg1(r)>0 && deg2(c)>0
                        deg_pro = dwass_discrete2(N_deg1{r}, N_deg2{c}, F_deg1{r}, F_deg2{c});
                        D(r, c) = deg_pro;
                        is_computed(r, c) = true;
                        new_vals(k) = deg_pro;
                        
                        % 仅针对度差为 0 和 1 的点提取其 degree profile 差
                        % current_d_hat = D_hat(r, c);
                        % if current_d_hat == 0
                        %     sum_dist_0 = sum_dist_0 + deg_pro;
                        %     count_0 = count_0 + 1;
                        % elseif current_d_hat == 1
                        %     sum_dist_1 = sum_dist_1 + deg_pro;
                        %     count_1 = count_1 + 1;
                        % end
                    end
                end
                
                start_idx = pool_count + 1;
                end_idx = pool_count + num_cand;
                rows_pool(start_idx:end_idx) = rows_cand;
                cols_pool(start_idx:end_idx) = cols_cand;
                vals_pool(start_idx:end_idx) = new_vals;
                pool_count = end_idx;
            end
            
            if pool_count >= n
                current_vals = vals_pool(1:pool_count);
                [~, sort_idx] = sort(current_vals, 'ascend');
                top_n_indices = sort_idx(1:n);
    
                best_rows = rows_pool(top_n_indices);
                best_cols = cols_pool(top_n_indices);
                D_mem = current_vals(top_n_indices);
                
                P_dp = sparse(best_cols, best_rows, ones(n, 1), n, n);
                [isPerm, msg] = check_permutation_matrix(P_dp);
                
                if isPerm == 1
                    isPerfect = check_perfect_matching(best_rows, best_cols, D_mem, N_deg1, N_deg2, F_deg1, F_deg2);
                    if isPerfect == 1
                        msg = '是完美匹配';
                        i = delta;
                        t2 = toc;
                        per = t1 / (t1 + t2);
                        
                        % D = linear_fitting(D, D_hat, is_computed);
                        return
                    end
                end
                
                present_rows = unique(best_rows);
                present_cols = unique(best_cols);
                missing_rows(:) = true;
                missing_cols(:) = true;
                missing_rows(present_rows) = false;
                missing_cols(present_cols) = false;
            end
        end
    end
    
    t2 = toc;
    i = max_diff;
    per = t1 / (t1 + t2);
    
    D = linear_fitting(D, D_hat, is_computed);
end




function D_out = linear_fitting(D_in, D_hat, is_computed)
    D_out = D_in;
    
    vals = D_in(is_computed);
    diffs = D_hat(is_computed);
    
    if isempty(vals), return; end
    
    
    % alpha = 0; 
    % beta = mean(vals);
    
    d_max = max(diffs);

    d_min = min(diffs);

    if d_min == d_max
        beta = mean(vals);
        alpha = beta;
    else
        y_min = mean(vals(diffs == d_min));
        y_max = mean(vals(diffs == d_max));
        
        alpha = (y_max - y_min) / double(d_max - d_min);
        
        beta = y_min - alpha * double(d_min);
        
        if alpha < 0
            alpha = 0;
            beta = mean(vals);
        end
        
        
    end

    mask_fill = ~is_computed;
    if any(mask_fill(:))
        % Apply: y = alpha * x + beta
        estimated_vals = alpha * double(D_hat(mask_fill)) + beta;
        
        % 确保非负
        estimated_vals(estimated_vals < 0) = 0;
        
        D_out(mask_fill) = estimated_vals;
    end
    
end