% calculate the kernel as a function of depth.

function [K0,K1,K2,K3,Kgp] = kernel_deep

N_1 = 500;

N_0= 784;   
Ns = 2000;
M0 = 8;


seed = 1;
P = 600;
%Nall = fliplr(P./alphaall);

%N_1 = 200;
%sigma = sigmaall(seed1);
sigma = 1;
beta = inf;


%%
    load('mnist.mat')
    
    noise = 0.;

    rng(10,'twister');

    %% generating samples from noisy linear teacher neurons
    % random projection weights
trainX = double(trainX);
testX = double(testX);

trainX = (trainX'-mean(trainX'))./std(trainX');
testX = (testX'-mean(testX'))./std(testX');

     
     rng(2,'twister');


    indtrain = find(mod(trainY,2)==0);
    indtest = find(mod(testY,2)==0);
    indtrain1 = find(mod(trainY,2)==1);
    indtest1 = find(mod(testY,2)==1);
    index = [indtrain(1:P/2),indtrain1(1:P/2)];
    index_t = [indtest(1:Ns/2),indtest1(1:Ns/2)];

    x_0 = trainX(:,index);
    x_t = testX(:,index_t);

    y = 2*(mod(trainY(index),2)==0)-1;
    y_t = 2*(mod(testY(index_t),2)==0)-1;

    th = 0;


%     rng(0,'twister');
%     iteration = 100;
%     beta0 = 0.1;
%     [g,xc] = softkmeans(M, x_0, beta0, iteration, 3);
%     g = g(1:M,:);
% %g = (g>1/N/2);
% 
%     for i0 = 1:M
%     d(i0,:) = sum((x_t - xc(:, i0)).^2)';
%     expd(i0,:) = exp(-beta0*d(i0,:));
%     end
%     gp = expd./sum(expd);
%     gp = gp(1:M,:);
%gp = (gp>1/N/2);
     rng(3,'twister');
     th = 0.;
    v = normrnd(0,1,M0,N_0);
    g = (v*x_0>th);
    gp = (v*x_t>th);
    
ind = find(sum(g)==0);
for i = 1:length(find(sum(g)==0))
    field = v*x_0(:,ind(i));
    maxind = find(field==max(field));
    g(maxind,ind(i)) = 1;
end
ind = find(sum(gp)==0);
for i = 1:length(find(sum(gp)==0))
    field = v*x_t(:,ind(i));
    maxind = find(field==max(field));
    gp(maxind,ind(i)) = 1;
end
    
 

M = ~logical(eye(Ns));
T = ~logical(testY(index_t)'==testY(index_t));
% noise = 0;
% seed = 1;
%    iteration = 100;
%     beta0 = 0.1;
%     [g,xc] = softkmeans(M0, x_0, beta0, iteration, 3);
%     g = g(1:M0,:)+ noise*normrnd(0,1,M0,P);
%     %g(g>=1/M) = 1;
% 
% 
%     for i0 = 1:M0
%        d(i0,:) = sum((x_t - xc(:, i0)).^2)';
%        expd(i0,:) = exp(-beta0*d(i0,:));
%     end
%     gp = expd./sum(expd);
%     gp = gp(1:M0,:)+ noise*normrnd(0,1,M0,Ns);
    %gp(gp>=1/M) = 1;
  %v = normrnd(0,1,M0,N_0);
  %g =(v*x_train>0);
%  gp =(v*x>0);
%g = normrnd(0,1,M,P);
%gp = normrnd(0,1,M,N_s);
   % g = normrnd(0,1,M0,P);
   %g = (normrnd(0,1,M0,P)>0);

   % compute K0
   
    K0 = sigma^2/N_0*x_t'*x_t;
    K0g = sigma^2/N_0*x_0'*x_0;
    K00g = sigma^2/N_0*x_t'*x_0;
    k01 = K0./(sqrt(diag(K0))*sqrt(diag(K0))');
    fmean = K00g*pinv(K0g)*y';
    bias0 = mean((fmean-y_t').^2);
    fvar = diag(K0-K00g*pinv(K0g)*K00g');
    ge0 = bias0 + mean(fvar);
    sameblock0 = k01(logical((1-T).*M));
    difblock0 = k01(logical(T));
    for l = 1:10

     Kgp{l} = sigma^l*((gp'*gp)/M0).^l.*K0;
     Kg = sigma^l*((g'*g)/M0).^l.*K0g;
     Kggp = sigma^l*((gp'*g)/M0).^l.*K00g;
     fmean = Kggp*pinv(Kg)*y';
     fvar = diag(Kgp{l} - Kggp*pinv(Kg)*Kggp');
     
     biasgp(l) = mean((fmean-y_t').^2);
     ge(l) = biasgp(l) + mean(fvar);
     Kgp{l} = Kgp{l}./(sqrt(diag(Kgp{l}))*sqrt(diag(Kgp{l}))');
     sameblock{l} = Kgp{l}(logical((1-T).*M));
     difblock{l} = Kgp{l}(logical(T));
     biasacc(l)=   1-mean((sign(y_t).*sign(fmean')==1));

    

    end
    

    


   % compute K1
    [error,H] = numeric_saddlepoint_nonlinear(N_0,P,M0,g,y,x_0,N_1,sigma,beta);
    
    K1 = (gp'*H*gp).*K0;
    K1g = (g'*H*g).*K0g;
    K11g = (gp'*H*g).*K00g;
    fmean = K11g*pinv(K1g)*y';
    bias1 = mean((fmean-y_t').^2);
     bias1acc=   1-mean((sign(y_t).*sign(fmean')==1));
    fvar = diag(K1 - K11g*pinv(K1g)*K11g');
    ge1 = mean(fvar) + bias1;
    K1 = K1./(sqrt(diag(K1))*sqrt(diag(K1))');
    sameblock1 = K1(logical((1-T).*M));
    difblock1 = K1(logical(T));
    
        [error,error_sq,error_cu,error_qu, H,H_sq,H_cu,H_qu] = numeric_saddlepoint_nonlinear_deep4opt(N_0,P,M0,g,y,x_0,N_1,sigma,seed,beta);
    gpm = permute(gp,[1,3,2]).*permute(H*gp,[3,1,2]);
    gpm = reshape(gpm,M0^2,Ns);
    gc = permute(gp,[1,3,2]).*permute(H_sq*gpm,[3,1,2]);
    gc = reshape(gc,M0^3,Ns);
    gq = permute(gp,[1,3,2]).*permute(H_cu*gc,[3,1,2]);
    gq = reshape(gq,M0^4,Ns);
    
    gm = permute(g,[1,3,2]).*permute(H*g,[3,1,2]);
    gm = reshape(gm,M0^2,P);
    gc0 = permute(g,[1,3,2]).*permute(H_sq*gm,[3,1,2]);
    gc0 = reshape(gc0,M0^3,P);
    gq0 = permute(g,[1,3,2]).*permute(H_cu*gc0,[3,1,2]);
    gq0 = reshape(gq0,M0^4,P);
    
    K4 = (gq'*H_qu*gq).*K0;
    K4g = (gq0'*H_qu*gq0).*K0g;
    K44g = (gq'*H_qu*gq0).*K00g;
    
    fmean = K44g*pinv(K4g)*y';
    bias4 = mean((fmean-y_t').^2);
    fvar = diag(K4 - K44g*pinv(K4g)*K44g');
    ge4 = mean(fvar) + bias4;
    bias4acc=   1-mean((sign(y_t).*sign(fmean')==1));

    
    K4 = K4./(sqrt(diag(K4))*sqrt(diag(K4))');
    sameblock4 = K4(logical((1-T).*M));
    difblock4 = K4(logical(T));
    

    
   % compute K2
    [error,error_sq,H_sqrt,H_sq] = numeric_saddlepoint_nonlinear_deep_newopt(N_0,P,M0,g,y,x_0,N_1,sigma,seed,beta);

    gpmH = permute(gp,[1,3,2]) .* permute(H_sqrt*gp,[3,1,2]);
    gpmH = reshape(gpmH, M0^2, Ns);
    gmH = permute(g,[1,3,2]) .* permute(H_sqrt*g,[3,1,2]);
    gmH = reshape(gmH, M0^2, P);
    
    
    K2 = (gpmH'*H_sq*gpmH).*K0;
    K2g = (gmH'*H_sq*gmH).*K0g;
    K22g = (gpmH'*H_sq*gmH).*K00g;
    fmean = K22g*pinv(K2g)*y';
    bias2 = mean((fmean-y_t').^2);
    fvar = diag(K2 - K22g*pinv(K2g)*K22g');
    ge2 = mean(fvar) + bias2;
    
    K2 = K2./(sqrt(diag(K2))*sqrt(diag(K2))');
    sameblock2 = K2(logical((1-T).*M));
    difblock2 = K2(logical(T));
    
bias2acc=   1-mean((sign(y_t).*sign(fmean')==1));

    
   % compute K3
   
    [error,error_sq,error_cu,H,H_sq,H_cu] = numeric_saddlepoint_nonlinear_deep3opt(N_0,P,M0,g,y,x_0,N_1,sigma,seed,beta);
    gpm = permute(gp,[1,3,2]).*permute(H*gp,[3,1,2]);
    gpm = reshape(gpm,M0^2,Ns);
    gc = permute(gp,[1,3,2]).*permute(H_sq*gpm,[3,1,2]);
    gc = reshape(gc,M0^3,Ns);
    
    gm = permute(g,[1,3,2]).*permute(H*g,[3,1,2]);
    gm = reshape(gm,M0^2,P);
    gc0 = permute(g,[1,3,2]).*permute(H_sq*gm,[3,1,2]);
    gc0 = reshape(gc0,M0^3,P);
    
    K3 = (gc'*H_cu*gc).*K0;
    K3g = (gc0'*H_cu*gc0).*K0g;
    K33g = (gc'*H_cu*gc0).*K00g;
    
        fmean = K33g*pinv(K3g)*y';
    bias3 = mean((fmean-y_t').^2);
    fvar = diag(K3 - K33g*pinv(K3g)*K33g');
    ge3 = mean(fvar) + bias3;
    

    
    K3 = K3./(sqrt(diag(K3))*sqrt(diag(K3))');
    sameblock3 = K3(logical((1-T).*M));
    difblock3 = K3(logical(T));
   % compute Kgp

%    


bias3acc=   1-mean((sign(y_t).*sign(fmean')==1));


    save('kernel_shape_depth_M10_mnist_randomg2.mat','bias1acc','bias2acc','bias3acc','bias4acc','biasacc','ge4','bias4','biasgp','bias3','bias1','bias2','bias2','ge1','ge2','ge3','ge','Kgp','K1','K2','K3','testY','index_t','sameblock4','difblock4','sameblock1','difblock1','sameblock0','difblock0','sameblock2','difblock2','sameblock3','difblock3','sameblock','difblock');


end