clc;
close all;
clear;
% load the configuration
run('Configuration.m')
%%
for avg = 1 : AVG % for AVG trials
    avg
    % produce different realizations of channel
    run('channel_realization.m')

    %% start simulation
    for snr = 1:length(SNRl) % snr

        SNR_10 = SNRl_10(snr);
        noise_var = (signal_power)/SNR_10;

        x_pilot = X(:,1:K1);
        IRS = irs_pattern(:,1:K2);
        % different irs reflection patterns, the number is equal to #overheads,
        H_p1 = IRS.'*kr(A_irs_a.',A_irs_d').'*10;
        H_p1 = H_p1(:,1:Res1);
        H_p1_ori = H_p1;
        H_p2 = x_pilot.'*conj(A1)/10;
        H_p2_ori = H_p2;
        H = kron(kron(H_p1,H_p2),A2);
        %% SNR part
        % received noisy signals
        y_tilde = vec(y_bar(1:K1*bs_ante,1:K2));
        % generate noise
        noise = sqrt(noise_var / 2)*(randn(size(y_tilde))+1i*randn(size(y_tilde)));
        y = y_tilde + noise;
        %% part 2: channel estimation with different techniques
        % true channel for each IRS pattern
        for i = 1:K2
            Htrue(:,:,i) = vec(H2*diag(IRS(:,i))*H1);
        end
        %% different techniques
        %% SVD-KroSBL
        result_svd = svd_kroSBL3(1e-3,y,H_p1,H_p2,A2,H,Res1,200);
        % metrics compute
        [error_svdsbl, Hre_svdsbl] = ce_error(result_svd{1,2},Res1,K2,IRS,A_irs_a,A_irs_d,A1,A2,H1,H2);
        error(1,avg,snr) = error_svdsbl;
        time(1,avg,snr) = result_svd{2,2};
        disp('SVD-KroSBL is finished!');
        %% AM-KroSBL
        result_am = am_kroSBL3(1e-3,y,H_p1,H_p2,A2,H,Res1,200);
        % metrics compute
        [error_am, Hre_am] = ce_error(result_am{1,2},Res1,K2,IRS,A_irs_a,A_irs_d,A1,A2,H1,H2);
        error(2,avg,snr) = error_am;
        time(2,avg,snr) = result_am{2,2};
        disp('AM-KroSBL is finished!');
        %% MSSBL
        Dict = cell(3,1);
        Dict{1} = A2;
        Dict{2} = H_p2;
        Dict{3} = H_p1;
        M = [size(Dict{1},1);size(Dict{2},1);size(Dict{3},1)];
        N = [size(Dict{1},2);size(Dict{2},2);size(Dict{3},2)];
        tY = reshape(y,M.');
        result_mssbl = tenMulReSBL_hi(tY,Dict, N, M, [1e-1,1e-1,1e-1]);
        % metrics compute
        [error_mssbl, Hre_mssbl] = ce_error(vec(result_mssbl{1}),Res1,K2,IRS,A_irs_a,A_irs_d,A1,A2,H1,H2);
        error(3,avg,snr) = error_mssbl;
        time(3,avg,snr) = result_mssbl{2};
        disp('MSSBL is finished!');
    end
end

save irs_channel.mat

%%
load irs_channel.mat

all_colors = [
    0.9290, 0.6940, 0.1250;
    0,      70/255, 222/255;
    0.6350, 0.0780, 0.1840;
    0,      0.4470, 0.7410;
    1,      0,      1;
    0.4900, 0.1800, 0.5600;
    0.5, 0.5, 0.5;
    0.8500, 0.3250, 0.0980
    ];
line_type_set{1} = '-*';
line_type_set{2} = '-p';
line_type_set{3} = '-o';
% line_type_set{4} = '-x';

algo_name{1} = 'SVD-KroSBL';
algo_name{2} = 'AM-KroSBL';
algo_name{3} = 'MSSBL';
% algo_name{4} = 'MSOMP';


errorTotal = squeeze(median(error, 2)).';
errorQ1 = squeeze(quantile(error, 0.25, 2)).';
errorQ3 = squeeze(quantile(error, 0.75, 2)).';

timeTotal = squeeze(median(time, 2)).';
timeQ1 = squeeze(quantile(time, 0.25, 2)).';
timeQ3 = squeeze(quantile(time, 0.75, 2)).';

fontsizeman = 20;
figure_position = [100, 100, 800, 400];
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultLineMarkerSize',fontsizeman);
set(0,'DefaultAxesFontSize',fontsizeman);
set(0,'DefaultAxesFontWeight','bold');

f_nmse = figure('Position', figure_position);
ax1 = gca;
hold(ax1, 'on');

for algo_index = 1:3
    upper_bound = errorQ3(:, algo_index);
    lower_bound = errorQ1(:, algo_index);
    fill_x = [SNRl'; fliplr(SNRl)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax1, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none', 'HandleVisibility', 'off');
    plot(ax1, SNRl, errorTotal(:, algo_index), line_type_set{algo_index}, 'Color', all_colors(algo_index, :));
end

h1 = plot(ax1, NaN, NaN, '-*', 'Color', all_colors(1, :), 'LineStyle', 'none', 'DisplayName', algo_name{1});
h2 = plot(ax1, NaN, NaN, '-p', 'Color', all_colors(2, :), 'LineStyle', 'none', 'DisplayName', algo_name{2});
h3 = plot(ax1, NaN, NaN, '-o', 'Color', all_colors(3, :), 'LineStyle', 'none', 'DisplayName', algo_name{3});
% h4 = plot(ax1, NaN, NaN, '-x', 'Color', all_colors(4, :), 'LineStyle', 'none', 'DisplayName', algo_name{4});
legend([h1, h2, h3], 'Location', 'southwest', 'Interpreter', 'latex', 'Box', 'off');
hold(ax1, 'off');


xticks(SNRl)
set(ax1, 'YScale', 'log');
ylabel(ax1, 'Channel Estimation NSE');
xlabel(ax1, 'SNR (dB)');

grid(ax1, 'on');
box(ax1, 'on');
exportgraphics(f_nmse, 's_irs_channel_estimation.pdf', 'ContentType', 'vector');

f_time = figure('Position', figure_position);
ax3 = gca;
hold(ax3, 'on');
for algo_index = 1:3
    upper_bound = timeQ1(:, algo_index);
    lower_bound = timeQ3(:, algo_index);

    fill_x = [SNRl'; fliplr(SNRl)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax3, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
    plot(ax3, SNRl,timeTotal(:, algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
end
hold(ax3, 'off');
% xlim(ax3, [5 30]);
grid(ax3, 'on');
set(ax3, 'yscale', 'log');
xlabel(ax3, 'SNR (dB)');
ylabel(ax3, 'Time');
box(ax3, 'on');
% exportgraphics(f_time,'s_standard_time.pdf','ContentType', 'vector');

%%
timeAvg = squeeze(mean(time, 2));