clc;clear;close all;
addpath(genpath('comparison_methods'));




tubal_r=3; %% exact rank 
k=tubal_r*3; %% over estimated rank
n1=30;n2=n1;n3=3; 
ite=10000;
repeat_time=1;





error_GD=zeros(ite,1);
loss_GD=zeros(ite,1);

for i =1:repeat_time
    [error_GD_temp,loss_val_temp] = LRTR(n1,n2,n3,tubal_r,k,ite);
    error_GD =   error_GD + error_GD_temp;
    loss_GD =   loss_GD + loss_val_temp;
end

error_GD =  error_GD/repeat_time;
loss_GD = loss_GD/repeat_time;



marker_indices = 1:ite/10:ite;
x=1:ite;


hold on;


p1=plot(x,error_GD,'-s','MarkerIndices',marker_indices,'MarkerSize',18,'LineWidth',2,'Color',[189,30,30]/255);
p2=plot(x,loss_GD,'-+','MarkerIndices',marker_indices,'MarkerSize',18,'LineWidth',2,'Color',[55 103 149]/255);


[minVal, idx] = min(loss_GD);
x_min = x(idx);
y_min = loss_GD(idx);


plot(x_min, y_min, 'bo', 'MarkerSize', 16, 'MarkerFaceColor', 'r');

text(x_min, y_min, sprintf('Min validation loss: (%.0f, %.4f)', x_min, y_min), ...
     'VerticalAlignment', 'top', 'HorizontalAlignment', 'left', ...
     'FontSize', 16);


x_min = x(idx);
y_min = error_GD(idx);


plot(x_min, y_min, 'ro', 'MarkerSize', 16, 'MarkerFaceColor', 'r');


text(x_min, y_min, sprintf('Min RSE: (%.0f, %.4f)', x_min, y_min), ...
     'VerticalAlignment', 'top', 'HorizontalAlignment', 'left', ...
     'FontSize', 16);






set(gca,'yscale','log');
% ylim([1e-5,1]);


lgd = legend('relative square error','validation loss');
lgd.FontSize = 14;
lgd.FontName = 'Times New Roman';  

xlabel('Iterations','FontName', 'Times New Roman');
ylabel('Relative Error','FontName', 'Times New Roman');
% ylim([1e-10,10]);
% xlim([0,3e2])
box on
set(gca, 'FontSize', 12,'FontName', 'Times New Roman');          
set(gca, 'LineWidth', 2);        


set(gcf, 'Position', [100, 100, 600, 600]);  





function [error_GD,loss_val]=LRTR(n1,n2,n3,tubal_r,k,ite)

%% init X_*, A_i and y
m=1*tubal_r*(n1+n2-tubal_r)*n3;
U_star=randn(n1,tubal_r,n3);
X_star=tprod(U_star,tran(U_star));
X_star = X_star/norm(X_star(:));




A=normrnd(0,sqrt(1/m),m,n1*n2*n3);
s=normrnd(0,1e-3,m,1); % noise
y=A*X_star(:) + s;


error_GD=zeros(ite,1);
loss_GD = zeros(ite,1);

%% Gradient descent with samll ini and early stopping

mu = 1e-2;

train_size = floor(m*0.9);
A_train = A(1:train_size,:);
A_validate = A(train_size+1:m,:);
y_train = y(1:train_size);
y_validate = y(train_size+1:m);

Ft = randn(n1,k,n3)/k^0.5*1e-10;
loss_val = zeros(ite,1);
for i=1:ite

    Xt = tprod(Ft,tran(Ft));
    error_GD(i)=norm(X_star-Xt,'fro')^2/norm(X_star,'fro')^2;
    res_val = A_validate * Xt(:) - y_validate;
    loss_val(i) = norm(res_val);
    res = A_train * Xt(:) - y_train;
    A_star=reshape(A_train'*res,[n1,n1,n3]);
    Ft = Ft- mu * tprod(A_star,Ft);
end





end





