%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% Clear previous work
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
close all;
clc;

addpath('utils')
addpath('../tensorlab/')

% fix random seed for reproducibility
rng(pi);

%% parameters setting
M = 512;Ms = ceil(0.3*M);
N = 1024;
L = [5,10,15,20,25,35,50,75,100];
Kl = 3;
ratio = 0.1;
D = 512;
Ts = 1.024;
AVG = 1;


SNR = 20;
%%
nse = zeros(length(L),4,AVG);
time = zeros(length(L),4,AVG);

angle_grid = linspace(0,1-1/M,M);
delay_grid = linspace(0,Ts - Ts/N,N);
% generate the dictionary
Ha = generate_steering(M,angle_grid);
Hd = generate_delay(N,delay_grid(1:D),Ts);
for avg = 1:AVG
    avg
    for ll = 1:length(L)
        %% generate channel
        % amplitude
        rho =  reshape((1/sqrt(2)*(randn(Kl*L(ll),1) + 1i * randn(Kl*L(ll),1))),[Kl, L(ll)]);

        % generate L angles and L delays: on grid assumption
        angle_idx = randsample(M, L(ll)); angles = angle_grid(angle_idx);
        delay_idx = randsample(D, Kl*L(ll)); delays = reshape(delay_grid(delay_idx),[Kl, L(ll)]);


        binaryMatrix = 1*(rand(Kl, L(ll)) > -0.1);
        allZeroCols = all(binaryMatrix == 0, 1);

        numBadCols = sum(allZeroCols);

        if numBadCols > 0
            colIndices = find(allZeroCols);
            rowIndices = randi(Kl, [1, numBadCols]);
            indicesToFlip = sub2ind([Kl, L(ll)], rowIndices, colIndices);
            binaryMatrix(indicesToFlip) = 1;
        end

        binaryMatrix(binaryMatrix == 0) = -1;

        delays = delays.*binaryMatrix;

        channel = zeros(N,M);
        for l = 1:L(ll)
            for kl = 1:Kl
                channel = channel + rho(kl,l) * generate_delay(N,delays(kl,l), Ts)*generate_steering(M,angles(l))';
            end
        end

        c_base = exp(1i * 2 * pi * rand(N, 1));
        diag_vec = eye(N,N);
        c_rest = diag_vec*c_base;
        %% generate received signal
        % 1. generate pilot using
        Ns = ceil(ratio*N);
        % generate random selection: instead of generating a binary matrix, we can
        % randomly draw a set of row indices
        Sa = sort(randsample(M, Ms));
        Sd = sort(randsample(N, Ns));
        c = c_rest(Sd,:);

        Ha_effect = conj(Ha(Sa,:));
        Hd_effect = diag(c)*Hd(Sd,:);
        A = cell(2,1);A{1} = Hd_effect; A{2} = Ha_effect;

        % 2. generate received signal
        Y_clean = diag(c) * channel(Sd,Sa);

        noise_var = (norm(Y_clean,'fro')^2/numel(Y_clean))/(10^(SNR/10));
        noise = sqrt(noise_var)/sqrt(2) * (randn(size(Y_clean)) + 1i* randn(size(Y_clean)));
        Y = Y_clean + noise;
        %% channel estimation
        % different algorithms
        % HiHTP
        params.M = M; params.N = N; params.D = D; params.Ms = Ms; params.Ns = Ns;
        params.row_indices_M = Sa; params.row_indices_N = Sd; params.maxIt = 200;
        params.c_vec = c;

        result_hihtp = HiHTP(vec(Y), A, L(ll), Kl, params);
        disp('HiHTP finished')
        % HiIHT
        result_hiiht = HiIHT(vec(Y), L(ll), Kl, params);
        disp('HiIHT finished')
        % MSHTP-Seq
        result_mshtp = tenMulReHTP_hi(Y, A, [D,M], [Ns,Ms], [Kl,L(ll)]);
        disp('MSHTP finished')
        %% MSOMP
        result_msomp = tenMulReOMP(Y, A, [D,M], [Ns,Ms], 0.1);
        disp('MSOMP finished')
        %% evaluation
        channel_est_hihtp = Hd * reshape(result_hihtp{1,1},[D M]) * Ha';
        nse(ll,1,avg) = norm(channel - channel_est_hihtp, 'fro')^2/norm(channel, 'fro')^2;
        time(ll,1,avg) = result_hihtp{2,1};

        channel_est_hiiht = Hd * reshape(result_hiiht{1,1},[D M]) * Ha';
        nse(ll,2,avg) = norm(channel - channel_est_hiiht, 'fro')^2/norm(channel, 'fro')^2;
        time(ll,2,avg) = result_hiiht{2,1};

        channel_est_mshtp = Hd * result_mshtp{1,1} * Ha';
        nse(ll,3,avg) = norm(channel - channel_est_mshtp, 'fro')^2/norm(channel, 'fro')^2;
        time(ll,3,avg) = result_mshtp{2,1};

        channel_est_msomp = Hd * result_msomp{1,1} * Ha';
        nse(ll,4,avg) = norm(channel - channel_est_msomp, 'fro')^2/norm(channel, 'fro')^2;
        time(ll,4,avg) = result_msomp{2,1};

    end
end
%%
save result.mat nse time


%%
load result.mat

all_colors = [
    0,      70/255, 222/255;
    0.6350, 0.0780, 0.1840;
    0,      0.4470, 0.7410;
    0.9290, 0.6940, 0.1250;
    1,      0,      1;
    0.4900, 0.1800, 0.5600;
    0.5, 0.5, 0.5;
    0.8500, 0.3250, 0.0980
    ];
line_type_set{1} = '-*';
line_type_set{2} = '-p';
line_type_set{3} = '-o';
line_type_set{4} = '-x';

algo_name{1} = 'HiHTP';
algo_name{2} = 'HiIHT';
algo_name{3} = 'MSHTP';
algo_name{4} = 'MSOMP';


errorTotal = squeeze(median(nse, 3));
errorQ1 = squeeze(quantile(nse, 0.25, 3));
errorQ3 = squeeze(quantile(nse, 0.75, 3));

timeTotal = squeeze(median(time, 3));
timeQ1 = squeeze(quantile(time, 0.25, 3));
timeQ3 = squeeze(quantile(time, 0.75, 3));

fontsizeman = 20;
figure_position = [100, 100, 800, 400];
set(0,'DefaultLineLineWidth',3);
set(0,'DefaultLineMarkerSize',fontsizeman);
set(0,'DefaultAxesFontSize',fontsizeman);
set(0,'DefaultAxesFontWeight','bold');

f_nmse = figure('Position', figure_position);
ax1 = gca;
hold(ax1, 'on');

for algo_index = 1:4
    upper_bound = errorQ3(:, algo_index);
    lower_bound = errorQ1(:, algo_index);
    fill_x = [L'; fliplr(L)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax1, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none', 'HandleVisibility', 'off');
    plot(ax1, L, errorTotal(:, algo_index), line_type_set{algo_index}, 'Color', all_colors(algo_index, :));
end

h1 = plot(ax1, NaN, NaN, '-*', 'Color', all_colors(1, :), 'LineStyle', 'none', 'DisplayName', algo_name{1});
h2 = plot(ax1, NaN, NaN, '-p', 'Color', all_colors(2, :), 'LineStyle', 'none', 'DisplayName', algo_name{2});
h3 = plot(ax1, NaN, NaN, '-o', 'Color', all_colors(3, :), 'LineStyle', 'none', 'DisplayName', algo_name{3});
h4 = plot(ax1, NaN, NaN, '-x', 'Color', all_colors(4, :), 'LineStyle', 'none', 'DisplayName', algo_name{4});
legend([h1, h2, h3, h4], 'Location', 'southeast', 'Interpreter', 'latex', 'Box', 'off');
hold(ax1, 'off');


xticks(L)
set(ax1, 'YScale', 'log');
ylabel(ax1, 'Channel Estimation NSE');
xlabel(ax1, 'The number of angles');

grid(ax1, 'on');
box(ax1, 'on');
exportgraphics(f_nmse, 's_channel_estimation.pdf', 'ContentType', 'vector');

f_time = figure('Position', figure_position);
ax3 = gca;
hold(ax3, 'on');
for algo_index = 1:4
    upper_bound = timeQ1(:, algo_index);
    lower_bound = timeQ3(:, algo_index);

    fill_x = [L'; fliplr(L)']';
    fill_y = [lower_bound; fliplr(upper_bound')']';
    fill(ax3, fill_x, fill_y, all_colors(algo_index, :), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
    plot(ax3, L,timeTotal(:, algo_index),line_type_set{algo_index},'Color',all_colors(algo_index, :));
end
hold(ax3, 'off');
% xlim(ax3, [5 30]);
grid(ax3, 'on');
set(ax3, 'yscale', 'log');
xlabel(ax3, 'SNR (dB)');
ylabel(ax3, 'Time');
box(ax3, 'on');
% exportgraphics(f_time,'s_standard_time.pdf','ContentType', 'vector');

%%
timeAvg = squeeze(mean(time, 3));