clear; close all; clc;

n=[0.5,0.6,0.7,0.8]; 
% 动态生成legend标签
legend_labels = cell(1, 2*length(n)); % 初始化一个cell数组来存储legend标签

% 为AltSGD生成前4个标签
for i = 1:4
    legend_labels{i} = ['$\mathrm{Ours},\ \eta = ', num2str(n(i), '%.0g'), '$'];
end

% 为SGD生成后4个标签，使用LaTeX格式化
for i = 1:4
    legend_labels{i+4} = ['$\mathrm{ScaledGD}(\lambda),\ \eta = ', num2str(n(i), '%.0g'), '$'];
end


ite=500;
error_APGD=zeros(ite,length(n));
temp_all=zeros(ite,1);
repeat_time=1;
for i=1:length(n)
    for j=1:repeat_time
        temp=AltScaledGD(n(i),ite);
        temp_all=temp_all+temp;
    end
    error_APGD(:,i)=temp_all/repeat_time;
    temp_all=zeros(ite,1);
end

error_ScaledGD=zeros(ite,length(n));
temp_all=zeros(ite,1);
repeat_time=1;
for i=1:length(n)
    for j=1:repeat_time
        temp=ScaledGD(n(i),ite);
        temp_all=temp_all+temp;
    end
    error_ScaledGD(:,i)=temp_all/repeat_time;
    temp_all=zeros(ite,1);
end

marker_indices = 1:30:ite;
x = 1:ite;
p1 = plot(x,error_APGD(:,1),'-o','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[189,30,30]/255);
hold on;

p2 = plot(x,error_APGD(:,2),'-*','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[135 207 164]/255);
hold on;

p3 = plot(x,error_APGD(:,3),'-+','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[253 185 106]/255);
hold on;

p4 = plot(x,error_APGD(:,4),'-s','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[55 103 149]/255);
hold on;

marker_indices = 1:50:ite;
x = 1:ite;
p5 = plot(x,error_ScaledGD(:,1),'--o','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[189,30,30]/255);
hold on;
p6 = plot(x,error_ScaledGD(:,2),'--*','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[135 207 164]/255);
hold on;
p7 = plot(x,error_ScaledGD(:,3),'--+','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[253 185 106]/255);
hold on;
p8 = plot(x,error_ScaledGD(:,4),'--s','MarkerIndices',marker_indices,'MarkerSize',18,LineWidth=2,Color=[55 103 149]/255);
hold on;



set(gca,'yscale','log');
lgd=legend([p1,p2,p3,p4,p5,p6,p7,p8],legend_labels,'Interpreter', 'latex');
lgd.FontSize = 18;
xlabel('Iterations');
ylabel('Relative Error');
ylim([1e-15,100])





%% APGD
function errors_APGD=AltScaledGD(eta,T)
    alpha =1e-10;
    n = 20;r = 5;
    search_r=10;kappa = 100;
    k=search_r;
    m=10*n*r;
    errors_APGD = zeros(T,1);
    scale=0.1;
    L=randn(n,search_r)*scale/(3*sqrt(n+k));
    R=randn(n,search_r)*scale/(3*sqrt(n+k));

    U_seed = sign(rand(n, r) - 0.5);
    [U_star, ~, ~] = svds(U_seed, r);
    V_seed = sign(rand(n, r) - 0.5);
    [V_star, ~, ~] = svds(U_seed, r);
    As = cell(m, 1);
    for k = 1:m
	    As{k} = randn(n, n)/sqrt(m);
    end
    
    sigma_star = linspace(1, 1/kappa, r);
    L_star = U_star*diag(sqrt(sigma_star));
    R_star = V_star*diag(sqrt(sigma_star));
    X_star = L_star*R_star';
    y = zeros(m, 1);
    yt = zeros(m, 1);
    for k = 1:m
        y(k) = As{k}(:)'*X_star(:);
    end
    m = 10*n*r;
    for t = 1:T
        X = L*R';
        error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
        errors_APGD(t) = error;
        for k = 1:m
            yt(k) = As{k}(:)'*X(:);
        end
    
        if ~isfinite(error) || error > 1e5 || error <= 1e-15
            break;
        end
        Z = zeros(n, n);
        for k = 1:m
            Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
        end
        
        
        L_plus = L - eta*Z*R/(R'*R+alpha*eye(search_r));
        L = L_plus;
    
        X = L*R';
        for k = 1:m
            yt(k) = As{k}(:)'*X(:);
        end
        Z = zeros(n, n);
        for k = 1:m
            Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
        end
        R_plus = R - eta*Z'*L/(L'*L+alpha*eye(search_r));
        
    
        L = L_plus;
        R = R_plus;
    end
end

function errors_ScaledGD=ScaledGD(eta,T)
    alpha =1e-10;
    n = 20;r = 5;
    search_r=10;kappa = 100;
    m=10*n*r;
    errors_ScaledGD = zeros(T,1);
    k=search_r;scale=0.1;
    L=randn(n,search_r)*scale/(3*sqrt(n+k));
    R=randn(n,search_r)*scale/(3*sqrt(n+k));

    U_seed = sign(rand(n, r) - 0.5);
    [U_star, ~, ~] = svds(U_seed, r);
    V_seed = sign(rand(n, r) - 0.5);
    [V_star, ~, ~] = svds(U_seed, r);
    As = cell(m, 1);
    for k = 1:m
	    As{k} = randn(n, n)/sqrt(m);
    end
    
    sigma_star = linspace(1, 1/kappa, r);
    L_star = U_star*diag(sqrt(sigma_star));
    R_star = V_star*diag(sqrt(sigma_star));
    X_star = L_star*R_star';
    y = zeros(m, 1);
    yt = zeros(m, 1);
    for k = 1:m
        y(k) = As{k}(:)'*X_star(:);
    end
    m = 10*n*r;
    for t = 1:T
        X = L*R';
        error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
        errors_ScaledGD(t) = error;
        for k = 1:m
            yt(k) = As{k}(:)'*X(:);
        end
    
        if ~isfinite(error) || error > 1e10 || error <= 1e-15
            break;
        end
            Z = zeros(n, n);
        for k = 1:m
            Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
        end
    

        L_plus = L - eta*Z*R/(R'*R+alpha*eye(search_r));
        R_plus = R - eta*Z'*L/(L'*L+alpha*eye(search_r));
    
    
        L = L_plus;
        R = R_plus;
    end
end
