%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% Clear previous work
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
close all;
clc;

addpath('utils')
addpath('../tensorlab/')
if isempty(gcp('nocreate'))
    parpool('threads');  % Start with default settings
end
% fix random seed for reproducibility
rng(pi);

%% generate synthetic data for testing
I   = 3;    % the number of factor matrices
AVG = 50;
SNR = 20; % Define the range of SNR values for testing


n_candidate = 50:10:110;
ratio_m = 0.6;
ratio_k = 0.4;



error = zeros(AVG,length(n_candidate),6);
time  = zeros(AVG,length(n_candidate),6);
srr   = zeros(AVG,length(n_candidate),6);


for avg = 1:AVG

    avg

    for n_idx = 1:length(n_candidate)

        m_candidate = ceil((n_candidate(n_idx)^I * ratio_m )^(1/I));

        % initialize dimensions
        N = n_candidate(n_idx)*ones(I,1);Nbar = prod(N);
        M = m_candidate*ones(I,1);Mbar = prod(M);
        % sparsity pattern
        K = ratio_k * n_candidate(n_idx) * ones(I,1); % hierarchical sparse, #nonzeros of each dimension

        % generate measuring dictionaries
        A = cell(1,I);
        for i = 1:I
            A{i} = randn(M(i),N(i));
            % col_norm = vecnorm(A{i});
        end

        % generate sparse vector x
        tX = generate_hierarchical_sparse_tensor(N, K);
        % generate noisy measurements
        tY_ori         = tmprod(tX,A,1:I);
        control.tX     = tX; % for debugging
        signal_power   = norm(vec(tY_ori),'fro')^2/numel(tY_ori); %average signal power per asymbol
        n_var          = (signal_power)/(10^(SNR/10));
        noise          = sqrt(n_var)*randn(size(tY_ori));
        tY             = tY_ori + noise;

        %% sparse vector estimation
        %% using SBL: sequential
        result_mtsbl = tenMulReSBL_hi(tY,A,N,M);
        disp('Multi-Stage SBL-seq is done!')
        %% using SBL: parallel
        result_mtsbl_pl = tenMulReSBL_hi_pl(tY,A,N,M);
        disp('Multi-Stage SBL-pl is done!')
        %% case 2: with the knowledge of sparsity level
        %% using HTP
        result_mthtp = tenMulReHTP_hi(tY,A,N,M,K);
        disp('Multi-Stage HTP-seq is done!')
        %% using HTP
        result_mthtp_pl = tenMulReHTP_hi_pl(tY,A,N,M,K);
        disp('Multi-Stage HTP-pl is done!')
        %% using IHT
        result_mtiht = tenMulReIHT(tY,A,N,M,K);
        disp('Multi-Stage IHT-seq is done!')
        %% using IHT
        result_mtiht_pl = tenMulReIHT_pl(tY,A,N,M,K);
        disp('Multi-Stage IHT-pl is done!')
        %%
        error(avg,n_idx,1) = (norm(vec(result_mtsbl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,n_idx,1) = result_mtsbl{2};
        srr(avg,n_idx,1) = recover_rate(vec(result_mtsbl{1}),vec(tX));

        error(avg,n_idx,2) = (norm(vec(result_mtsbl_pl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,n_idx,2) = result_mtsbl_pl{2};
        srr(avg,n_idx,2) = recover_rate(vec(result_mtsbl_pl{1}),vec(tX));

        error(avg,n_idx,3) = (norm(vec(result_mthtp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,n_idx,3) = result_mthtp{2};
        srr(avg,n_idx,3) = recover_rate(vec(result_mthtp{1}),vec(tX));

        error(avg,n_idx,4) = (norm(vec(result_mthtp_pl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,n_idx,4) = result_mthtp_pl{2};
        srr(avg,n_idx,4) = recover_rate(vec(result_mthtp_pl{1}),vec(tX));

        error(avg,n_idx,5) = (norm(vec(result_mtiht{1} - (tX)), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,n_idx,5) = result_mtiht{2};
        srr(avg,n_idx,5) = recover_rate(vec(result_mtiht{1}),vec(tX));

        error(avg,n_idx,6) = (norm(vec(result_mtiht_pl{1} - (tX)), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,n_idx,6) = result_mtiht_pl{2};
        srr(avg,n_idx,6) = recover_rate(vec(result_mtiht_pl{1}),vec(tX));
    end
end
save hie_s_varying_n.mat
