% theta ture
p = 600; K = 5;
sigma = zeros(p,p);
for i = 1:p
    for k = 1:p
        sigma(i,k) = 0.5^(abs(i-k));
    end
end  
beta = zeros(p,K);
mu = zeros(p,K);
for k = 1:K
    beta(2*k-1,k)=1.6;
    beta(2*k,k) = 1.6;
    mu(:,k) = sigma * beta(:,k);
end
theta_true = beta - beta(:, 1);
theta_true(:, 1) = [];

n = 200;


%% the effect of alpha
values = linspace(0.5, 2, 10);
T = 5;
tic;
for i = 1:n
    [train, valid, test] = generate_data_multi(20000, p, 5);
    data0 = data_with_Byzantine(train, valid, test, 5, 100, 0);
    data1 = data_with_Byzantine(train, valid, test, 5, 100, 5);
    data2 = data_with_Byzantine(train, valid, test, 5, 100, 10);
    data3 = data_with_Byzantine(train, valid, test, 5, 100, 15);
    data4 = data_with_Byzantine(train, valid, test, 5, 100, 20);

    [md_error_iter_0(:, i), md_ell2_0(:, i), md_F_score_0(:, i)] = median_valid_msda(data0, K, T, values, theta_true);
    [md_error_iter_1(:, i), md_ell2_1(:, i), md_F_score_1(:, i)] = median_valid_msda(data1, K, T, values, theta_true);
    [md_error_iter_2(:, i), md_ell2_2(:, i), md_F_score_2(:, i)] = median_valid_msda(data2, K, T, values, theta_true);
    [md_error_iter_3(:, i), md_ell2_3(:, i), md_F_score_3(:, i)] = median_valid_msda(data3, K, T, values, theta_true);
    [md_error_iter_4(:, i), md_ell2_4(:, i), md_F_score_4(:, i)] = median_valid_msda(data4, K, T, values, theta_true);
    i
end
toc;


c_error_avg = mean(c_error, 1);
c_ell2_avg = mean(c_ell_2, 1);
c_F_score_avg = mean(c_F_score, 1);

c_error_sd = std(c_error, 1);
c_ell2_sd = std(c_ell_2, 1);
c_F_score_sd = std(c_F_score, 1);

md_error_iter_0_avg = mean(md_error_iter_0, 2);
md_ell2_0_avg = mean(md_ell2_0, 2);
md_F_score_0_avg = mean(md_F_score_0, 2);


md_error_iter_1_avg = mean(md_error_iter_1, 2);
md_ell2_1_avg = mean(md_ell2_1, 2);
md_F_score_1_avg = mean(md_F_score_1, 2);

md_error_iter_2_avg = mean(md_error_iter_2, 2);
md_ell2_2_avg = mean(md_ell2_2, 2);
md_F_score_2_avg = mean(md_F_score_2, 2);


md_error_iter_3_avg = mean(md_error_iter_3, 2);
md_ell2_3_avg = mean(md_ell2_3, 2);
md_F_score_3_avg = mean(md_F_score_3, 2);

error_iter_avg = [md_error_iter_0_avg md_error_iter_1_avg md_error_iter_2_avg md_error_iter_3_avg];
ell2_iter_avg = [md_ell2_0_avg md_ell2_1_avg md_ell2_2_avg md_ell2_3_avg];
Fscore_iter_avg = [md_F_score_0_avg md_F_score_1_avg md_F_score_2_avg md_F_score_3_avg];


md_error_iter_0_sd = std(md_error_iter_0')';
md_ell2_0_sd = std(md_ell2_0')';
md_F_score_0_sd = std(md_F_score_0')';

md_error_iter_1_sd = std(md_error_iter_1')';
md_ell2_1_sd = std(md_ell2_1')';
md_F_score_1_sd = std(md_F_score_1')';

md_error_iter_2_sd = std(md_error_iter_2')';
md_ell2_2_sd = std(md_ell2_2')';
md_F_score_2_sd = std(md_F_score_2')';

md_error_iter_3_sd = std(md_error_iter_3')';
md_ell2_3_sd = std(md_ell2_3')';
md_F_score_3_sd = std(md_F_score_3')';

error_iter_sd = [md_error_iter_0_sd md_error_iter_1_sd md_error_iter_2_sd md_error_iter_3_sd];
ell2_iter_sd = [md_ell2_0_sd md_ell2_1_sd md_ell2_2_sd md_ell2_3_sd];
Fscore_iter_sd = [md_F_score_0_sd md_F_score_1_sd md_F_score_2_sd md_F_score_3_sd];

save('alpha_effect.mat');