% permuted mnist

function [] = mnist_gd_relu(i,seed2)
load('mnist.mat')

beta = 1e4;
Mall = [5,8,10:10:100];
M = Mall(i);
n = 1;
sigma = 0.5;
N = 1000;
iteration0 = 3e6;
start = 1e4;
eta = 0.002;
eta1 = 0.01;
N_0 = 784;
P = 3000;
alpha = P/N;
N_s = 1000;
rng(2,'twister');
trainX = double(trainX);
testX = double(testX);

trainX = (trainX'-mean(trainX'))./std(trainX');
testX = (testX'-mean(testX'))./std(testX');

indtrain = find(mod(trainY,2)==0);
indtest = find(mod(testY,2)==0);
indtrain1 = find(mod(trainY,2)==1);
indtest1 = find(mod(testY,2)==1);
index = [indtrain(1:P/2),indtrain1(1:P/2)];
index_t = [indtest(1:N_s/2),indtest1(1:N_s/2)];

x_00 = trainX(:,index);
x_t0 = testX(:,index_t);

y = (mod(trainY(index),2)==0);
y_t = (mod(testY(index_t),2)==0);

th = 0;
for k = 1:n
    ind(k,:) = randperm(784);
    x_0(:,((k-1)*P+1):(k*P)) = double(x_00(ind(k,:),:));
    x_t(:,((k-1)*N_s+1):(k*N_s)) = double(x_t0(ind(k,:),:));
end
rng(1,'twister');
v = normrnd(0,1,M,N_0);
g = (v*x_0>th);
gp = (v*x_t>th);
    
%     
% iteration = 1000;
% beta0 = 0.1;
% [g,xc] = softkmeans(M, x_0, beta0, iteration, 1);
% g = g(1:M,:);
% %g = (g>1/N);
% 
% for i0 = 1:M
%    d(i0,:) = sum((x_t - xc(:, i0)).^2)';
%    expd(i0,:) = exp(-beta0*d(i0,:));
% end
% gp = expd./sum(expd);
% gp = gp(1:M,:);


y = kron(ones(1,n),double(y));
y_t = kron(ones(1,n),double(y_t));

y = 2*y-1;
y_t = 2*y_t-1;


[~,H] = numeric_saddlepoint_nonlinear(N_0,P*n,M,g,y,x_0,N,sigma,beta);




    K_1 = sigma^2/N_0*(x_0'*x_0);
    K_1p = sigma^2/N_0*(x_t'*x_0);
    k_1p = sigma^2/N_0*(x_t'*x_t);
    
    K2 = (g'*H*g).*K_1 + 1/beta*eye(n*P,n*P);

    K2p = (gp'*H*g).*K_1p;

    k2p = (gp'*H*gp).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
    fp =  K2p*pinv(K2)*y';
    fvar_th = diag(k2p - K2p*pinv(K2)*K2p');
    bias_th = sum((fp-y_t').^2)/N_s;
    
    ge = sum((fp-y_t').^2)/N_s + mean(fvar_th);
    gb_th = mean((y_t'+1)/2+(-y_t').*erfc(-fp./sqrt(fvar_th*2))/2);
    
    
    t = x_0'*x_0./(sqrt(sum(x_0.^2))'*sqrt(sum(x_0.^2)));
    theta = real(acos(t));
    k0 = 1/2/pi*sigma^2/N_0*sqrt(sum(x_0.^2))'*sqrt(sum(x_0.^2)).*(sin(theta)+(pi-theta).*cos(theta));
    
    % K for the test sample
    t_k = x_t'*x_t./(sqrt(sum(x_t.^2))'*sqrt(sum(x_t.^2)));
    theta_k = real(acos(t_k));
    k = 1/2/pi*sigma^2/N_0*sqrt(sum(x_t.^2))'*sqrt(sum(x_t.^2)).*(sin(theta_k)+(pi-theta_k).*cos(theta_k));
    
    % K for training/testing sample
    t_kp = x_t'*x_0./(sqrt(sum(x_t.^2))'*sqrt(sum(x_0.^2)));
    theta_kp = real(acos(t_kp));
    k_p = 1/2/pi*sigma^2/N_0*sqrt(sum(x_t.^2))'*sqrt(sum(x_0.^2)).*(sin(theta_kp)+(pi-theta_kp).*cos(theta_kp));
    
    % b =  N_0/N*sigma^(-4)*y*pinv(k0)*y';
    % h = (-(1-alpha)+sqrt((1-alpha)^2+4*b))/2/b;
    % finite T version
    h0 = normrnd(0,1,1,1);
    fun = @(x) x-sigma^(-2)+1/N*y*pinv(k0/x+1/beta*eye(P,P))*pinv(k0/x+1/beta*eye(P,P))*k0*y'-1/N*trace(pinv(k0/x+1/beta*eye(P,P))*k0);
    h =  abs(fsolve(fun,h0));
    
    
    % w_sample = cell(sample_size,1);]
    fp_relu =  k_p/h*pinv(k0/h + 1/beta*eye(P,P))*y';
    fvar_relu = diag(k/h - k_p/h*pinv(k0/h+1/beta*eye(P,P))*k_p'/h);
    
    ge_relu = sum((fp_relu-y_t').^2)/N_s + mean(fvar_relu);
    gb_relu = mean((y_t'+1)/2+(-y_t').*erfc(-fp_relu./sqrt(fvar_relu*2))/2);
    
    biasr_th = sum((fp_relu-y_t').^2)/N_s;
    fvarr_th = mean(fvar_relu);
    
     %f_average = zeros(1,N_s);
     fr_average = zeros(1,N_s);
     %f_var = zeros(1,N_s);
     fr_var = zeros(1,N_s);
%     
    
    rng(seed2,'twister');
    wr = normrnd(0, 1, N_0+1, N);

for j0 = 1:iteration0
    
%    w_1 = reshape(w(1:N_0, :), N_0, N);
%    w_2 = reshape(w(N_0+1:end, :), M, N)';
     w_ini = wr(1:N_0, :);
     a = wr(N_0+1, :);
    
%     for m0 = 1:N_1
%         x_1(m0,:) = 1/sqrt(N_0*M)*diag(x_0'*w_1(:,:,m0)*g);
%     end
%    x_1 = 1/sqrt(N_0)*w_1'*x_0;
%    x_1t = 1/sqrt(N_0)*w_1'*x_t;
%    f = (1/sqrt(N*M)*diag(g'*w_2'*x_1))';
%    f_t = (1/sqrt(N*M)*diag(gp'*w_2'*x_1t))';
%    E(j0) = sum((y - f).^2);
    
     fr = 1/sqrt(N)*a*((1/sqrt(N_0)*w_ini'*x_0).*((1/sqrt(N_0)*w_ini'*x_0)>0));
     fr_t = 1/sqrt(N)*a*((1/sqrt(N_0)*w_ini'*x_t).*((1/sqrt(N_0)*w_ini'*x_t)>0));   
     Er(j0) = sum((y - fr).^2);
   

    
%    w_2 = w_2 + (eta1*1/sqrt(N*M)*g*diag((y-f))*x_1')';
%    w_1 = w_1 + (eta1/sqrt(N*M*N_0)*(w_2*g)*diag(y-f)*x_0')';
    
     a = a + eta*(y - fr)*(1/sqrt(N)*(1/sqrt(N_0)*w_ini'*x_0).*(1/sqrt(N_0)*w_ini'*x_0>0))';
     w_ini = w_ini + (eta*1/sqrt(N*N_0)*diag(a')*((1/sqrt(N_0)*w_ini'*x_0)>0)*(diag(y - fr)*x_0'))';


%    bias = mean((f_t - y_t).^2);
    bias_r = mean((fr_t - y_t).^2);
    if mod(j0,1e3) == 1
%    disp(['train',num2str(E(end))]);
    disp(['trainr',num2str(Er(end))]);
    end
%    w(1:N_0,:) = reshape(w_1, N_0, N);
%    w(N_0+1:end,:) = reshape(w_2, N, M)';
    
     wr(1:N_0,:) = w_ini;
     wr(N_0+1,:) = a;

      
    if (Er(end)<0.1)
        break
    end

end
    save(['../compare_mnist_relu_global_theory/compare_mnist_generalization_context_relu','seed',num2str(seed),'.mat'],'bias_r','fr_t','biasr_th','fvarr_th','j0','gb_th','gb_relu','ge_relu','ge','bias_th','fp_relu','fvar_relu','fp','y_t','fvar_th','H','g','x_0','x_t')

end