%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% Clear previous work
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
close all;
clc;

addpath('utils')
addpath('../../tensorlab/')
% fix random seed for reproducibility
rng(pi);

%%
I   = 3;    % the number of factor matrices
SNR = 0:5:25;   % SNR 
AVG = 50;
% initialize dimensions
N = 15*ones(I,1);Nbar = prod(N);
M = 12*ones(I,1);Mbar = prod(M);

K = 8; % s-sparse, #nonzeros in total

error = zeros(AVG,length(SNR),9);
time  = zeros(AVG,length(SNR),9);
srr   = zeros(AVG,length(SNR),9);

%%
for avg = 1:AVG

    avg

    % generate measuring dictionaries
    A = cell(1,I);
    for i = 1:I
        A{i} = randn(M(i),N(i));
        col_norm = vecnorm(A{i});
    end

    % generate sparse vector x
    tX = generate_sparse_tensor(N, K);
    % generate noisy measurements
    tY_ori         = tmprod(tX,A,1:I);
    control.tX     = tX; % for debugging 
    for snr = 1:length(SNR)
        signal_power   = norm(vec(tY_ori),'fro')^2/numel(tY_ori); %average signal power per asymbol
        n_var          = (signal_power)/(10^(SNR(snr)/10));
        noise          = sqrt(n_var)*randn(size(tY_ori));
        tY             = tY_ori + noise;

        tol_coef = 0.05;
        %% sparse vector estimation
        %% using mssbl: sequential
        result_sbl_se = tenMulReSBL(tY,A,N,M);
        disp('Multi-Stage SBL is done!')
        %% using msomp
        result_msomp = tenMulReOMP(tY, A, N, M, tol_coef);
        disp('Multi-Stage OMP is done!')
        %% Benchmark
        y_mea = vec(tY);
        %% using IHT
        result_iht = iht(A,y_mea,K);
        disp('IHT is done!')
        %% using HTP
        result_htp = htp(A,y_mea,K);
        disp('HTP is done!')
        %% using OMP
        result_omp = omp(A, vec(tY), tol_coef);
        disp('OMP is done!')
        %% using sbl
        result_sbl = sbl(A, vec(tY), 1e-3, 1);
        disp('SBL is done!')
        %% using l1
        tic;
        A_all = kron(A{3},kron(A{2},A{1}));
        D_l1 = sensingDictionary('CustomDictionary', A_all);
        result_l1 = cell(2,1);
        result_l1{1} = basisPursuit(D_l1, vec(tY),maxIterations=50);
        result_l1{2} = toc;
        disp('L1 is done!')
        %% using msiht: sequential
        result_msiht = tenMulReIHT(tY,A,N,M,K*ones(I,1));
        disp('Multi-Stage IHT is done!')
        %% using mshtp: sequential
        result_mshtp = tenMulReHTP_hi(tY,A,N,M,K*ones(I,1));
        disp('Multi-Stage HTP is done!')
        %% track results
        %%
        error(avg,snr,1) = (norm(vec(result_sbl_se{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,1) = result_sbl_se{2};
        srr(avg,snr,1) = recover_rate(vec(result_sbl_se{1}),vec(tX));
        %%
        error(avg,snr,2) = (norm(vec(result_msomp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,2) = result_msomp{2};
        srr(avg,snr,2) = recover_rate(vec(result_msomp{1}),vec(tX));
        %%
        error(avg,snr,3) = (norm(result_iht{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,3) = result_iht{2};
        srr(avg,snr,3) = recover_rate((result_iht{1}),vec(tX));
        %%
        error(avg,snr,4) = (norm(result_htp{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,4) = result_htp{2};
        srr(avg,snr,4) = recover_rate((result_htp{1}),vec(tX));
        %%
        error(avg,snr,5) = (norm(result_omp{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,5) = result_omp{2};
        srr(avg,snr,5) = recover_rate((result_omp{1}),vec(tX));
        %%
        error(avg,snr,6) = (norm(result_sbl{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,6) = result_sbl{2};
        srr(avg,snr,6) = recover_rate((result_sbl{1}),vec(tX));
        %%
        error(avg,snr,7) = (norm(result_l1{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,7) = result_l1{2};
        srr(avg,snr,7) = recover_rate((result_l1{1}),vec(tX));
        %%
        error(avg,snr,8) = (norm(vec(result_msiht{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,8) = result_msiht{2};
        srr(avg,snr,8) = recover_rate(vec(result_msiht{1}),vec(tX));
        %%
        error(avg,snr,9) = (norm(vec(result_mshtp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,9) = result_mshtp{2};
        srr(avg,snr,9) = recover_rate(vec(result_mshtp{1}),vec(tX));
    end
end

save compare_trad_s_sparse.mat

%%
load compare_trad_s_sparse.mat

all_colors = [
    0,      0.4470, 0.7410;
    0.9290, 0.6940, 0.1250;
    1,      0,      1;
    0.6350, 0.0780, 0.1840;
    0.4900, 0.1800, 0.5600;
    0.5, 0.5, 0.5;
    0.8500, 0.3250, 0.0980;
    0.3010, 0.7450, 0.9330;
    0.4660, 0.6740, 0.1880;
];
line_type_set{1} = '-o';
line_type_set{2} = '-x';
line_type_set{3} = '-.>';
line_type_set{4} = '-.h';
line_type_set{5} = '-<';
line_type_set{6} = '-^';
line_type_set{7} = '-v';
line_type_set{8} = '-.+';
line_type_set{9} = '-.*';

legend_type_set{1} = 'o';
legend_type_set{2} = 'x';
legend_type_set{3} = '>';
legend_type_set{4} = 'h';
legend_type_set{5} = '<';
legend_type_set{6} = '^';
legend_type_set{7} = 'v';
legend_type_set{8} = '+';
legend_type_set{9} = '*';


algo_name{1} = 'MSSBL';
algo_name{2} = 'MSOMP';
algo_name{3} = 'IHT';
algo_name{4} = 'HTP';
algo_name{5} = 'OMP';
algo_name{6} = 'SBL';
algo_name{7} = 'BPDN';
algo_name{8} = 'MSIHT';
algo_name{9} = 'MSHTP';

timeAVG = squeeze(mean(time,1));

errorTotal = squeeze(median(error, 1));
errorQ1 = squeeze(quantile(error, 0.25, 1));
errorQ3 = squeeze(quantile(error, 0.75, 1));


timeTotal = squeeze(median(time, 1));
timeQ1 = squeeze(quantile(time, 0.25, 1));
timeQ3 = squeeze(quantile(time, 0.75, 1));


srrTotal = squeeze(median(srr, 1));
srrQ1 = squeeze(quantile(srr, 0.25, 1));
srrQ3 = squeeze(quantile(srr, 0.75, 1));


fontsizeman = 20;
figure_position = [100, 100, 800, 600];
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultLineMarkerSize',fontsizeman);
set(0,'DefaultAxesFontSize',fontsizeman);
set(0,'DefaultAxesFontWeight','bold');

f_nmse_srr = figure('Position', figure_position);
% tiledlayout(1, 2, 'TileSpacing', 'compact');

ax1 = nexttile;
hold(ax1, 'on'); 
% line1 = plot(ax1, SNR,errorTotal(:,1),'-','Color',[0,0,0]);
% line2 = plot(ax1, SNR,errorTotal(:,2),'-.','Color',[0,0,0]);
for algo_index = 1:9

    upper_bound = errorQ3(:, algo_index);
    lower_bound = errorQ1(:, algo_index);

    fill_x = [SNR'; fliplr(SNR)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax1, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
    plot(ax1, SNR,errorTotal(:,algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
end
h1 = plot(ax1, SNR,errorTotal(:,1),legend_type_set{1},'Color',all_colors(1, :),'Display',algo_name{1});
h2 = plot(ax1, SNR,errorTotal(:,2),legend_type_set{2},'Color',all_colors(2, :),'Display',algo_name{2});
h3 = plot(ax1, SNR,errorTotal(:,3),legend_type_set{3},'Color',all_colors(3, :),'Display',algo_name{3});
h4 = plot(ax1, SNR,errorTotal(:,4),legend_type_set{4},'Color',all_colors(4, :),'Display',algo_name{4});
h5 = plot(ax1, SNR,errorTotal(:,5),legend_type_set{5},'Color',all_colors(5, :),'Display',algo_name{5});
h6 = plot(ax1, SNR,errorTotal(:,6),legend_type_set{6},'Color',all_colors(6, :),'Display',algo_name{6});
h7 = plot(ax1, SNR,errorTotal(:,7),legend_type_set{7},'Color',all_colors(7, :),'Display',algo_name{7});
h8 = plot(ax1, SNR,errorTotal(:,8),legend_type_set{8},'Color',all_colors(8, :),'Display',algo_name{8});
h9 = plot(ax1, SNR,errorTotal(:,9),legend_type_set{9},'Color',all_colors(9, :),'Display',algo_name{9});
set(ax1, 'YScale', 'log', 'XLim', [0 25], 'YLim', [1e-5 1e3], 'FontSize', fontsizeman, 'FontWeight', 'bold');
ylabel(ax1, 'NSE', 'FontSize', fontsizeman, 'FontWeight', 'bold');
xlabel(ax1, 'SNR (dB)')
yticks([1e-5 1e-4 1e-3 1e-2 1e-1 1e0 1e1 1e2])
grid(ax1, 'on');
box(ax1, 'on');
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultLineMarkerSize',fontsizeman);

% ax2 = nexttile;
% hold(ax2, 'on');
% for algo_index = 1:7
% 
%     upper_bound = timeQ3(:, algo_index);
%     lower_bound = timeQ1(:, algo_index);
% 
%     fill_x = [SNR'; fliplr(SNR)']';
%     fill_y = [lower_bound; fliplr(upper_bound')']';
%     fill(ax2, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
%     plot(ax2, SNR,timeTotal(:,algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
% end
% hold(ax2, 'off');
% set(ax2, 'XLim', [0 25], 'FontSize', fontsizeman, 'FontWeight', 'bold', 'Yscale', 'Log');
% ylabel(ax2, 'Time', 'FontSize', fontsizeman, 'FontWeight', 'bold');
% grid(ax2, 'on');
% box(ax2, 'on');
% xlabel(ax2, 'SNR (dB)', 'FontSize', fontsizeman, 'FontWeight', 'bold');

lgd = legend([h1 h2 h3 h4 h5 h6 h7 h8 h9],{algo_name{1},algo_name{2},algo_name{3},algo_name{4},algo_name{5},algo_name{6},algo_name{7},algo_name{8},algo_name{9}},'Location','northoutside','Orientation','horizontal','Interpreter','LaTex','Box','off','NumColumns',5);
% lgd.Position = [0.1, 0.9, 1, 0.15];

exportgraphics(f_nmse_srr, 's_stan_nmse_com_trad.pdf', 'ContentType', 'vector');
% %%
% tiledlayout(1, 3, 'TileSpacing', 'compact');
% ax1 = nexttile;
% imagesc(abs(result_sbl_se{1}))
% ax2 = nexttile;
% imagesc(abs(reshape(result_htp{1},[40,40])))
% ax3 = nexttile;
% imagesc(abs(reshape(result_sbl{1},[40,40])))
