% Complete MM Algorithm with Stiefel, Random, Spectral, and Greedy Initializations
% Load data
load house_sync_data

%% Parameters
k_val = 65; % Problem size
max_iter = 50;
tolerance = 1e-6;

%% Prepare data
dimVector = repmat(m, k, 1);
Wcell = mat2cell(Wnoise, dimVector, dimVector);
subIdx = round(linspace(1, k, k_val));
currWnoise = cell2mat(Wcell(subIdx, subIdx));
currDimVector = dimVector(subIdx);

% Universe size (number of columns in P) - same as block size m
universe_size = m;

%% Try different initializations
fprintf('Testing different initialization methods:\n\n');

% 1. Original Stiefel initialization
[~, U_Stiefel] = SparseStiefelSync(currWnoise, currDimVector, universe_size, 0);
P_init_stiefel = blockwiseAssignment(U_Stiefel, currDimVector);
fprintf('Stiefel init size: [%d, %d]\n', size(P_init_stiefel, 1), size(P_init_stiefel, 2));
fprintf('Stiefel init objective: %.6f\n', trace(P_init_stiefel' * currWnoise * P_init_stiefel));

% 2. Random Permutation initialization
P_init_random = RandomPerm_Init(currWnoise, currDimVector, universe_size);
fprintf('Random Perm init size: [%d, %d]\n', size(P_init_random, 1), size(P_init_random, 2));
fprintf('Random Perm init objective: %.6f\n', trace(P_init_random' * currWnoise * P_init_random));

% 3. Spectral initialization
P_init_spectral = Spectral_Init(currWnoise, currDimVector, universe_size);
fprintf('Spectral init size: [%d, %d]\n', size(P_init_spectral, 1), size(P_init_spectral, 2));
fprintf('Spectral init objective: %.6f\n', trace(P_init_spectral' * currWnoise * P_init_spectral));

% 4. Greedy Matching initialization
P_init_greedy = GreedyMatching_Init(currWnoise, currDimVector, universe_size);
fprintf('Greedy Matching init size: [%d, %d]\n', size(P_init_greedy, 1), size(P_init_greedy, 2));
fprintf('Greedy Matching init objective: %.6f\n', trace(P_init_greedy' * currWnoise * P_init_greedy));

%% Run MM with each initialization
[P_mm_stiefel, obj_stiefel, iter_stiefel] = mmPermSyncWithTracking(currWnoise, currDimVector, tolerance, max_iter, P_init_stiefel, k_val);
[P_mm_random, obj_random, iter_random] = mmPermSyncWithTracking(currWnoise, currDimVector, tolerance, max_iter, P_init_random, k_val);
[P_mm_spectral, obj_spectral, iter_spectral] = mmPermSyncWithTracking(currWnoise, currDimVector, tolerance, max_iter, P_init_spectral, k_val);
[P_mm_greedy, obj_greedy, iter_greedy] = mmPermSyncWithTracking(currWnoise, currDimVector, tolerance, max_iter, P_init_greedy, k_val);

%% Verify output sizes
fprintf('\nOutput sizes:\n');
fprintf('P_mm_stiefel size: [%d, %d]\n', size(P_mm_stiefel, 1), size(P_mm_stiefel, 2));
fprintf('P_mm_random size: [%d, %d]\n', size(P_mm_random, 1), size(P_mm_random, 2));
fprintf('P_mm_spectral size: [%d, %d]\n', size(P_mm_spectral, 1), size(P_mm_spectral, 2));
fprintf('P_mm_greedy size: [%d, %d]\n', size(P_mm_greedy, 1), size(P_mm_greedy, 2));

%% Compare results
fprintf('\nResults comparison:\n');
fprintf('Stiefel: Final obj = %.6f, Iterations = %d\n', obj_stiefel(end), iter_stiefel);
fprintf('Random Perm: Final obj = %.6f, Iterations = %d\n', obj_random(end), iter_random);
fprintf('Spectral: Final obj = %.6f, Iterations = %d\n', obj_spectral(end), iter_spectral);
fprintf('Greedy Matching: Final obj = %.6f, Iterations = %d\n', obj_greedy(end), iter_greedy);

%% Plot comparison - Four subplots as in original code
figure('Color', 'w', 'Position', [100 100 800 800]);

subplot(2,2,1);
plot(1:iter_stiefel, obj_stiefel, 'b-o', 'LineWidth', 2);
xlabel('Iteration'); ylabel('Objective');
title('MM with Stiefel Init');
grid on;

subplot(2,2,2);
plot(1:iter_random, obj_random, 'k-o', 'LineWidth', 2);
xlabel('Iteration'); ylabel('Objective');
title('MM with Random Perm Init');
grid on;

subplot(2,2,3);
plot(1:iter_spectral, obj_spectral, 'c-o', 'LineWidth', 2);
xlabel('Iteration'); ylabel('Objective');
title('MM with Spectral Init');
grid on;

subplot(2,2,4);
plot(1:iter_greedy, obj_greedy, 'm-o', 'LineWidth', 2);
xlabel('Iteration'); ylabel('Objective');
title('MM with Greedy Matching Init');
grid on;

%% ============ ALL REQUIRED FUNCTIONS ============


%% MM Algorithm with convergence tracking
function [P, objectives, final_iter] = mmPermSyncWithTracking(W, dimVector, tol, maxIter, Uinit, k_val)
    m = sum(dimVector);
    U = Uinit;
    P = blockwiseAssignment(U, dimVector);
    d = size(P, 2);
    
    % Make W PSD
    lambdamin = min(eig(W));
    M = W - lambdamin * eye(m);
    
    % Track objectives
    objectives = zeros(maxIter + 1, 1);
    objectives(1) = trace(P' * W * P); % Initial objective
    prevObj = objectives(1);
    
    for t = 1:maxIter
        % Build surrogate gradient
        A = M' * P; % m×d
        
        % Form primal update T = 2*A
        T = 2 * A;
        
        % Binary threshold
        Pnew = double(T >= 0);
        
        % Re-project each block so rows sum to 1 and cols ≤1
        Pnew = blockwiseAssignment(Pnew, dimVector);
        
        % Store objective
        currObj = trace(Pnew' * W * Pnew);
        objectives(t + 1) = currObj;
        
        % Check convergence
        if abs(currObj - prevObj) < tol
            final_iter = t + 1;
            objectives = objectives(1:final_iter) / k_val^2;
            P = Pnew;
            return;
        end
        
        P = Pnew;
        prevObj = currObj;
    end
    
    final_iter = maxIter + 1;
    objectives = objectives(1:final_iter) / k_val^2;
end

%% Blockwise assignment function
function Pproj = blockwiseAssignment(Pin, rowDims)
    m = sum(rowDims);
    d = size(Pin, 2);
    Pproj = zeros(m, d);
    offset = 0;
    
    for b = 1:numel(rowDims)
        mi = rowDims(b);
        rows = (offset + 1):(offset + mi);
        B = Pin(rows, :);
        Pblk = zeros(mi, d);
        usedCols = false(1, d);
        
        % For each row, in order, pick best free column
        for r = 1:mi
            [~, order] = sort(B(r, :), 'descend');
            for c = order
                if ~usedCols(c)
                    Pblk(r, c) = 1;
                    usedCols(c) = true;
                    break;
                end
            end
        end
        
        Pproj(rows, :) = Pblk;
        offset = offset + mi;
    end
end

%% Random Permutation Initialization (FIXED)
function P_random = RandomPerm_Init(~, dimVector, d)
    % Inputs:
    %   W - weight matrix (not used, kept for consistent interface)
    %   dimVector - vector of block dimensions
    %   d - universe size (number of columns in output)
    % Output:
    %   P_random - random permutation initialization matrix of size (sum(dimVector) x d)
    
    m = sum(dimVector);
    P_random = zeros(m, d);
    offset = 0;
    
    for b = 1:numel(dimVector)
        mi = dimVector(b);
        rows = (offset + 1):(offset + mi);
        
        % Generate a random permutation for the block
        % Each block is mi x d, we assign mi rows to mi distinct columns
        perm = randperm(d);
        perm = perm(1:mi);  % Take first mi columns (assumes mi <= d)
        
        P_block = zeros(mi, d);
        for i = 1:mi
            P_block(i, perm(i)) = 1;
        end
        
        P_random(rows, :) = P_block;
        offset = offset + mi;
    end
end

%% Spectral Initialization (FIXED)
function P_spectral = Spectral_Init(W, dimVector, d)
    % Inputs:
    %   W - weight matrix
    %   dimVector - vector of block dimensions
    %   d - universe size (number of columns in output)
    % Output:
    %   P_spectral - spectral initialization matrix of size (sum(dimVector) x d)
    
    m = sum(dimVector);
    
    % Compute leading eigenvectors
    [V, ~] = eigs(W, d, 'largestreal');
    
    % Handle complex eigenvalues by taking real part
    V = real(V);
    
    % Normalize rows to avoid numerical issues
    row_norms = sqrt(sum(V.^2, 2));
    row_norms(row_norms < 1e-10) = 1;  % Avoid division by zero
    V = V ./ row_norms;
    
    % Project to block permutation using blockwiseAssignment
    P_spectral = blockwiseAssignment(V, dimVector);
end

%% Greedy Matching Initialization (FIXED)
function P_greedy = GreedyMatching_Init(W, dimVector, d)
    % Inputs:
    %   W - weight matrix
    %   dimVector - vector of block dimensions
    %   d - universe size (number of columns in output)
    % Output:
    %   P_greedy - greedy matching initialization matrix of size (sum(dimVector) x d)
    
    m = sum(dimVector);
    n_blocks = numel(dimVector);
    
    % Use spectral embedding to get initial scores
    [V, D] = eigs(W, d, 'largestreal');
    V = real(V);
    
    % Weight by eigenvalues for better initialization
    eigenvalues = diag(D);
    eigenvalues = real(eigenvalues);
    eigenvalues(eigenvalues < 0) = 0;
    V = V * diag(sqrt(eigenvalues));
    
    P_greedy = zeros(m, d);
    offset = 0;
    
    for b = 1:n_blocks
        mi = dimVector(b);
        rows = (offset + 1):(offset + mi);
        
        % Extract block scores from V
        V_block = V(rows, :);
        
        % Greedy assignment: for each row, pick best available column
        P_block = zeros(mi, d);
        used_cols = false(1, d);
        
        % Compute row priorities (rows with highest max values go first)
        max_vals = max(V_block, [], 2);
        [~, row_order] = sort(max_vals, 'descend');
        
        for idx = 1:mi
            r = row_order(idx);
            scores = V_block(r, :);
            scores(used_cols) = -Inf;  % Mask used columns
            [~, col] = max(scores);
            P_block(r, col) = 1;
            used_cols(col) = true;
        end
        
        P_greedy(rows, :) = P_block;
        offset = offset + mi;
    end
end

%% SparseStiefelSync function
% This function is a wrapper to apply the method in [1] to permutation 
% synchronisation. If you use this in your work, you are required to cite [1].
%
% [1] F. Bernard, D. Cremers, J. Thunberg. Sparse Quadratic Optimisation over 
% the Stiefel Manifold with Application to Permutation Synchronisation.
% NeurIPS 2021
%
% Input:   W is a matrix of size m x m, which comprises of k x k block
%              matrices that represent pairwise matchings, where the 
%              dimension of the block at position (i,j) is 
%              dimVector(i) x dimVector(j)
%          dimVector denotes the size of the blocks of W, where
%              sum(dimVector) = m
%          d is the number of columns of the output U (also called universe
%              size in permutation synchronisation problems)
%          vis (0 or 1) to enable visualisation
% Output:  Wout is a matrix of size m x m and contains cycle-consistent pairwise matchings
%          Uproj is a matrix of size m x d, so that Wout = Uproj*Uproj'
