% permuted mnist

function [] = mnist_gd_pretrained(i,seed1,seed2)
load('mnist.mat')

beta = 1e4;
Mall = [5,8,10:10:100];
M = Mall(i);
n = 1;
sigma = 0.5;
N = 1000;
iteration0 = 3e6;
start = 1e4;
eta = 0.02;
eta1 = 0.01;
N_0 = 784;
P = 1000;
alpha = P/N;
N_s = 1000;
rng(2,'twister');
trainX = double(trainX);
testX = double(testX);

trainX = (trainX'-mean(trainX'))./std(trainX');
testX = (testX'-mean(testX'))./std(testX');

indtrain = find(mod(trainY,2)==0);
indtest = find(mod(testY,2)==0);
indtrain1 = find(mod(trainY,2)==1);
indtest1 = find(mod(testY,2)==1);
index = [indtrain(1:P/2),indtrain1(1:P/2)];
index_t = [indtest(1:N_s/2),indtest1(1:N_s/2)];

x_00 = trainX(:,index);
x_t0 = testX(:,index_t);

y = (mod(trainY(index),2)==0);
y_t = (mod(testY(index_t),2)==0);

th = 0;
for k = 1:n
    ind(k,:) = randperm(784);
    x_0(:,((k-1)*P+1):(k*P)) = double(x_00(ind(k,:),:));
    x_t(:,((k-1)*N_s+1):(k*N_s)) = double(x_t0(ind(k,:),:));
end
% rng(seed1,'twister');
% v = normrnd(0,1,M,N_0);
% g = (v*x_0>th);
% gp = (v*x_t>th);
%     
%     
iteration = 1000;
beta0 = 0.1;
[g,xc] = softkmeans(M, x_0, beta0, iteration, seed1);
g = g(1:M,:);
%g = (g>1/N);

for i0 = 1:M
   d(i0,:) = sum((x_t - xc(:, i0)).^2)';
   expd(i0,:) = exp(-beta0*d(i0,:));
end
gp = expd./sum(expd);
gp = gp(1:M,:);


y = kron(ones(1,n),double(y));
y_t = kron(ones(1,n),double(y_t));

y = 2*y-1;
y_t = 2*y_t-1;


[~,H] = numeric_saddlepoint_nonlinear(N_0,P*n,M,g,y,x_0,N,sigma,beta);




    K_1 = sigma^2/N_0*(x_0'*x_0);
    K_1p = sigma^2/N_0*(x_t'*x_0);
    k_1p = sigma^2/N_0*(x_t'*x_t);
    
    K2 = (g'*H*g).*K_1 + 1/beta*eye(n*P,n*P);

    K2p = (gp'*H*g).*K_1p;

    k2p = (gp'*H*gp).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
    fp =  K2p*pinv(K2)*y';
    fvar_th = diag(k2p - K2p*pinv(K2)*K2p');
    bias_th = sum((fp-y_t').^2)/N_s;
    
    ge = sum((fp-y_t').^2)/N_s + mean(fvar_th);
    gb_th = mean((y_t'+1)/2+(-y_t').*erfc(-fp./sqrt(fvar_th*2))/2);
    
    
    t = x_0'*x_0./(sqrt(sum(x_0.^2))'*sqrt(sum(x_0.^2)));
    theta = real(acos(t));
    k0 = 1/2/pi*sigma^2/N_0*sqrt(sum(x_0.^2))'*sqrt(sum(x_0.^2)).*(sin(theta)+(pi-theta).*cos(theta));
    
    % K for the test sample
    t_k = x_t'*x_t./(sqrt(sum(x_t.^2))'*sqrt(sum(x_t.^2)));
    theta_k = real(acos(t_k));
    k = 1/2/pi*sigma^2/N_0*sqrt(sum(x_t.^2))'*sqrt(sum(x_t.^2)).*(sin(theta_k)+(pi-theta_k).*cos(theta_k));
    
    % K for training/testing sample
    t_kp = x_t'*x_0./(sqrt(sum(x_t.^2))'*sqrt(sum(x_0.^2)));
    theta_kp = real(acos(t_kp));
    k_p = 1/2/pi*sigma^2/N_0*sqrt(sum(x_t.^2))'*sqrt(sum(x_0.^2)).*(sin(theta_kp)+(pi-theta_kp).*cos(theta_kp));
    
    % b =  N_0/N*sigma^(-4)*y*pinv(k0)*y';
    % h = (-(1-alpha)+sqrt((1-alpha)^2+4*b))/2/b;
    % finite T version
    h0 = normrnd(0,1,1,1);
    fun = @(x) x-sigma^(-2)+1/N*y*pinv(k0/x+1/beta*eye(P,P))*pinv(k0/x+1/beta*eye(P,P))*k0*y'-1/N*trace(pinv(k0/x+1/beta*eye(P,P))*k0);
    h =  abs(fsolve(fun,h0));
    
    
    % w_sample = cell(sample_size,1);]
    fp_relu =  k_p/h*pinv(k0/h + 1/beta*eye(P,P))*y';
    fvar_relu = diag(k/h - k_p/h*pinv(k0/h+1/beta*eye(P,P))*k_p'/h);
    
    ge_relu = sum((fp_relu-y_t').^2)/N_s + mean(fvar_relu);
    gb_relu = mean((y_t'+1)/2+(-y_t').*erfc(-fp_relu./sqrt(fvar_relu*2))/2);
    
    biasr_th = sum((fp_relu-y_t').^2)/N_s;
    fvarr_th = mean(fvar_relu);
    
     f_average = zeros(1,N_s);
     fr_average = zeros(1,N_s);
     f_var = zeros(1,N_s);
     fr_var = zeros(1,N_s);
%     
    
    rng(seed2,'twister');
    w = normrnd(0, sigma, N_0+M, N);

for j0 = 1:iteration0
    
    w_1 = reshape(w(1:N_0, :), N_0, N);
    w_2 = reshape(w(N_0+1:end, :), M, N)';
%     w_ini = wr(1:N_0, :);
%     a = wr(N_0+1, :);
    
%     for m0 = 1:N_1
%         x_1(m0,:) = 1/sqrt(N_0*M)*diag(x_0'*w_1(:,:,m0)*g);
%     end
    x_1 = 1/sqrt(N_0)*w_1'*x_0;
    x_1t = 1/sqrt(N_0)*w_1'*x_t;
    f = (1/sqrt(N*M)*diag(g'*w_2'*x_1))';
    f_t = (1/sqrt(N*M)*diag(gp'*w_2'*x_1t))';
    E(j0) = sum((y - f).^2);
    
%     fr = 1/sqrt(N)*a*((1/sqrt(N_0)*w_ini'*x_0).*((1/sqrt(N_0)*w_ini'*x_0)>0));
%     fr_t = 1/sqrt(N)*a*((1/sqrt(N_0)*w_ini'*x_t).*((1/sqrt(N_0)*w_ini'*x_t)>0));   
%     Er(j0) = sum((y - fr).^2);
   

    
    w_2 = w_2 + (eta1*1/sqrt(N*M)*g*diag((y-f))*x_1')';
    w_1 = w_1 + (eta1/sqrt(N*M*N_0)*(w_2*g)*diag(y-f)*x_0')';
    
%     a = a + eta*(y - fr)*(1/sqrt(N)*(1/sqrt(N_0)*w_ini'*x_0).*(1/sqrt(N_0)*w_ini'*x_0>0))' - eta/(beta*sigma^2)*a + sqrt(2*eta/(beta))*normrnd(0,1,1,N);
%     w_ini = w_ini + (eta*1/sqrt(N*N_0)*diag(a')*((1/sqrt(N_0)*w_ini'*x_0)>0)*(diag(y - fr)*x_0'))' - eta/(beta*sigma^2)*w_ini + sqrt(2*eta/(beta))*normrnd(0,1,N_0, N);


    bias = mean((f_t - y_t).^2);
%   bias_r = mean((fr_average - y_t).^2);
    if mod(j0,1e3) == 1
    disp(['train',num2str(E(end))]);
%    disp(['trainr',num2str(Er(end))]);
    end
    w(1:N_0,:) = reshape(w_1, N_0, N);
    w(N_0+1:end,:) = reshape(w_2, N, M)';
    
%     wr(1:N_0,:) = w_ini;
%     wr(N_0+1,:) = a;

      
    if (E(end)<0.1)
        break
    end

end
    save(['../compare_mnist_relu_global_theory/compare_mnist_generalization_context_','M',num2str(M),'seed1',num2str(seed1),'seed2',num2str(seed2),'.mat'],'j0','f_t','gb_th','gb_relu','f_var','bias','ge_relu','ge','bias_th','fp_relu','fvar_relu','fp','y_t','fvar_th','H','g','x_0','x_t')

end