%MNIST
X = loadMNISTImages("train-images-idx3-ubyte");
Y = loadMNISTLabels("train-labels-idx1-ubyte");
test_X = loadMNISTImages("t10k-images-idx3-ubyte");
test_Y = loadMNISTLabels("t10k-labels-idx1-ubyte");
N = size(X, 1);
index = randperm(N, N);
X = X(index, :);
Y = Y(index);

Y = Y + 1;
test_Y = test_Y + 1;
K = 10;
p = 784;


train{1,1} = X;
train{1,2} = Y;
test{1,1} = test_X;
test{1,2} = test_Y;

T = 10;
data_0 = data_with_Byzantine(train, test, 20, 0, K);
data_1 = data_with_Byzantine(train, test, 20, 2, K);
values = linspace(0.005, 0.1, 5);
[error_iter0, r_error_iter0] = model_run(data_0, T, K, values);
r_error_iter1 = median_run(data_1, T, K, values);
cen_error = msda_cen(train, test, K, values);


X1 = 0:10;
YMatrix1 = [error_iter0(1:11) r_error_iter0(1:11) r_error_iter1(1:11)];

figure1 = figure;
axes1 = axes('Parent',figure1);
hold(axes1,'on');
plot1 = plot(X1,YMatrix1,'MarkerSize',8,'LineWidth',2,'Parent',axes1);
set(plot1(1),'DisplayName','Mean-DSLDA,    \alpha = 0','Marker','x',...
    'Color',[0.466666666666667 0.674509803921569 0.188235294117647]);
set(plot1(2),'DisplayName','Median-DSLDA, \alpha = 0','Marker','^',...
    'Color',[0.635294117647059 0.07843137254902 0.184313725490196]);
set(plot1(3),'DisplayName','Median-DSLDA, \alpha = 0.1','Marker','+',...
    'Color',[0 0 1]);
line([0, 10], [cen_error, cen_error],'DisplayName','C-MSDA','Parent',axes1,'LineWidth',2,...
    'LineStyle','-',...
    'Color',[0.466666666666667 0.674509803921569 0.188235294117647]);

ylabel({'Test Error'},'FontSize',20);
xlabel({'Iteration'},'FontSize',20);
box(axes1,'on');
legend1 = legend(axes1,'show');
set(legend1,...
    'Position',[0.573660714285712 0.785119188036872 0.304464285714286 0.11547619047619],...
    'FontSize',20);

mean_error_1 = error_iter(:, 10);
median_error_1 = error_iter_1(:, 10);
median_error = error_iter_1(:, 10);

for i = 1:10
    index = randperm(N, N);
    X = X(index, :);
    Y = Y(index);
    train{1,1} = X;
    train{1,2} = Y;
    data = data_with_Byzantine(train, test, 20, 0);
    [error_iter(:, i), error_iter_1(:, i)] = model_run(data, T);
end


%% ISOLET
importdata('isolet.csv')
data = ans.data;
X = data(:, 1:617);
Y = data(:, 618);

importdata('isolet_test.csv')
test_data = ans.data;
test_X = test_data(:, 1:617);
test_Y = test_data(:, 618);

N = size(X, 1);
index = randperm(N, N);
X = X(index, :);
Y = Y(index);

values = linspace(0.01, 0.1, 5);
K = 26;

train{1,1} = X;
train{1,2} = Y;
test{1,1} = test_X;
test{1,2} = test_Y;
data = data_with_Byzantine(train, test, 10, 0, K);
[error_1, error_2] = model_run(data, 20, K, values);
data_1 = data_with_Byzantine(train, test, 10, 1, K);
error_3 = median_run(data_1, 20, K, values);
cen_error = msda_cen(train, test, K, values);
