%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% Clear previous work
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
close all;
clc;

addpath('utils')
addpath('../tensorlab/')
if isempty(gcp('nocreate'))
    parpool('threads');  % Start with default settings
end
% fix random seed for reproducibility
rng(pi);

%% generate synthetic data for testing
i_candidate = 1:4;    % the number of factor matrices
AVG = 50;
SNR = 20; % Define the range of SNR values for testing


n_candidate = 50;
ratio_m = 0.6;
ratio_k = 0.4;

error = zeros(AVG,length(i_candidate),6);
time  = zeros(AVG,length(i_candidate),6);
srr   = zeros(AVG,length(i_candidate),6);

for avg = 1:AVG

    avg

    for i_idx = 1:length(i_candidate)
        I = i_candidate(i_idx);

        m_candidate = ceil((n_candidate^I * ratio_m )^(1/I));

        % initialize dimensions
        N = n_candidate*ones(I,1);Nbar = prod(N);
        M = m_candidate*ones(I,1);Mbar = prod(M);
        % sparsity pattern
        K = ratio_k * n_candidate * ones(I,1); % hierarchical sparse, #nonzeros of each dimension

        % generate measuring dictionaries
        A = cell(1,I);
        for i = 1:I
            A{i} = randn(M(i),N(i));
        end

        if I == 1
            % generate sparse vector x
            tX = zeros(N, 1);
            supp = randsample(N, K);
            tX(supp) = randn(K,1);
            % generate noisy measurements
            tY_ori         = A{1} * tX;
            control.tX     = tX; % for debugging
            signal_power   = norm(vec(tY_ori),'fro')^2/numel(tY_ori); %average signal power per asymbol
            n_var          = (signal_power)/(10^(SNR/10));
            noise          = sqrt(n_var)*randn(size(tY_ori));
            tY             = tY_ori + noise;
            %% sparse vector estimation: normal ones
            %% using SBL: sequential
            result_mtsbl = cell(2,1);
            tic;
            result_mtsbl{1} = MSBL_ori(A{1},tY,1e-3,1);
            result_mtsbl{2} = toc;
            disp('Multi-Stage SBL-seq is done!')
            %% case 2: with the knowledge of sparsity level
            %% using HTP
            result_mthtp = cell(2,1);
            tic;
            result_mthtp{1} = mmv_htp(A{1},tY,K);
            result_mthtp{2} = toc;
            disp('Multi-Stage HTP-seq is done!')
            %% using IHT
            result_mtiht = cell(2,1);
            tic;
            result_mtiht{1} = iht_mmv(A{1},tY,K);
            result_mtiht{2} = toc;
            disp('Multi-Stage IHT-seq is done!')
        else
            % generate sparse vector x
            tX = generate_kro_supp_sparse_tensor(N, K);
            % generate noisy measurements
            tY_ori         = tmprod(tX,A,1:I);
            control.tX     = tX; % for debugging
            signal_power   = norm(vec(tY_ori),'fro')^2/numel(tY_ori); %average signal power per asymbol
            n_var          = (signal_power)/(10^(SNR/10));
            noise          = sqrt(n_var)*randn(size(tY_ori));
            tY             = tY_ori + noise;
            %% sparse vector estimation
            %% using SBL: sequential
            result_mtsbl = tenMulReSBL_hi(tY,A,N,M);
            disp('Multi-Stage SBL-seq is done!')
            %% using SBL: parallel
            result_mtsbl_pl = tenMulReSBL_hi_pl(tY,A,N,M);
            disp('Multi-Stage SBL-pl is done!')
            %% case 2: with the knowledge of sparsity level
            %% using HTP
            result_mthtp = tenMulReHTP_hi(tY,A,N,M,K);
            disp('Multi-Stage HTP-seq is done!')
            %% using HTP
            result_mthtp_pl = tenMulReHTP_hi_pl(tY,A,N,M,K);
            disp('Multi-Stage HTP-pl is done!')
            %% using IHT
            result_mtiht = tenMulReIHT(tY,A,N,M,K);
            disp('Multi-Stage IHT-seq is done!')
            %% using IHT
            result_mtiht_pl = tenMulReIHT_pl(tY,A,N,M,K);
            disp('Multi-Stage IHT-pl is done!')
        end
        %%
        error(avg,i_idx,1) = (norm(vec(result_mtsbl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,i_idx,1) = result_mtsbl{2};
        srr(avg,i_idx,1) = recover_rate(vec(result_mtsbl{1}),vec(tX));

        if I > 1

            error(avg,i_idx,2) = (norm(vec(result_mtsbl_pl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
            time(avg,i_idx,2) = result_mtsbl_pl{2};
            srr(avg,i_idx,2) = recover_rate(vec(result_mtsbl_pl{1}),vec(tX));
        end
        %%
        error(avg,i_idx,3) = (norm(vec(result_mthtp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,i_idx,3) = result_mthtp{2};
        srr(avg,i_idx,3) = recover_rate(vec(result_mthtp{1}),vec(tX));

        if I > 1

            error(avg,i_idx,4) = (norm(vec(result_mthtp_pl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
            time(avg,i_idx,4) = result_mthtp_pl{2};
            srr(avg,i_idx,4) = recover_rate(vec(result_mthtp_pl{1}),vec(tX));

        end

        error(avg,i_idx,5) = (norm(vec(result_mtiht{1} - (tX)), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,i_idx,5) = result_mtiht{2};
        srr(avg,i_idx,5) = recover_rate(vec(result_mtiht{1}),vec(tX));

        if I > 1

            error(avg,i_idx,6) = (norm(vec(result_mtiht_pl{1} - (tX)), 'fro')/norm(vec(tX),'fro'))^2;
            time(avg,i_idx,6) = result_mtiht_pl{2};
            srr(avg,i_idx,6) = recover_rate(vec(result_mtiht_pl{1}),vec(tX));

        end
    end
end
save hie_s_varying_i.mat

%% plot
clc
clear
load hie_s_varying_i.mat

all_colors = [
    0,      0.4470, 0.7410;
    0.9290, 0.6940, 0.1250;
    1,      0,      1;
    0.4900, 0.1800, 0.5600;
    0.5, 0.5, 0.5;
    0.8500, 0.3250, 0.0980
    ];
line_type_set{1} = '-o';
line_type_set{2} = '-x';
line_type_set{3} = '->';
line_type_set{4} = '-<';
line_type_set{5} = '-^';
line_type_set{6} = '-v';
legend_type_set{1} = 'o';
legend_type_set{2} = 'x';
legend_type_set{3} = '>';
legend_type_set{4} = '<';
legend_type_set{5} = '^';
legend_type_set{6} = 'v';
algo_name{1} = 'MSSBL-Seq';
algo_name{2} = 'MSSBL-Pl';
algo_name{3} = 'MSHTP-Seq';
algo_name{4} = 'MSHTP-Pl';
algo_name{5} = 'MSIHT-Seq';
algo_name{6} = 'MSIHT-Pl';

errorTotal = squeeze(median(error, 1));
errorQ1 = squeeze(quantile(error, 0.25, 1));
errorQ3 = squeeze(quantile(error, 0.75, 1));

timeTotal = squeeze(median(time, 1));
timeQ1 = squeeze(quantile(time, 0.25, 1));
timeQ3 = squeeze(quantile(time, 0.75, 1));

% we run i=2,3,4 first and then i=1, thus we need to merge these two sets.
% this is purely because we have the results for i=2,3,4 first and then we
% want to add i=1. Simulations for i=1,2,3,4 can be run directly. Results
% can be different but should follow similar trends
load hie_s_varying_i_1.mat

errorTotal_1 = squeeze(median(error, 1)).';
errorQ1_1 = squeeze(quantile(error, 0.25, 1)).';
errorQ3_1 = squeeze(quantile(error, 0.75, 1)).';

timeTotal_1 = squeeze(median(time, 1)).';
timeQ1_1 = squeeze(quantile(time, 0.25, 1)).';
timeQ3_1 = squeeze(quantile(time, 0.75, 1)).';


errorTotal = [errorTotal_1;errorTotal];
errorQ1 = [errorQ1_1;errorQ1];
errorQ3 = [errorQ3_1;errorQ3];

timeTotal = [timeTotal_1;timeTotal];
timeQ1 = [timeQ1_1;timeQ1];
timeQ3 = [timeQ3_1;timeQ3];

i_candidate = [1,2,3,4];
dimension = i_candidate;

fontsizeman = 20;
figure_position = [100, 100, 1200, 400];
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultLineMarkerSize',fontsizeman);
set(0,'DefaultAxesFontSize',fontsizeman);
set(0,'DefaultAxesFontWeight','bold');

f_nmse = figure('Position', figure_position);

tiledlayout(1, 2, 'TileSpacing', 'compact');

ax1 = nexttile;
hold(ax1, 'on');

for algo_index = [1,3,5]

    upper_bound = errorQ3(:, algo_index);
    lower_bound = errorQ1(:, algo_index);
    fill_x = [dimension'; fliplr(dimension)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax1, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
    plot(ax1, dimension,errorTotal(:,algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
end
% h1 = plot(ax1, dimension,errorTotal(:,1),legend_type_set{1},'Color',all_colors(1, :),'Display',algo_name{1});
% h3 = plot(ax1, dimension,errorTotal(:,3),legend_type_set{3},'Color',all_colors(3, :),'Display',algo_name{3});
% h5 = plot(ax1, dimension,errorTotal(:,5),legend_type_set{5},'Color',all_colors(5, :),'Display',algo_name{5});
% legend(ax1, [h1 h3 h5],{algo_name{1},algo_name{3},algo_name{5}},'Location','southwest','Interpreter','LaTex','NumColumns',1,'Box','off');
grid(ax1, 'on');
set(ax1, 'YScale', 'log');
box(ax1, 'on');
tick_positions = i_candidate;
tick_labels = {'1', '2', '3', '4'};
xticks(ax1, tick_positions);
xticklabels(ax1, tick_labels);
xlim(ax1, [1,4]);
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultAxesFontSize',fontsizeman);
set(0,'DefaultLineMarkerSize',fontsizeman);
set(0,'DefaultAxesFontWeight','bold');
set(ax1,'FontSize',fontsizeman, 'FontWeight', 'bold');
ylabel(ax1, 'NSE', 'FontSize', fontsizeman, 'FontWeight', 'bold');
xlabel(ax1, 'System Order $I$', 'FontSize', fontsizeman, 'FontWeight', 'bold', 'Interpreter', 'latex');

% exportgraphics(f_nmse,'dim_nmse_I.pdf','ContentType', 'vector');

% f_time = figure('Position', figure_position);
ax2 = nexttile;
hold(ax2, 'on');
for algo_index = [1,3,5]
    upper_bound = timeQ3(:, algo_index);
    lower_bound = timeQ1(:, algo_index);
    fill_x = [dimension'; fliplr(dimension)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax2, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
    plot(ax2, dimension,timeTotal(:,algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
end
for algo_index = [2,4,6]
    upper_bound = timeQ3([2:end], algo_index);
    lower_bound = timeQ1([2:end], algo_index);
    fill_x = [dimension([2:end])'; fliplr(dimension([2:end]))']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax2, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
    plot(ax2, dimension([2:end]),timeTotal([2:end],algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
end

h1 = plot(ax2, dimension,timeTotal(:,1),legend_type_set{1},'Color',all_colors(1, :),'Display',algo_name{1});
h2 = plot(ax2, dimension([2:end]),timeTotal([2:end],2),legend_type_set{2},'Color',all_colors(2, :),'Display',algo_name{2});
h3 = plot(ax2, dimension,timeTotal(:,3),legend_type_set{3},'Color',all_colors(3, :),'Display',algo_name{3});
h4 = plot(ax2, dimension([2:end]),timeTotal([2:end],4),legend_type_set{4},'Color',all_colors(4, :),'Display',algo_name{4});
h5 = plot(ax2, dimension,timeTotal(:,5),legend_type_set{5},'Color',all_colors(5, :),'Display',algo_name{5});
h6 = plot(ax2, dimension([2:end]),timeTotal([2:end],6),legend_type_set{6},'Color',all_colors(6, :),'Display',algo_name{6});
lgd = legend(ax2, [h1 h2 h3 h4 h5 h6],{algo_name{1},algo_name{2},algo_name{3},algo_name{4},algo_name{5},algo_name{6}},'Interpreter','LaTex','NumColumns',6,'Box','off');
% lgd.Position = [0.1, 0.94, 0.8, 0.05];
lgd.Layout.Tile = 'north';
grid(ax2, 'on');
set(ax2, 'YScale', 'log');
box(ax2, 'on');
xticks(ax2, tick_positions);
xticklabels(ax2, tick_labels);
xlim(ax2, [1,4]);
ylim(ax2, [1e-3 1e2]);
yticks([1e-3,1e-2,1e-1,1e0,1e1,1e2])
set(ax2,'FontSize',fontsizeman, 'FontWeight', 'bold');
ylabel(ax2, 'Time', 'FontSize', fontsizeman, 'FontWeight', 'bold');
xlabel(ax2, 'System Order $I$', 'FontSize', fontsizeman, 'FontWeight', 'bold', 'Interpreter', 'latex');

exportgraphics(f_nmse,'dim_I.pdf','ContentType', 'vector');