clear; close all; clc;
n = 20;
r = 5;
search_r=5;
kappa = 100;
m = 10*n*r;

T = 300;

thresh_up = 1e20; thresh_low = 1e-15;
errors_APGD = zeros(T,1);
errors_ScaledGD = zeros(T,1);
errors_AltGD = zeros(T,1);
errors_GD = zeros(T,1);

U_seed = sign(rand(n, r) - 0.5);
[U_star, ~, ~] = svds(U_seed, r);
V_seed = sign(rand(n, r) - 0.5);
[V_star, ~, ~] = svds(U_seed, r);
As = cell(m, 1);
for k = 1:m
	As{k} = randn(n, n)/sqrt(m);
end


sigma_star = linspace(1, 1/kappa, r);
L_star = U_star*diag(sqrt(sigma_star));
R_star = V_star*diag(sqrt(sigma_star));
X_star = L_star*R_star';
y = zeros(m, 1);
yt = zeros(m, 1);
for k = 1:m
    y(k) = As{k}(:)'*X_star(:);
end


%% APGD
L=randn(n,search_r)*1e-1;
R=randn(n,search_r)*1e-1;
alpha =1e-10;
eta = 1;
for t = 1:T
    X = L*R';
    error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
    errors_APGD(t) = error;
    for k = 1:m
        yt(k) = As{k}(:)'*X(:);
    end

    if ~isfinite(error) || error > thresh_up || error < thresh_low
        break;
    end
    Z = zeros(n, n);
    for k = 1:m
        Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
    end
    
    
    L_plus = L - eta*Z*R/(R'*R+alpha*eye(search_r));
    L = L_plus;

    X = L*R';
    for k = 1:m
        yt(k) = As{k}(:)'*X(:);
    end
    Z = zeros(n, n);
    for k = 1:m
        Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
    end
    R_plus = R - eta*Z'*L/(L'*L+alpha*eye(search_r));
    

    L = L_plus;
    R = R_plus;
end
%% ScaledGD
L=randn(n,search_r)*1e-1;
R=randn(n,search_r)*1e-1;
eta = 0.6;
alpha=1e-10;
for t = 1:T
    X = L*R';
    error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
    errors_ScaledGD(t) = error;
    for k = 1:m
        yt(k) = As{k}(:)'*X(:);
    end
    
    if ~isfinite(error) || error > thresh_up || error < thresh_low
        break;
    end
    Z = zeros(n, n);
    for k = 1:m
        Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
    end

    alpha=1e-10;
    L_plus = L - eta*Z*R/(R'*R+alpha*eye(search_r));
    R_plus = R - eta*Z'*L/(L'*L+alpha*eye(search_r));


    L = L_plus;
    R = R_plus;
end
    %% AltGD
L=randn(n,search_r)*1e-1;
R=randn(n,search_r)*1e-1;
eta=1;
for t = 1:T
    X = L*R';
    error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
    errors_AltGD(t) = error;
    for k = 1:m
        yt(k) = As{k}(:)'*X(:);
    end

    if ~isfinite(error) || error > thresh_up || error < thresh_low
        break;
    end
    Z = zeros(n, n);
    for k = 1:m
        Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
    end
    
    
    L_plus = L - eta*Z*R;
    L = L_plus;

    X = L*R';
    for k = 1:m
        yt(k) = As{k}(:)'*X(:);
    end
    Z = zeros(n, n);
    for k = 1:m
        Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
    end
    R_plus = R - eta*Z'*L;
    

    L = L_plus;
    R = R_plus;
end
    %% GD
L=randn(n,search_r)*1e-1;
R=randn(n,search_r)*1e-1;
eta = 0.6;
alpha=1e-10;
for t = 1:T
    X = L*R';
    error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
    errors_GD(t) = error;
    for k = 1:m
        yt(k) = As{k}(:)'*X(:);
    end
    
    if ~isfinite(error) || error > thresh_up || error < thresh_low
        break;
    end
    Z = zeros(n, n);
    for k = 1:m
        Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
    end

    alpha=1e-10;
    L_plus = L - eta*Z*R;
    R_plus = R - eta*Z'*L;


    L = L_plus;
    R = R_plus;
end




marker_indices = 1:10:T;
x = 1:T;
p1 = plot(x,errors_APGD,'-o','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[189,30,30]/255);
hold on;
p2 = plot(x,errors_ScaledGD,'-*','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[135 207 164]/255);
hold on;
marker_indices = 1:50:T;
p3 = plot(x,errors_AltGD,'-+','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[253 185 106]/255);
hold on;
p4 = plot(x,errors_GD,'-s','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[55 103 149]/255);
hold on;

hold on
set(gca,'yscale','log');
legend([p1,p2,p3,p4],'Ours','Scaled GD','Alternating GD','GD');
xlabel('Iterations');
ylabel('Relative Error');
ylim([1e-15,100]);
