clc;clear;close all;
addpath(genpath('comparison_methods'));



tubal_r=3; %% exact rank 
n1=30;n2=n1;n3=3; 
ite=5000;
repeat_time=20;

k = tubal_r:3:n1;
var = 1e-3;

error = zeros(length(k),5);
for i =1:length(k)
    for j =1:repeat_time
        error_temp = LRTR(n1,n2,n3,tubal_r,k(i),ite,var);
        error(i,:) = error(i,:) + error_temp;  
    end
end




x=k;
p1=plot(x,error(:,1),':o','MarkerSize',18,'LineWidth',2,'Color',[255,0,0]/255);
hold on ;
p2=plot(x,error(:,2),'-*','MarkerSize',18,'LineWidth',2,'Color',[55 103 149]/255);
hold on 
p3=plot(x,error(:,3),'-s','MarkerSize',18,'LineWidth',2,'Color',[150,195,125]/255);
hold on 
p4=plot(x,error(:,4),'-+','MarkerSize',18,'LineWidth',2,'Color',[189,30,30]/255);
hold on 
p5=plot(x,error(:,5),'-d','MarkerSize',18,'LineWidth',2,'Color',[255,190,122]/255);
hold on 

set(gca,'yscale','log');
lgd = legend('Optimal error','Spectral ini','Large random ini','Small random ini (best)','Small random ini (ES)');
lgd.FontSize = 14;
lgd.FontName = 'Times New Roman'; 

xlabel('\mathrm{over rank}\ $R$','FontName', 'Times New Roman', 'Interpreter', 'latex');
ylabel('Relative Square Error','FontName', 'Times New Roman');

box on
set(gca, 'FontSize', 12,'FontName', 'Times New Roman');        
set(gca, 'LineWidth', 2);        
set(gcf, 'Position', [100, 100, 600, 600]); 







function error=LRTR(n1,n2,n3,tubal_r,k,ite,var)

%% init X_*, A_i and y
m=5*tubal_r*(n1+n2-tubal_r)*n3;
U_star=randn(n1,tubal_r,n3);
X_star=tprod(U_star,tran(U_star));
X_star = X_star/norm(X_star(:));


s2 = generate_exponential_vector(m,1000);
error_GD = zeros(ite,1);
A=normrnd(0,sqrt(1/m),m,n1*n2*n3);
s=normrnd(0,var,m,1); % noise
y=A*X_star(:) + s2;

mu=1e-1;



%% optimal error
Ft = randn(n1,tubal_r,n3)*1e-10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini_optimal = min(error_GD);


%% spectral ini

Y = A'*y;
Y = reshape(Y,[n1,n2,n3]);
[U,S,~] = tsvd_r(Y,k);
S1 = fft(S,[],3);
S1 = S1.^0.5;
S = ifft(S1,[],3);

Ft = tprod(U(:,1:k,:),S(1:k,1:k,:));



for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;

    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_spectral_ini = min(error_GD);



%% Gradient descent with samll ini
Ft = randn(n1,k,n3)*1e-10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini = min(error_GD);




%% Gradient descent with moderate ini
mu = 5e-4;
Ft = randn(n1,k,n3)/n1^0.5*5;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_moderate_ini = min(error_GD);




%% Gradient descent with small ini
Ft = randn(n1,k,n3)*1e-10;
mu=1e-1;
for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini = min(error_GD);


%% GD with small ini and early stopping
train_size = floor(m*0.95);

A_train = A(1:train_size,:);
A_validate = A(train_size+1:m,:);
y_train = y(1:train_size);
y_validate = y(train_size+1:m);

Ft = randn(n1,k,n3)*1e-10;
loss_val = zeros(ite,1);
for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res_val = A_validate * Xt(:) - y_validate;
    loss_val(i) = norm(res_val);
    res = A_train * Xt(:) - y_train;
    A_star=reshape(A_train'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end
[~,idx] = min(loss_val);
error_small_ini_val = error_GD(idx);

error = [error_small_ini_optimal,error_spectral_ini,error_moderate_ini,error_small_ini,error_small_ini_val];
end

function x = generate_exponential_vector(m, lambda)
%GENERATE_EXPONENTIAL_VECTOR 

    if nargin < 2
        lambda = 1;  
    end

    x = exprnd(1 / lambda, [m, 1]);
end



