%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% Clear previous work
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
close all;
clc;

addpath('utils')
addpath('../tensorlab/')
% fix random seed for reproducibility
rng(pi);

%% generate synthetic data for testing
I   = 3;    % the number of factor matrices
AVG = 200;
% initialize dimensions
N = 18*ones(I,1);Nbar = prod(N);
M = 15*ones(I,1);Mbar = prod(M);
% sparsity pattern
K = 4*ones(I,1); % Kronecker support sparse, #nonzeros of each dimension
% SNR condition
SNR = [3:2:25]; % Define the range of SNR values for testing

error = zeros(AVG,length(SNR),9);
time  = zeros(AVG,length(SNR),9);
srr   = zeros(AVG,length(SNR),9);


for avg = 1:AVG

    avg

    % generate measuring dictionaries
    A = cell(1,I);
    for i = 1:I
        A{i} = randn(M(i),N(i));
        col_norm = vecnorm(A{i});
    end

    % generate sparse vector x
    tX = generate_kro_supp_sparse_tensor(N, K);
    % generate noisy measurements
    tY_ori         = tmprod(tX,A,1:I);
    control.tX     = tX; % for debugging
    for snr = 1:length(SNR)
        signal_power   = norm(tens2mat(tY_ori,1),'fro')^2/numel(tY_ori); %average signal power per asymbol
        n_var          = (signal_power)/(10^(SNR(snr)/10));
        noise          = sqrt(n_var)*randn(size(tY_ori));
        tY             = tY_ori + noise;
        A_full         = kron(kron(A{3},A{2}),A{1});

        tol_coef = 0.07;
        %% sparse vector estimation
        %% using OMP
        result_mtomp = tenMulReOMP_hi(tY,A,N,M,tol_coef);
        disp('Multi-Stage OMP is done!')
        %% using HTP
        result_mthtp = tenMulReHTP_hi(tY,A,N,M,K);
        disp('Multi-Stage HTP is done!')
        %% using IHT
        result_mtiht = tenMulReIHT(tY,A,N,M,K);
        disp('Multi-Stage IHT is done!')
        %% using SBL
        result_mtsbl = tenMulReSBL_hi(tY,A,N,M);
        disp('Multi-Stage SBL is done!')
        %% Benchmark 1: SVD-KroSBL
        result_svd = svd_kroSBL3(1e-3,vec(tY), A{3}, A{2}, A{1}, A_full, N(1), 200);
        disp('SVD-KroSBL is done!')
        %% Benckmark 2: AM-KroSBL
        result_am = am_kroSBL3(1e-3,vec(tY), A{3}, A{2}, A{1}, A_full, N(1), 200);
        disp('AM-KroSBL is done!')
        %% Benchmark 3: HTP
        result_htp = htp(A_full, vec(tY), prod(K));
        disp('HTP is done!')
        %% Benchmark 4: IHT
        result_iht = iht(A_full, vec(tY), prod(K));
        disp('iht is done!')
        %% using shtp
        result_shtp = shtp(A_full, vec(tY), N, K);
        disp('SHTP is done!')
        %%
        error(avg,snr,1) = (norm(vec(result_mtomp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,1) = result_mtomp{2};
        srr(avg,snr,1) = recover_rate(vec(result_mtomp{1}),vec(tX));

        error(avg,snr,2) = (norm(vec(result_mthtp{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,2) = result_mthtp{2};
        srr(avg,snr,2) = recover_rate(vec(result_mthtp{1}),vec(tX));

        error(avg,snr,3) = (norm(vec(result_mtsbl{1} - tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,3) = result_mtsbl{2};
        srr(avg,snr,3) = recover_rate(vec(result_mtsbl{1}),vec(tX));

        error(avg,snr,4) = (norm(result_svd{1,2} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,4) = result_svd{2,2};
        srr(avg,snr,4) = recover_rate(result_svd{1,2},vec(tX));

        error(avg,snr,5) = (norm(result_am{1,2} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,5) = result_am{2,2};
        srr(avg,snr,5) = recover_rate(result_am{1,2},vec(tX));

        error(avg,snr,6) = (norm(result_htp{1} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,6) = result_htp{2};
        srr(avg,snr,6) = recover_rate(vec(result_htp{1}),vec(tX));

        error(avg,snr,7) = (norm(vec(result_mtiht{1} - (tX)), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,7) = result_mtiht{2};
        srr(avg,snr,7) = recover_rate(vec(result_mtiht{1}),vec(tX));

        error(avg,snr,8) = (norm((result_iht{1} - vec(tX)), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,8) = result_iht{2};
        srr(avg,snr,8) = recover_rate(vec(result_iht{1}),vec(tX));

        error(avg,snr,9) = (norm(result_shtp{1,2} - vec(tX), 'fro')/norm(vec(tX),'fro'))^2;
        time(avg,snr,9) = result_shtp{2,2};
        srr(avg,snr,9) = recover_rate(result_shtp{1,2},vec(tX));
    end
end

save kro_sup_snr.mat