%-----------subfigure 1: err vs #probability---------------%
clear all;
home;

FileError = 'Figures/dim_vs_Error';
FileErrorb = 'Figures/dim_vs_Errorb';
FileErrorW = 'Figures/dim_vs_ErrorW';
FileKLdiv = 'Figures/dim_vs_KLdiv';
FileLeg = 'Figures/Legend';

num_samples = 20000;
p = 0.95;
dim = [5,10,15,20,100,200,300,400,500];%100:100:500;%5:5:20;
kappa = 1;
k = 5;
num_runs = 100;
algo_names = {'SGD w/o Filter','SGD with Median','SGD with Trimmed Mean','GD w/o Filter','GD with Median','GD with Trimmed Mean','Oracle SGD','Oracle GD'};
marker_list = ['+','x', 's', '<','+','x', 's', '<', 'o', '*','d','p','o', '*'];%['+','x', 's', '<', 'o', '*'];

mu=5;
sigma=1;

var_len = size(dim, 2);
algo_len = length(algo_names);
err_b = zeros(var_len, num_runs,algo_len);
err_b2 = zeros(var_len, num_runs,algo_len);
norm_b = zeros(var_len, num_runs,algo_len);
err_sigma = zeros(var_len, num_runs,algo_len);
sqrt_kl = zeros(var_len, num_runs,algo_len); % upper bound on TV distance

for i = 1:var_len
    d = dim(i);
    for j = 1:num_runs
        disp (['Varying dim.: d = ' num2str(d) ', iteration = ' num2str(j) '/' num2str(num_runs)] )
        b_star = max(0, randn(d,1));
        W_star = randn(d,d)/sqrt(d);
        [U,~,~] = svd(W_star);
        W_star = U(:,1:d);
        indices = randsrc(1,num_samples,[1 2; p (1-p)]);
        selected_samples = zeros(d,num_samples);
        
        sample_indx  = find(indices==1);
        Z = randn(d,length(sample_indx));
        good_samples =  max(0, W_star*Z + b_star);
        selected_samples(:,sample_indx) = good_samples;
        
        poison_sample_indx  = find(indices==2);
        Z_poison = normrnd(mu,sigma,[d, length(poison_sample_indx)]);
        selected_samples(:,poison_sample_indx) = max(0, Z_poison);
        
        for iter_algo = 1:algo_len
            if iter_algo <= algo_len - 2 
                [sigma_hat, b_hat] = main(selected_samples,algo_names{iter_algo});
            else
                switch iter_algo
                    case algo_len - 1
                        [sigma_hat, b_hat] = main(good_samples,"Oracle SGD");
                    case algo_len
                        [sigma_hat, b_hat] = main(good_samples,"Oracle GD");
                end
            end
            sigma_star = W_star*W_star';
            err_sigma(i, j, iter_algo) = norm(sigma_hat - sigma_star, 'fro')/(1*norm(W_star,'fro')^2);
            err_b(i, j, iter_algo) = norm(b_hat - b_star)/(norm(W_star,'fro')*1);
            if norm(b_star) ~= 0
                err_b2(i, j, iter_algo) = norm(b_hat - b_star)/(1*norm(b_star));
            else
                err_b2(i, j, iter_algo) = norm(b_hat - b_star)/1;
            end
        end
 
    end
end

err_b_mean = mean(err_b, 2);
err_b_mean2 = mean(err_b2, 2);
err_b_std = std(err_b, 0, 2);
err_sigma_mean = mean(err_sigma, 2);
err_sigma_std = std(err_sigma, 0, 2);

save(FileError,'dim', 'err_b', 'err_b2','err_sigma', 'num_samples','p','kappa')

legend_cell = algo_names;


figure; 
figprop;
for iter_algo = 1: algo_len
    plot(dim, err_b_mean(:,:,iter_algo),[marker_list(iter_algo)],'LineStyle','-');
    hold on;
end
hold off 
xlabel('Dimension (d)');
ylabel('Error in bias vector');
ylim([0 0.75]);
axis square;
savefig(FileErrorb); saveas(gca, FileErrorb, 'epsc');



figure;
figprop;
for iter_algo = 1: algo_len
    plot(dim, err_sigma_mean(:,:,iter_algo),[marker_list(iter_algo)],'LineStyle','-');
    hold on;
end 
hold off 
xlabel('Dimension (d)');
ylabel('Error in weight matrix');
ylim([0 16]);
axis square;
savefig(FileErrorW); saveas(gca, FileErrorW, 'epsc');


figure;
hold on
x = 1:10;
y = NaN;

for iter_algo = 1: algo_len
    plot(x, y .* x, [marker_list(iter_algo)],'LineStyle','-','DisplayName', legend_cell{iter_algo});
end
set(gcf,'Position',[0,0,1024,1024]);
legend_handle = legend('Orientation','horizontal');
set(gcf,'Position',(get(legend_handle,'Position')...
    .*[0, 0, 1, 1].*get(gcf,'Position')));
set(legend_handle,'Position',[0,0,1,1]);
set(gcf, 'Position', get(gcf,'Position') + [500, 400, 0, 0]);
savefig(FileLeg); saveas(gca, FileLeg, 'epsc');
