clear; close all; clc;

n = 20;
r = 5;
search_r = 10;
kappa = 100;
m = 10*n*r;
T = 100;
thresh_up = 1e20; thresh_low = 1e-15;
repeat_times = 1; % 重复次数
noise_rates = [1e-2, 1e-3, 1e-4]; % 三种不同的 noise_rate

% 定义不同的 marker
markers = {'-o', '-s', '-d'};

% 初始化平均误差存储
avg_errors_APGD = zeros(T, length(noise_rates));
avg_errors_ScaledGD = zeros(T, length(noise_rates));
avg_errors_AltGD = zeros(T, length(noise_rates));

for noise_idx = 1:length(noise_rates)
    noise_rate = noise_rates(noise_idx);
    
    for repeat = 1:repeat_times
        U_seed = sign(rand(n, r) - 0.5);
        [U_star, ~, ~] = svds(U_seed, r);
        V_seed = sign(rand(n, r) - 0.5);
        [V_star, ~, ~] = svds(U_seed, r);
        As = cell(m, 1);
        for k = 1:m
            As{k} = randn(n, n)/sqrt(m);
        end
        
        sigma_star = linspace(1, 1/kappa, r);
        L_star = U_star*diag(sqrt(sigma_star));
        R_star = V_star*diag(sqrt(sigma_star));
        X_star = L_star*R_star';
        y = zeros(m, 1);
        yt = zeros(m, 1);
        for k = 1:m
            y(k) = As{k}(:)'*X_star(:) + normrnd(0, noise_rate);
        end

        %% APGD
        L = randn(n, search_r)*1e-1;
        R = randn(n, search_r)*1e-1;
        alpha = 1e-10;
        eta = 1;
        for t = 1:T
            X = L*R';
            error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
            avg_errors_APGD(t, noise_idx) = avg_errors_APGD(t, noise_idx) + error;
            for k = 1:m
                yt(k) = As{k}(:)'*X(:);
            end

            if ~isfinite(error) || error > thresh_up || error < thresh_low
                break;
            end
            Z = zeros(n, n);
            for k = 1:m
                Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
            end

            L_plus = L - eta*Z*R/(R'*R + alpha*eye(search_r));
            L = L_plus;

            X = L*R';
            for k = 1:m
                yt(k) = As{k}(:)'*X(:);
            end
            Z = zeros(n, n);
            for k = 1:m
                Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
            end
            R_plus = R - eta*Z'*L/(L'*L + alpha*eye(search_r));

            L = L_plus;
            R = R_plus;
        end

        %% ScaledGD
        L = randn(n, search_r)*1e-1;
        R = randn(n, search_r)*1e-1;
        eta = 0.5;
        alpha = 1e-10;
        for t = 1:T
            X = L*R';
            error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
            avg_errors_ScaledGD(t, noise_idx) = avg_errors_ScaledGD(t, noise_idx) + error;
            for k = 1:m
                yt(k) = As{k}(:)'*X(:);
            end

            if ~isfinite(error) || error > thresh_up || error < thresh_low
                break;
            end
            Z = zeros(n, n);
            for k = 1:m
                Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
            end

            L_plus = L - eta*Z*R/(R'*R + alpha*eye(search_r));
            R_plus = R - eta*Z'*L/(L'*L + alpha*eye(search_r));

            L = L_plus;
            R = R_plus;
        end

        %% AltGD
        Y = zeros(n, n);
        for k = 1:m
            Y = Y + y(k)*As{k};
        end
        [U0, Sigma0, V0] = svds(Y, search_r);
        L = U0*sqrt(Sigma0);
        R = V0*sqrt(Sigma0);
        eta = 0.5;
        alpha = 1e-10;
        for t = 1:T
            X = L*R';
            error = norm(X - X_star, 'fro')/norm(X_star, 'fro');
            avg_errors_AltGD(t, noise_idx) = avg_errors_AltGD(t, noise_idx) + error;
            for k = 1:m
                yt(k) = As{k}(:)'*X(:);
            end

            if ~isfinite(error) || error > thresh_up || error < thresh_low
                break;
            end
            Z = zeros(n, n);
            for k = 1:m
                Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
            end

            alpha = 0.01*0.95^t;
            L_plus = L - eta*Z*R/(R'*R + alpha*eye(search_r));
            R_plus = R - eta*Z'*L/(L'*L + alpha*eye(search_r));

            L = L_plus;
            R = R_plus;
        end
    end
end

% 计算每次重复的平均误差
avg_errors_APGD = avg_errors_APGD / repeat_times;
avg_errors_ScaledGD = avg_errors_ScaledGD / repeat_times;
avg_errors_AltGD = avg_errors_AltGD / repeat_times;

% 绘制误差图
marker_indices = 1:10:T;
x = 1:T;
figure;
colors = {[189,30,30]/255, [135,207,164]/255, [253,185,106]/255}; % 颜色设置

% APGD plots
for noise_idx = 1:length(noise_rates)
    p1(noise_idx) = plot(x, avg_errors_APGD(:, noise_idx), markers{noise_idx}, 'MarkerIndices', marker_indices, 'MarkerSize', 8, 'LineWidth', 2, 'Color', colors{1});
    hold on;
end

% ScaledGD plots
for noise_idx = 1:length(noise_rates)
    p2(noise_idx) = plot(x, avg_errors_ScaledGD(:, noise_idx), markers{noise_idx}, 'MarkerIndices', marker_indices, 'MarkerSize', 8, 'LineWidth', 2, 'Color', colors{2});
end

% AltGD plots
for noise_idx = 1:length(noise_rates)
    p3(noise_idx) = plot(x, avg_errors_AltGD(:, noise_idx), markers{noise_idx}, 'MarkerIndices', marker_indices, 'MarkerSize', 8, 'LineWidth', 2, 'Color', colors{3});
end

set(gca, 'yscale', 'log');
legend([p1(1), p2(1), p3(1), p1(2), p2(2), p3(2), p1(3), p2(3), p3(3)], ...
    'Ours (1e-2)', 'Scaled GD (1e-2)', 'Zhang (1e-2)', ...
    'Ours (1e-3)', 'Scaled GD (1e-3)', 'PrecGD (1e-3)', ...
    'Ours (1e-4)', 'Scaled GD (1e-4)', 'PrecGD (1e-4)');
xlabel('Iterations');
ylabel('Relative Error');
title('Comparison of Methods with Different Noise Rates');
