%MNIST
X = loadMNISTImages("train-images-idx3-ubyte");
Y = loadMNISTLabels("train-labels-idx1-ubyte");
test_X = loadMNISTImages("t10k-images-idx3-ubyte");
test_Y = loadMNISTLabels("t10k-labels-idx1-ubyte");
N = size(X, 1);
index = randperm(N, N);
X = X(index, :);
Y = Y(index);

Y = Y + 1;
test_Y = test_Y + 1;
K = 10;
p = 784;


train{1,1} = X;
train{1,2} = Y;
test{1,1} = test_X;
test{1,2} = test_Y;

data_0 = data_with_Byzantine(train, test, 50, 0);
data_1 = data_with_Byzantine(train, test, 50, 5);
data_2 = data_with_Byzantine(train, test, 50, 10);
[nb_error, b_error_0] = model_run(data_0, 20);
b_error_1 = median_run(data_1, 20);
b_error_2 = median_run(data_2, 20);
save('resu.mat');


data = data_with_Byzantine(train, test, 50, 0);
[error_3, error_4] = model_run(data, 20);
sys_1 = [error_3 error_4];
x = 0:20;
plot(x, sys_1);

T = 20;
for i = 1:10
    index = randperm(N, N);
    X = X(index, :);
    Y = Y(index);
    train{1,1} = X;
    train{1,2} = Y;
    data = data_with_Byzantine(train, test, 20, 2);
    [error_iter(:, i), error_iter_1(:, i)] = model_run(data, T);
end

mean_error = error_iter(:, 1);
median_error = error_iter_1(:, 1);

mean_error_1 = error_iter(:, 1);
median_error_1 = error_iter_1(:, 1);

X1 = 0:20;
YMatrix1 = [mean_error median_error median_error_1];

% 创建 figure
figure1 = figure;

% 创建 axes
axes1 = axes('Parent',figure1);
hold(axes1,'on');

% 使用 plot 的矩阵输入创建多行
plot1 = plot(X1,YMatrix1,'MarkerSize',8,'LineWidth',2,'Parent',axes1);
set(plot1(1),'DisplayName','Mean-DSLDA,    \alpha = 0','Marker','x',...
    'Color',[0.466666666666667 0.674509803921569 0.188235294117647]);
set(plot1(2),'DisplayName','Median-DSLDA, \alpha = 0','Marker','^',...
    'Color',[0.635294117647059 0.07843137254902 0.184313725490196]);
set(plot1(3),'DisplayName','Median-DSLDA, \alpha = 0.1','Marker','+',...
    'Color',[0 0 1]);


% 创建 ylabel
ylabel({'Test Error'},'FontSize',15);
% 创建 xlabel
xlabel({'Iteration'},'FontSize',15);

box(axes1,'on');
% 创建 legend
legend1 = legend(axes1,'show');
set(legend1,...
    'Position',[0.573660714285712 0.785119188036872 0.304464285714286 0.11547619047619],...
    'FontSize',12);

mean_error_1 = error_iter(:, 10);
median_error_1 = error_iter_1(:, 10);
median_error = error_iter_1(:, 10);

for i = 1:10
    index = randperm(N, N);
    X = X(index, :);
    Y = Y(index);
    train{1,1} = X;
    train{1,2} = Y;
    data = data_with_Byzantine(train, test, 20, 0);
    [error_iter(:, i), error_iter_1(:, i)] = model_run(data, T);
end




importdata('isolet.csv')
data = ans.data;
X = data(:, 1:617);
Y = data(:, 618);

importdata('isolet_test.csv')
test_data = ans.data;
test_X = test_data(:, 1:617);
test_Y = test_data(:, 618);

N = size(X, 1);
index = randperm(N, N);
X = X(index, :);
Y = Y(index);

values = linspace(0.01, 1, 10);

train{1,1} = X;
train{1,2} = Y;
test{1,1} = test_X;
test{1,2} = test_Y;
data = data_with_Byzantine(train, test, 10, 0);
[nb_error, nb_merror] = model_run(data, 20, 26, values);