clc;clear;close all;
addpath(genpath('comparison_methods'));




tubal_r=3; %% exact rank 
n1=30;n2=n1;n3=3; 
ite=2000;
repeat_time=20;

k = 5*tubal_r;
var = 0.7:0.02:0.99;

error = zeros(length(var),3);


for i =1:length(var)
    i
    for j =1:repeat_time

        error_temp = LRTR(n1,n2,n3,tubal_r,k,ite,var(i));
        error(i,:) = error(i,:) + error_temp;  
    end
end





x=var;
marker_indices = 1:2:length(var);

p1=plot(x,error(:,1),':o','MarkerIndices',marker_indices,'MarkerSize',18,'LineWidth',2,'Color',[255,0,0]/255);
hold on ;
p4=plot(x,error(:,2),'-+','MarkerIndices',marker_indices,'MarkerSize',18,'LineWidth',2,'Color',[189,30,30]/255);
hold on 
p5=plot(x,error(:,3),'-d','MarkerIndices',marker_indices,'MarkerSize',18,'LineWidth',2,'Color',[255,190,122]/255);
hold on 

set(gca,'yscale','log');
set(gca,'xscale','log');
lgd = legend('Optimal error','Small random ini (best)','Small random ini (ES)');
lgd.FontSize = 14;
lgd.FontName = 'Times New Roman';  

xlabel('$\mathrm{Train set/All}\ \nu$','FontName', 'Times New Roman', 'Interpreter', 'latex');
ylabel('Relative Square Error','FontName', 'Times New Roman');

box on
set(gca, 'FontSize', 12,'FontName', 'Times New Roman');         
set(gca, 'LineWidth', 2);       
set(gcf, 'Position', [100, 100, 600, 600]);  







function error=LRTR(n1,n2,n3,tubal_r,k,ite,var)

%% init X_*, A_i and y
m=5*tubal_r*(n1+n2-tubal_r)*n3;
U_star=randn(n1,tubal_r,n3);
X_star=tprod(U_star,tran(U_star));
X_star = X_star/norm(X_star(:));
% X_star = data_gen(n1,n2,n3,kappa,tubal_r);



error_GD = zeros(ite,1);



A=normrnd(0,sqrt(1/m),m,n1*n2*n3);
s=normrnd(0,1e-3,m,1); % noise
y=A*X_star(:) + s;

mu=1e-1;



%% optimal error
Ft = randn(n1,tubal_r,n3)/k^0.5*1e-10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini_optimal = min(error_GD);





%% Gradient descent with small ini
Ft = randn(n1,k,n3)/k^0.5*1e-10;

for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res = A * Xt(:) - y;
    A_star=reshape(A'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end

error_small_ini = min(error_GD);




%% GD with small ini and early stopping
train_size = floor(m*var);

A_train = A(1:train_size,:);
A_validate = A(train_size+1:m,:);
y_train = y(1:train_size);
y_validate = y(train_size+1:m);

Ft = randn(n1,k,n3)/k^0.5*1e-10;
loss_val = zeros(ite,1);
for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res_val = A_validate * Xt(:) - y_validate;
    loss_val(i) = norm(res_val);
    res = A_train * Xt(:) - y_train;
    A_star=reshape(A_train'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end
[~,idx] = min(loss_val);
error_small_ini_val = error_GD(idx);

error = [error_small_ini_optimal,error_small_ini,error_small_ini_val];
end





