%% pitprops data analysis
clear all

pitprops = readtable("pitprops.csv");
true_M = table2array(pitprops(:,2:end));

%% check eigenvalues of original covariance matrix
plot(sort(eig(true_M),"desc"))
disp(sort(eig(true_M),"desc"))

%% run algorithms on original covariance matrix
d = length(true_M);

true_support = [1,2,7,8,9,10];
true_idx = zeros(d,1);
true_idx(true_support) = 1;
s = length(true_support);

%%% SDP
M_hat = sdp_optim(true_M, 0.4, d);
num_select = sum((diag(M_hat)>1e-4));
exact_recovery = double(sum((diag(M_hat)>1e-4)==true_idx)==d);
disp([num_select, exact_recovery])

%%% Thresholding
[u, v] = TSPCA(true_M, d, 0, 1.1);
num_select_dt = sum((abs(u)>sqrt(0.085)));
exact_recovery_dt = double(sum((abs(u)>sqrt(0.085))==true_idx)==d);
num_select_it = sum((abs(v)>sqrt(0.085)));
exact_recovery_it = double(sum((abs(v)>sqrt(0.085))==true_idx)==d);
disp([num_select_dt, exact_recovery_dt])
disp([num_select_it, exact_recovery_it])

%%% A-ManPG
[~,~,Vu] = svds(true_M, 1);  
phi_init = Vu;
option_manpg_alt.X0 = phi_init;
option_manpg_alt.Y0 = phi_init;
option_manpg_alt.maxiter =1e4; 
option_manpg_alt.tol =1e-10;
option_manpg_alt.n = 1;  
option_manpg_alt.d = d;  
option_manpg_alt.mu = 1.5;  
option_manpg_alt.type = 1;
option_manpg_alt.lambda = 1;

[~, ~ , ~, ~, ~, Y_alt]= spca_amanpg(true_M, option_manpg_alt);

num_select_amanpg = sum(abs(Y_alt)>sqrt(1e-4));
exact_recovery_amanpg = double(sum((abs(Y_alt)>sqrt(1e-4))==true_idx)==d);
disp([num_select_amanpg, exact_recovery_amanpg])

%%% GPM
[~, supp_est] = GPM(true_M, 1);
num_select_gpm = sum(supp_est);
exact_recovery_gpm = double(sum(supp_est==true_idx)==d);
disp([num_select_gpm, exact_recovery_gpm])


%% run algorithms on incomplete matrices
clear all 

compare_algorithms(100, 0.1, 100);

%% check matrix completion accuracy
clear all

res = readtable("compare_algorithms_result.csv");

mean(res{:,7})
mean(res{:,7}<0.25)

%% plotting
clear all

res = readtable("compare_algorithms_result.csv");

tbl = table(res{:,4}, res{:,8:end});
stats = grpstats(tbl,"Var1");

er_rates = stats.mean_Var2;
sdp_selected_er_rates = er_rates(:,1);
sdp_complete_selected_er_rates = er_rates(:,2);
sdp_er_rates = max(er_rates(:,3:42)')';
sdp_complete_er_rates = max(er_rates(:,43:82)')';
amanpg_er_rates = max(er_rates(:,83:102)')';
amanpg_complete_er_rates = max(er_rates(:,103:122)')';
dt_er_rates = max(er_rates(:,123:137)')';
dt_complete_er_rates = max(er_rates(:,138:152)')';
it_er_rates = max(er_rates(:,153:377)')';
it_complete_er_rates = max(er_rates(:,378:602)')';
gpm_selected_er_rates = er_rates(:,625);
gpm_complete_selected_er_rates = er_rates(:,626);

t = tiledlayout('flow','TileSpacing','compact');
nexttile; hold on;
plot(0.1:0.2:2.1, sdp_selected_er_rates, '-_', 'LineWidth', 1.5);
plot(0.1:0.2:2.1, er_rates(:,5:4:42), '-_');
hold off;
ylim([0, 1]); xlim([0, 2.2]);
xticks(0:0.2:2.2);
xlabel({'$\frac{\psi(\mathcal{G}_{J,J})}{\phi(\mathcal{G}_{J,J})}$'}, 'FontSize', 20,'Interpreter','latex'); 
ylabel('Exact Recovery Rate', 'FontSize', 16); 
lgd = legend({'Selected $\rho$ by $C_\rho$', ...
                '$\rho=0.1$', '$\rho=0.2$', ...
                '$\rho=0.3$', '$\rho=0.4$', ...
                '$\rho=0.5$', '$\rho=0.6$', ...
                '$\rho=0.7$', '$\rho=0.8$', ...
                '$\rho=0.9$', '$\rho=1$'},...
    'Interpreter', 'latex', 'FontSize', 10);
lgd.Layout.Tile = 'east';
box on;

f = gcf;
set(f, 'Position',  [100, 100, 500, 350])

exportgraphics(f,strcat('tuning.png'),'Resolution', 300)


figure; hold on;
plot(0.1:0.2:2.1, sdp_selected_er_rates, '-o', 'LineWidth', 1.5, 'Color', [0.1 0.5 0.8]);
plot(0.1:0.2:2.1, sdp_complete_selected_er_rates, '--o', 'LineWidth', 1.5, 'Color', [0.3 0.7 0.8]);
plot(0.1:0.2:2.1, amanpg_er_rates, '-^', 'LineWidth', 1.5, 'Color', [0.9 0.2 0.2]);
plot(0.1:0.2:2.1, amanpg_complete_er_rates, '--^', 'LineWidth', 1.5, 'Color', [0.9 0.6 0.2]);
plot(0.1:0.2:2.1, it_er_rates, '-x', 'LineWidth', 1.5, 'Color', [0.6 0.2 0.8]);
plot(0.1:0.2:2.1, it_complete_er_rates, '--x', 'LineWidth', 1.5, 'Color', [0.9 0.2 0.9]);
plot(0.1:0.2:2.1, gpm_selected_er_rates, '-s', 'LineWidth', 1.5, 'Color', [0.1 0.8 0.1]);
plot(0.1:0.2:2.1, gpm_complete_selected_er_rates, '--s', 'LineWidth', 1.5, 'Color', [0.4 0.9 0.4]);
hold off;
ylim([0, 1]); xlim([0, 2.2]);
xticks(0:0.2:2.2);
xlabel({'$\frac{\psi(\mathcal{G}_{J,J})}{\phi(\mathcal{G}_{J,J})}$'}, 'FontSize', 20,'Interpreter','latex'); 
ylabel('Exact Recovery Rate', 'FontSize', 16); 
legend({'SDP', 'SDP with completion', 'A-ManPG', 'A-ManPG with completion', 'ITSPCA', 'ITSPCA with completion', 'GPM', 'GPM with completion'}, 'FontSize', 10);
box on;

f = gcf;
set(f, 'Position',  [100, 100, 400, 400])

exportgraphics(f,strcat('pitprops_comparison.png'),'Resolution', 300)