%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% Clear previous work
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

clear;
close all;
clc;

addpath('utils')
addpath('../tensorlab/')
if isempty(gcp('nocreate'))
    parpool('threads');  % Start with default settings
end
% fix random seed for reproducibility
rng(pi);

%% generate synthetic data for testing
I   = 2;    % the number of factor matrices
AVG = 200;
M_candidate = 48:4:72;
N = 80*ones(I,1);Nbar = prod(N);
% sparsity pattern
K = 15; % s-sparse, #nonzeros in total
% SNR condition
SNR = 20; % Define the range of SNR values for testing
error = zeros(AVG,length(M_candidate),6);
time  = zeros(AVG,length(M_candidate),6);
srr   = zeros(AVG,length(M_candidate),6);


for avg = 1:AVG

    avg


    for m = 1:length(M_candidate)

        % initialize dimensions
        M = M_candidate(m)*ones(I,1);Mbar = prod(M);

        % generate measuring dictionaries
        A = cell(1,I);
        for i = 1:I
            A{i} = randn(M(i),N(i));
            col_norm = vecnorm(A{i});
        end

        % generate sparse vector x
        tX = generate_sparse_tensor(N, K);
        % generate noisy measurements
        tY_ori         = tmprod(tX,A,1:I);
        control.tX     = tX; % for debugging

        signal_power   = norm(vec(tY_ori),'fro')^2/numel(tY_ori); %average signal power per asymbol
        n_var          = (signal_power)/(10^(SNR/10));
        noise          = sqrt(n_var)*randn(size(tY_ori));
        tY             = tY_ori + noise;

        coef = 0.05;
        %% sparse vector estimation
        %% Benchmark 1: KroOMP
        result_kroomp = kron_omp_refined(A,tY,coef);
        disp('Kro-OMP is done!')
        %% using sbl: sequential
        result_sbl_se = tenMulReSBL(tY,A,N,M);
        disp('Multi-Stage SBL is done!')
        %% using sbl: parallel
        result_sbl_pl = tenMulReSBL_pl(tY,A,N,M);
        disp('Multi-Stage SBL is done!')
        %% using OMP
        result_mtomp = tenMulReOMP(tY,A,N,M,coef);
        disp('Multi-Stage OMP is done!')
        %% Benchmark 2: prepare for IHT/HTP
        A_all = kron(A{2},A{1});
        y_mea = vec(tY);
        %% using IHT
        result_iht = iht(A_all,y_mea,K);
        disp('IHT is done!')
        %% using HTP
        result_htp = htp(A_all,y_mea,K);
        disp('HTP is done!')
        %% track results
        error(avg,m,1) = (norm(vec(result_mtomp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,m,1) = result_mtomp{2};
        srr(avg,m,1) = recover_rate(vec(result_mtomp{1}),vec(tX));
        %%
        error(avg,m,2) = (norm(vec(result_kroomp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,m,2) = result_kroomp{2};
        srr(avg,m,2) = recover_rate(vec(result_kroomp{1}),vec(tX));
        %%
        error(avg,m,3) = (norm(vec(result_sbl_se{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,m,3) = result_sbl_se{2};
        srr(avg,m,3) = recover_rate(vec(result_sbl_se{1}),vec(tX));
        %%
        error(avg,m,4) = (norm(vec(result_sbl_pl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,m,4) = result_sbl_pl{2};
        srr(avg,m,4) = recover_rate(vec(result_sbl_pl{1}),vec(tX));
        %%
        error(avg,m,5) = (norm(result_iht{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,m,5) = result_iht{2};
        srr(avg,m,5) = recover_rate((result_iht{1}),vec(tX));
        %%
        error(avg,m,6) = (norm(result_htp{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,m,6) = result_htp{2};
        srr(avg,m,6) = recover_rate((result_htp{1}),vec(tX));
    end
end

save s_sparse_measurement.mat