%MNIST
X = loadMNISTImages("train-images-idx3-ubyte");
Y = loadMNISTLabels("train-labels-idx1-ubyte");
test_X = loadMNISTImages("t10k-images-idx3-ubyte");
test_Y = loadMNISTLabels("t10k-labels-idx1-ubyte");
N = size(X, 1);
index = randperm(N, N);
X = X(index, :);
Y = Y(index);

Y = Y + 1;
test_Y = test_Y + 1;
K = 10;
p = 784;
zero_ind = zeros(p,1);
for i = 1:p
    zero_ind(i) = sum(X(:,i) == 0);
end
X(:, find(zero_ind > 55000)) = [];
test_X(:, find(zero_ind > 55000)) = [];

p = size(X, 2);

train_X = X(1:50000, :);
train_Y = Y(1:50000);
valid_X = X(50001:60000, :);
valid_Y = Y(50001:60000, :);

train{1,1} = train_X;
train{1,2} = train_Y;
valid{1,1} = valid_X;
valid{1,2} = valid_Y;
test{1,1} = test_X;
test{1,2} = test_Y;

%[train, valid, test] = generate_data_multi(10000, 500, 5);
data = data_with_Byzantine(train, valid, test, 10, 10, 0);
values = linspace(0.01, 1, 10);
[theta_initial, error_initial, lambda_initial] = valid_initial(data, K, values, 0.2);

IndexM = zeros(p, K-1);  % indexes of elements in each group
for i = 1:p
    IndexM(i,:) = linspace(i, (K-2)*p + i, K-1);
end

[m, n] = size(Sigma_1);
B = zeros(m, n);
P = cell(K-1);
for i=1:(K-1)
    for j=1:(K-1)
        if i==j
            P(i,j)={Sigma_1};
        else
            P(i,j)={B};
        end
    end
end
Sigma_10 = cell2mat(P);

delta_0 = reshape(delta, p*(K-1), 1);


delta = (Sigma_1 - Sigma) * theta_initial + delta_hat;
lambda = 0.9;
cvx_begin
    variable w(p, K-1)    
    expression ws(K-1)
    for i = 1:K-1
        ws(i) = 0.5 * w(:, i)' * Sigma_1 * w(:, i) - delta_hat(:, i)' * w(:, i);
    end
    minimize(sum(ws) + lambda * sum(norms(w, 1, 2)))
cvx_end
theta_initial = w;
%theta_update = reshape(w, p, K-1);
theta_choose = [zeros(p, 1) w];

tpred_value = test_X * theta_choose + log(pi_hat);
[max_a, index] = max(tpred_value');
index = index';
error_cvx = 1 - mean(index == test_Y);