%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% Clear previous work
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
close all;
clc;

addpath('utils')
addpath('../../tensorlab/')
% fix random seed for reproducibility
rng(pi);

%% generate synthetic data for testing
I   = 3;    % the number of factor matrices
AVG = 50;
% initialize dimensions
N = 15*ones(I,1);Nbar = prod(N);
M = 12*ones(I,1);Mbar = prod(M);
% sparsity pattern
K = 4*ones(I,1); % Kronecker support sparse, #nonzeros of each dimension
% SNR condition
SNR = 0:5:25; % Define the range of SNR values for testing

error = zeros(AVG,length(SNR),11);
time  = zeros(AVG,length(SNR),11);
srr   = zeros(AVG,length(SNR),11);

for avg = 1:AVG

    avg

    % generate measuring dictionaries
    A = cell(1,I);
    for i = 1:I
        A{i} = randn(M(i),N(i));
        col_norm = vecnorm(A{i});
    end

    % generate sparse vector x
    tX = generate_kro_supp_sparse_tensor(N, K);
    % generate noisy measurements
    tY_ori         = tmprod(tX,A,1:I);
    control.tX     = tX; % for debugging
    for snr = 1:length(SNR)
        signal_power   = norm(tens2mat(tY_ori,1),'fro')^2/numel(tY_ori); %average signal power per asymbol
        n_var          = (signal_power)/(10^(SNR(snr)/10));
        noise          = sqrt(n_var)*randn(size(tY_ori));
        tY             = tY_ori + noise;
        A_full         = kron(kron(A{3},A{2}),A{1});

        tol_coef = 0.05;
        %% sparse vector estimation
        %% using msOMP
        result_mtomp = tenMulReOMP_hi(tY,A,N,M,tol_coef);
        disp('Multi-Stage OMP is done!')
        %% using msHTP
        result_mthtp = tenMulReHTP_hi(tY,A,N,M,K);
        disp('Multi-Stage HTP is done!')
        %% using msIHT
        result_mtiht = tenMulReIHT(tY,A,N,M,K);
        disp('Multi-Stage IHT is done!')
        %% using msSBL
        result_mtsbl = tenMulReSBL_hi(tY,A,N,M);
        disp('Multi-Stage SBL is done!')
        %% Benchmark 1: SVD-KroSBL
        % result_svd = svd_kroSBL3(1e-3,vec(tY), A{3}, A{2}, A{1}, A_full, N(1), 200);
        % disp('SVD-KroSBL is done!')
        %% Benckmark 2: AM-KroSBL
        % result_am = am_kroSBL3(1e-3,vec(tY), A{3}, A{2}, A{1}, A_full, N(1), 200);
        % disp('AM-KroSBL is done!')
        %% Benchmark 3: HTP
        result_htp = htp(A_full, vec(tY), prod(K));
        disp('HTP is done!')
        %% Benchmark 4: IHT
        result_iht = iht(A_full, vec(tY), prod(K));
        disp('iht is done!')
        %% Benchmark 5: using sbl
        result_sbl = sbl(A_full, vec(tY), 1e-3, 1);
        disp('SBL is done!')
        %% Benchmark 6: using omp
        result_omp = omp(A_full, vec(tY), tol_coef);
        disp('OMP is done!')
        %% Benchmark 7: using l1
        tic;
        D_l1 = sensingDictionary('CustomDictionary', A_full);
        result_l1 = cell(2,1);
        result_l1{1} = basisPursuit(D_l1, vec(tY),maxIterations=50);
        result_l1{2} = toc;
        disp('L1 is done!')
        %%
        error(avg,snr,1) = (norm(vec(result_mtomp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,1) = result_mtomp{2};
        srr(avg,snr,1) = recover_rate(vec(result_mtomp{1}),vec(tX));

        error(avg,snr,2) = (norm(vec(result_mthtp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,2) = result_mthtp{2};
        srr(avg,snr,2) = recover_rate(vec(result_mthtp{1}),vec(tX));

        error(avg,snr,3) = (norm(vec(result_mtsbl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,3) = result_mtsbl{2};
        srr(avg,snr,3) = recover_rate(vec(result_mtsbl{1}),vec(tX));

        % error(avg,snr,4) = (norm(result_svd{1,2} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        % time(avg,snr,4) = result_svd{2,2};
        % srr(avg,snr,4) = recover_rate(result_svd{1,2},vec(tX));
        % 
        % error(avg,snr,5) = (norm(result_am{1,2} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        % time(avg,snr,5) = result_am{2,2};
        % srr(avg,snr,5) = recover_rate(result_am{1,2},vec(tX));

        error(avg,snr,6) = (norm(result_htp{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,6) = result_htp{2};
        srr(avg,snr,6) = recover_rate(vec(result_htp{1}),vec(tX));

        error(avg,snr,7) = (norm(vec(result_mtiht{1} - (tX)), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,7) = result_mtiht{2};
        srr(avg,snr,7) = recover_rate(vec(result_mtiht{1}),vec(tX));

        error(avg,snr,8) = (norm((result_iht{1} - vec(tX)), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,8) = result_iht{2};
        srr(avg,snr,8) = recover_rate(vec(result_iht{1}),vec(tX));

        error(avg,snr,9) = (norm(result_sbl{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,9) = result_sbl{2};
        srr(avg,snr,9) = recover_rate(result_sbl{1},vec(tX));

        error(avg,snr,10) = (norm(result_omp{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,10) = result_omp{2};
        srr(avg,snr,10) = recover_rate(result_omp{1},vec(tX));
        %%
        error(avg,snr,11) = (norm(result_l1{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,11) = result_l1{2};
        srr(avg,snr,11) = recover_rate((result_l1{1}),vec(tX));
    end
end

save compare_trad_kro_sup.mat

%%

load compare_trad_kro_sup.mat
all_colors = [
    0.9290, 0.6940, 0.1250;
    0.4660, 0.6740, 0.1880;
    0,      0.4470, 0.7410;
    0.6350, 0.0780, 0.1840; 
    0.4900, 0.1800, 0.5600; 
    0.6350, 0.0780, 0.1840;
    0.3010, 0.7450, 0.9330;
    1,      0,      1;
    0.5, 0.5, 0.5;
    0.4900, 0.1800, 0.5600;
    0.8500, 0.3250, 0.0980;
];

line_type_set{1} = '-x';
line_type_set{2} = '-.*';
line_type_set{3} = '-o';
line_type_set{6} = '-.h';
line_type_set{7} = '-.+';
line_type_set{8} = '-.>';
line_type_set{9} = '-^';
line_type_set{10} = '-<';
line_type_set{11} = '-v';

legend_type_set{1} = 'x';
legend_type_set{2} = '*';
legend_type_set{3} = 'o';
legend_type_set{6} = 'h';
legend_type_set{7} = '+';
legend_type_set{8} = '>';
legend_type_set{9} = '^';
legend_type_set{10} = '<';
legend_type_set{11} = 'v';


algo_name{1} = 'MSOMP';
algo_name{2} = 'MSHTP';
algo_name{3} = 'MSSBL';
algo_name{6} = 'HTP';
algo_name{7} = 'MSIHT';
algo_name{8} = 'IHT';
algo_name{9} = 'SBL';
algo_name{10} = 'OMP';
algo_name{11} = 'BPDN';

timeAVG = squeeze(mean(time,1));

errorTotal = squeeze(median(error, 1));
errorQ1 = squeeze(quantile(error, 0.25, 1));
errorQ3 = squeeze(quantile(error, 0.75, 1));

timeTotal = squeeze(median(time, 1));
timeQ1 = squeeze(quantile(time, 0.25, 1));
timeQ3 = squeeze(quantile(time, 0.75, 1));


srrTotal = squeeze(median(srr, 1));
srrQ1 = squeeze(quantile(srr, 0.25, 1));
srrQ3 = squeeze(quantile(srr, 0.75, 1));


fontsizeman = 20;
figure_position = [100, 100, 800, 600];
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultLineMarkerSize',fontsizeman);
set(0,'DefaultAxesFontSize',fontsizeman);
set(0,'DefaultAxesFontWeight','bold');

f_nmse_srr = figure('Position', figure_position);
% tiledlayout(1, 2, 'TileSpacing', 'compact');

ax1 = gca;%nexttile;
hold(ax1, 'on'); 
% line1 = plot(ax1, SNR,errorTotal(:,1),'-','Color',[0,0,0]);
% line2 = plot(ax1, SNR,errorTotal(:,2),'-.','Color',[0,0,0]);
for algo_index = [1,2,3,6,7,8,9,10,11]

    upper_bound = errorQ3(:, algo_index);
    lower_bound = errorQ1(:, algo_index);

    fill_x = [SNR'; fliplr(SNR)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax1, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
    plot(ax1, SNR,errorTotal(:,algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
end
h1 = plot(ax1, SNR,errorTotal(:,1),legend_type_set{1},'Color',all_colors(1, :),'Display',algo_name{1});
h2 = plot(ax1, SNR,errorTotal(:,2),legend_type_set{2},'Color',all_colors(2, :),'Display',algo_name{2});
h3 = plot(ax1, SNR,errorTotal(:,3),legend_type_set{3},'Color',all_colors(3, :),'Display',algo_name{3});
% h4 = plot(ax1, SNR,errorTotal(:,4),legend_type_set{4},'Color',all_colors(4, :),'Display',algo_name{4});
% h5 = plot(ax1, SNR,errorTotal(:,5),legend_type_set{5},'Color',all_colors(5, :),'Display',algo_name{5});
h6 = plot(ax1, SNR,errorTotal(:,6),legend_type_set{6},'Color',all_colors(6, :),'Display',algo_name{6});
h7 = plot(ax1, SNR,errorTotal(:,7),legend_type_set{7},'Color',all_colors(7, :),'Display',algo_name{7});
h8 = plot(ax1, SNR,errorTotal(:,8),legend_type_set{8},'Color',all_colors(8, :),'Display',algo_name{8});
h9 = plot(ax1, SNR,errorTotal(:,9),legend_type_set{9},'Color',all_colors(9, :),'Display',algo_name{9});
h10 = plot(ax1, SNR,errorTotal(:,10),legend_type_set{10},'Color',all_colors(10, :),'Display',algo_name{10});
h11 = plot(ax1, SNR,errorTotal(:,11),legend_type_set{11},'Color',all_colors(11, :),'Display',algo_name{11});
set(ax1, 'YScale', 'log', 'XLim', [0 25], 'YLim', [1e-4 1e3], 'FontSize', fontsizeman, 'FontWeight', 'bold');
ylabel(ax1, 'NSE', 'FontSize', fontsizeman, 'FontWeight', 'bold');
xlabel(ax1, 'SNR (dB)')
grid(ax1, 'on');
box(ax1, 'on');
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultLineMarkerSize',fontsizeman);

% ax2 = nexttile;
% hold(ax2, 'on');
% for algo_index = 1:10
% 
%     upper_bound = timeQ3(:, algo_index);
%     lower_bound = timeQ1(:, algo_index);
% 
%     fill_x = [SNR'; fliplr(SNR)']';
%     fill_y = [lower_bound; fliplr(upper_bound')']';
%     fill(ax2, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
%     plot(ax2, SNR,timeTotal(:,algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
% end
% hold(ax2, 'off');
% set(ax2, 'XLim', [0 25], 'FontSize', fontsizeman, 'FontWeight', 'bold', 'Yscale', 'Log');
% ylabel(ax2, 'SRR', 'FontSize', fontsizeman, 'FontWeight', 'bold');
% grid(ax2, 'on');
% box(ax2, 'on');
% xlabel(ax2, 'SNR (dB)', 'FontSize', fontsizeman, 'FontWeight', 'bold');

lgd = legend([h1 h2 h3 h6 h7 h8 h9 h10 h11],{algo_name{1},algo_name{2},algo_name{3},algo_name{6},algo_name{7},algo_name{8},algo_name{9},algo_name{10},algo_name{11}},'Location','northoutside','Orientation','horizontal','Interpreter','LaTex','Box','off','NumColumns',6);
% lgd.Position = [0.15, 0.9, 1, 0.15];

exportgraphics(f_nmse_srr, 's_kro_nmse_com_trad.pdf', 'ContentType', 'vector');