clc;clear;close all;
addpath(genpath('comparison_methods'));


tubal_r=3; %% exact rank 
n1=30;n2=n1;n3=3; 
ite=20000;
repeat_time=20;

k = 3*tubal_r;
var = 1e-3;


rate = 1:0.5:5;
error = zeros(length(rate),5);
for i =1:length(rate)
    for j =1:repeat_time
        error_temp = LRTR(n1,n2,n3,tubal_r,k,ite,var,rate(i));
        error(i,:) = error(i,:) + error_temp;  
    end
end





x=rate;
p1=plot(x,error(:,1),':o','MarkerSize',18,'LineWidth',2,'Color',[255,0,0]/255);
hold on ;
p2=plot(x,error(:,2),'-*','MarkerSize',18,'LineWidth',2,'Color',[55 103 149]/255);
hold on 
p3=plot(x,error(:,3),'-s','MarkerSize',18,'LineWidth',2,'Color',[150,195,125]/255);
hold on 
p4=plot(x,error(:,4),'-+','MarkerSize',18,'LineWidth',2,'Color',[189,30,30]/255);
hold on 
p5=plot(x,error(:,5),'-d','MarkerSize',18,'LineWidth',2,'Color',[255,190,122]/255);
hold on 

set(gca,'yscale','log');
lgd = legend('Optimal error','Spectral ini','Large random ini','Small random ini (best)','Small random ini (ES)');
lgd.FontSize = 14;
lgd.FontName = 'Times New Roman';  

xlabel('\mathrm{over rank}\ $R$','FontName', 'Times New Roman', 'Interpreter', 'latex');
ylabel('Relative Square Error','FontName', 'Times New Roman');

box on
set(gca, 'FontSize', 12,'FontName', 'Times New Roman');         
set(gca, 'LineWidth', 2);        
set(gcf, 'Position', [100, 100, 600, 600]);  







function error=LRTR(n1,n2,n3,tubal_r,k,ite,var,rate)

%% init X_*, A_i and y
m=rate*tubal_r*(n1+n2-tubal_r)*n3;
U_star=randn(n1,tubal_r,n3);
X_star=tprod(U_star,tran(U_star));
X_star = X_star/norm(X_star(:));



error_GD = zeros(ite,1);
A=normrnd(0,sqrt(1/m),m,n1*n2*n3);
s=normrnd(0,var,m,1); % noise
y=A*X_star(:) + s;

mu=1e-2;



%% optimal error
Ft = randn(n1,tubal_r,n3)/k^0.5*1e-10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini_optimal = min(error_GD);


%% spectral ini

Y = A'*y;
Y = reshape(Y,[n1,n2,n3]);
[U,S,~] = tsvd_r(Y,k);
S1 = fft(S,[],3);
S1 = S1.^0.5;
S = ifft(S1,[],3);

Ft = tprod(U(:,1:k,:),S(1:k,1:k,:));



for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;

    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_spectral_ini = min(error_GD);



%% Gradient descent with samll ini
Ft = randn(n1,k,n3)/k^0.5*1e-10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini = min(error_GD);




%% Gradient descent with moderate ini
Ft = randn(n1,k,n3)/k^0.5*10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_moderate_ini = min(error_GD);




%% Gradient descent with small ini
Ft = randn(n1,k,n3)/k^0.5*1e-10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini = min(error_GD);


%% GD with small ini and early stopping
train_size = floor(m*0.95);

A_train = A(1:train_size,:);
A_validate = A(train_size+1:m,:);
y_train = y(1:train_size);
y_validate = y(train_size+1:m);

Ft = randn(n1,k,n3)/k^0.5*1e-10;
loss_val = zeros(ite,1);
for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res_val = A_validate * Xt(:) - y_validate;
    loss_val(i) = norm(res_val);
    res = A_train * Xt(:) - y_train;
    A_star=reshape(A_train'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end
[~,idx] = min(loss_val);
error_small_ini_val = error_GD(idx);

error = [error_small_ini_optimal,error_spectral_ini,error_moderate_ini,error_small_ini,error_small_ini_val];
end





