clear all;close all;
rng(77777,'twister');%for reproduction

%% size of the network s.t.
pmax=60000;pvisul=10000;% max number of patterns, vizualize every pvisul
sizes = [784,256,32,10];% Network architecture
[num_layers, biases_0, weights_0] = network_(sizes,pmax);%init network
N=num_layers-1;


%% load training data
patterns = loadMNISTImages('data/train-images.idx3-ubyte');% mnist train data 
labels = loadMNISTLabels('data/train-labels.idx1-ubyte');% mnist train label

labels(labels==0) = 10;
labels=dummyvar(labels)';

patterns=(patterns-mean(patterns));% mean removal
stds=ones(size(patterns,1),1)*sqrt(sum(patterns.^2,1));% norm standardization
patterns=patterns./stds;% norm standardization
labels=(labels-0.1)/sqrt(0.9);%same for labels
%patterns=(patterns-mean(patterns));%mean removal bis
%% Validation dataset
patterns_val=loadMNISTImages('data/t10k-images.idx3-ubyte');% mnist validation data
labels_val=loadMNISTLabels('data/t10k-labels.idx1-ubyte');% mnist validation label

labels_val(labels_val==0) = 10;
labels_val=dummyvar(labels_val)';

% same pre-processing as above
patterns_val=(patterns_val-mean(patterns_val));% mean removal
stds_val=ones(size(patterns_val,1),1)*sqrt(sum(patterns_val.^2,1));
patterns_val=patterns_val./stds_val;
labels_val=(labels_val-0.1)/sqrt(0.9);

%% Training parameters

epochs=100;% nb of epochs

% Shuffle table
shuffle_tbl=(1:pmax)'*ones(1,epochs);
for i=1:epochs
    idx = randperm(pmax) ;
    shuffle_tbl(idx,i) = shuffle_tbl(:,i);
end

% DROPOUT rates (Not used)
drp_rate=0.0;
drp_rate_input=0.0;

%%
%% %% SGD MOMENTUM
%%

mval=zeros(floor(pmax/pvisul),epochs);mval(:)=nan;% validation perc.
mloss=zeros(floor(pmax/pvisul),epochs);mloss(:)=nan;% training perc.
l2val=zeros(floor(pmax/pvisul),epochs);l2val(:)=nan;% validation loss
l2loss=zeros(floor(pmax/pvisul),epochs);l2loss(:)=nan;% training loss
temps=zeros(floor(pmax/pvisul),epochs);% times


mu=0.01; % learning rate
lambda=0.01;% regularization factor
beta=0.9; % momentum factor

weights=weights_0;biases=biases_0;%init network

J=zeros(pmax,epochs);%Insta loss
Jx=zeros(pmax,epochs);%Insta xrms


[Gh,gh] = grd_init(sizes);%init update
for epoch=1:epochs% main loop
    
    for p=1:pmax% loop over patterns
        
        
        
        if mod(p,pvisul)==0
            temps(floor(p/pvisul),epoch)=toc;% times
            [mval(floor((p/pvisul)),epoch),l2v]=validation(patterns_val,labels_val,biases,weights,N);% validation errors
            [mloss(floor((p/pvisul)),epoch),l2]=validation(patterns,labels,biases,weights,N);% training errors
            l2loss(floor((p/pvisul)),epoch)=mean(l2.^2);
            l2val(floor((p/pvisul)),epoch)=mean(l2v.^2);
            
            subplot(3,1,1)
            plot(linspace(0,epoch,(epoch-1)*pmax+p-1),movmean(J(1:(epoch-1)*pmax+p-1),64,'omitnan'));% plotting insta loss smoothed
            xlabel('Epoch','fontweight','bold');
            ylabel('Cost L^2','fontweight','bold');
            title('MNIST','fontweight','bold');
            drawnow
            subplot(3,1,2)
            hold all
            plot(linspace(1,epochs,length(mval(:))),100*mval(:),'r-o');axis([1 epochs 0 100]);
            plot(linspace(1,epochs,length(mloss(:))),100*mloss(:),'b-o');axis([1 epochs 0 100]);
            xlabel('Epoch','fontweight','bold');
            ylabel('Validation/Accuracy (%)','fontweight','bold');
            hold off
            drawnow
            subplot(3,1,3)
            histogram(l2);drawnow;
            
            disp (strcat('Epoch: ',int2str(epoch),'/',int2str(epochs),...
                ' Pattern: ', int2str(p),'/',int2str(pmax),...
                ' Val: ', num2str(100*max(mval(:)),4), '%'));
            
            
        end
        
        tic;
        sp=shuffle_tbl(p,epoch);% shuffeld pattern
        x0=dropout(patterns(:,sp),drp_rate_input);%
        [xx,dFa,d2Fa] = feedforward (x0, biases, weights,N,drp_rate);%forward pass
        dp=labels(:,sp);% label
        eN=dp-xx{N,1};% misfit
        
        
        
        
        xrms=x0'*x0;
        for i=1:N-1
            xrms=xrms+xx{i,1}'*xx{i,1};
        end
        J(p,epoch)=0.5*(eN'*eN);
        Jx(p,epoch)=0.5*(xrms)/N;
        
        [bb,b0] = backprop_sym (eN, weights, dFa, N,lambda,xx,x0);% regularized backprop
        [GGq,ggq] = grd_sym (bb,xx,dFa,x0,N);% cross correlation
        
        for i=1:N % compute update 
            gq=ggq{i,1};Gq=GGq{i,1};
            Gh{i,1}=beta*Gh{i,1}+(1-beta)*Gq;%
            gh{i,1}=beta*gh{i,1}+(1-beta)*gq;%
        end
        
        for i=1:N %update weights
            weights{i,1}=weights{i,1}-mu*Gh{i,1};
            biases{i,1}=biases{i,1}-mu*gh{i,1};
            
        end
        
    end
    
    
    
end
% save result
save ('results/mnist_SGDMomentum.mat','weights','biases','J','Jx','mval','mloss','l2val','l2loss','temps');




%%
%% ADAM
%%

mval=zeros(floor(pmax/pvisul),epochs);mval(:)=nan;
mloss=zeros(floor(pmax/pvisul),epochs);mloss(:)=nan;
l2val=zeros(floor(pmax/pvisul),epochs);l2val(:)=nan;
l2loss=zeros(floor(pmax/pvisul),epochs);l2loss(:)=nan;
temps=zeros(floor(pmax/pvisul),epochs);


mu=0.001;
lambda=0.01;
beta1=0.9;beta2=0.99;eps=1e-8;
[Gh,gh] = grd_init(sizes);
[Gh2,gh2] = grd_init(sizes);


weights=weights_0;biases=biases_0;%init network

J=zeros(pmax,epochs);
Jx=zeros(pmax,epochs);



for epoch=1:epochs
    
    for p=1:pmax
        pf=p+(epoch-1)*pmax;
        [NN,nn] = grd_init(sizes);%init update
        
        if mod(p,pvisul)==0
            temps(floor(p/pvisul),epoch)=toc;
            [mval(floor((p/pvisul)),epoch),l2v]=validation(patterns_val,labels_val,biases,weights,N);
            [mloss(floor((p/pvisul)),epoch),l2]=validation(patterns,labels,biases,weights,N);
            l2loss(floor((p/pvisul)),epoch)=mean(l2.^2);
            l2val(floor((p/pvisul)),epoch)=mean(l2v.^2);
            
            subplot(3,1,1)
            plot(linspace(0,epoch,(epoch-1)*pmax+p-1),movmean(J(1:(epoch-1)*pmax+p-1),64,'omitnan'));
            xlabel('Epoch','fontweight','bold');
            ylabel('Cost L^2','fontweight','bold');
            title('MNIST','fontweight','bold');
            drawnow
            subplot(3,1,2)
            hold all
            plot(linspace(1,epochs,length(mval(:))),100*mval(:),'r-o');axis([1 epochs 0 100]);
            plot(linspace(1,epochs,length(mloss(:))),100*mloss(:),'b-o');axis([1 epochs 0 100]);
            xlabel('Epoch','fontweight','bold');
            ylabel('Validation/Accuracy (%)','fontweight','bold');
            hold off
            drawnow
            subplot(3,1,3)
            histogram(l2);drawnow;
            
            disp (strcat('Epoch: ',int2str(epoch),'/',int2str(epochs),...
                ' Pattern: ', int2str(p),'/',int2str(pmax),...
                ' Val: ', num2str(100*max(mval(:)),4), '%'));
            
            
        end
        
        tic;
        sp=shuffle_tbl(p,epoch);% shuffeld pattern
        x0=dropout(patterns(:,sp),drp_rate_input);%
        [xx,dFa,d2Fa] = feedforward (x0, biases, weights,N,drp_rate);%forward
        dp=labels(:,sp);% label
        eN=dp-xx{N,1};% misfit
        
        
        xrms=x0'*x0;
        for i=1:N-1
            xrms=xrms+xx{i,1}'*xx{i,1};
        end
        J(p,epoch)=0.5*(eN'*eN);%
        Jx(p,epoch)=0.5*(xrms)/N;%
        
        [bb,b0] = backprop_sym (eN, weights, dFa, N,lambda,xx,x0);% regularized
        
        [GGq,ggq] = grd_sym (bb,xx,dFa,x0,N);
        
        
        for i=1:N
            
            gq=ggq{i,1};Gq=GGq{i,1};
            
            gh{i,1}=(beta1*gh{i,1}+(1-beta1)*gq)/(1-beta1^pf);
            Gh{i,1}=(beta1*Gh{i,1}+(1-beta1)*Gq)/(1-beta1^pf);
            
            gh2{i,1}=(beta2*gh2{i,1}+(1-beta2)*gq.^2)/(1-beta2^pf);
            Gh2{i,1}=(beta2*Gh2{i,1}+(1-beta2)*Gq.^2)/(1-beta2^pf);
            
            NN{i,1}=Gh{i,1}./(sqrt(Gh2{i,1})+eps);
            nn{i,1}=gh{i,1}./(sqrt(gh2{i,1})+eps);
        end
        
        
        
        
        for i=1:N %update weights
            weights{i,1}=weights{i,1}-mu*NN{i,1};
            biases{i,1}=biases{i,1}-mu*nn{i,1};
            
        end
        
    end
    
    
    
end

save ('results/mnist_ADAM.mat','weights','biases','J','Jx','mval','mloss','l2val','l2loss','temps');



%%
%%
%% ADAHESSIAN
%%

mval=zeros(floor(pmax/pvisul),epochs);mval(:)=nan;
mloss=zeros(floor(pmax/pvisul),epochs);mloss(:)=nan;
l2val=zeros(floor(pmax/pvisul),epochs);l2val(:)=nan;
l2loss=zeros(floor(pmax/pvisul),epochs);l2loss(:)=nan;
temps=zeros(floor(pmax/pvisul),epochs);

lmb0=0;lmb=zeros(N,1);% changed later
mu=0.001;

lambda=0.00;
beta1=0.9;beta2=0.99;eps=1e-8;
[Gh,gh] = grd_init(sizes);%init update
[Dh2,dh2] = grd_init(sizes);%init update
[dims,dim] = dimsINIT(sizes);%init dims


weights=weights_0;biases=biases_0;%init network

J=zeros(pmax,epochs);
Jx=zeros(pmax,epochs);



for epoch=1:epochs
    
    for p=1:pmax
        pf=p+(epoch-1)*pmax;
        [NN,nn] = grd_init(sizes);%init update
        
        if mod(p,pvisul)==0
            
            temps(floor(p/pvisul),epoch)=toc;
            
            [mval(floor((p/pvisul)),epoch),l2v]=validation(patterns_val,labels_val,biases,weights,N);
            [mloss(floor((p/pvisul)),epoch),l2]=validation(patterns,labels,biases,weights,N);
            l2loss(floor((p/pvisul)),epoch)=mean(l2.^2);
            l2val(floor((p/pvisul)),epoch)=mean(l2v.^2);
            
            subplot(3,1,1)
            plot(linspace(0,epoch,(epoch-1)*pmax+p-1),movmean(J(1:(epoch-1)*pmax+p-1),64,'omitnan'));
            xlabel('Epoch','fontweight','bold');
            ylabel('Cost L^2','fontweight','bold');
            title('MNIST','fontweight','bold');
            drawnow
            subplot(3,1,2)
            hold all
            plot(linspace(1,epochs,length(mval(:))),100*mval(:),'r-o');axis([1 epochs 0 100]);
            plot(linspace(1,epochs,length(mloss(:))),100*mloss(:),'b-o');axis([1 epochs 0 100]);
            xlabel('Epoch','fontweight','bold');
            ylabel('Validation/Accuracy (%)','fontweight','bold');
            hold off
            drawnow
            subplot(3,1,3)
            histogram(l2);drawnow;
            
            disp (strcat('Epoch: ',int2str(epoch),'/',int2str(epochs),...
                ' Pattern: ', int2str(p),'/',int2str(pmax),...
                ' Val: ', num2str(100*max(mval(:)),4), '%'));
            
            
        end
        
        tic;
        sp=shuffle_tbl(p,epoch);% shuffeld pattern
        x0=dropout(patterns(:,sp),drp_rate_input);%
        [xx,dFa,d2Fa] = feedforward (x0, biases, weights,N,drp_rate);%forward
        dp=labels(:,sp);% label
        eN=dp-xx{N,1};% misfit
        
        
        
       
        xrms=x0'*x0;
        for i=1:N-1
            xrms=xrms+xx{i,1}'*xx{i,1};
        end
        J(p,epoch)=0.5*(eN'*eN);
        Jx(p,epoch)=0.5*(xrms)/N;
        
        [bb,b0] = backprop_sym (eN, weights, dFa, N,lambda,xx,x0);% regularized
        
        [GGq,ggq] = grd_sym (bb,xx,dFa,x0,N);
        v=radeMacher(dim); [VV,vv] = dissemble(v,dims,N,sizes);
        %[DDq,ddq] = Hv(weights,N,dFa,VV,vv,xx,x0,lambda*ones(N,1),bb);
        [DDq,ddq] = Hsv(weights,N,dFa,VV,vv,xx,x0,lambda,bb);
        [d] = assemble(DDq,ddq);
        [DDq,ddq] = dissemble(d.*v,dims,N,sizes);
        
        for i=1:N
            
            gq=ggq{i,1};Gq=GGq{i,1};
            dq=ddq{i,1};Dq=DDq{i,1};
            
            gh{i,1}=(beta1*gh{i,1}+(1-beta1)*gq)/(1-beta1^pf);
            Gh{i,1}=(beta1*Gh{i,1}+(1-beta1)*Gq)/(1-beta1^pf);
            
            dh2{i,1}=(beta2*dh2{i,1}+(1-beta2)*dq.^2)/(1-beta2^pf);
            Dh2{i,1}=(beta2*Dh2{i,1}+(1-beta2)*Dq.^2)/(1-beta2^pf);
            
            NN{i,1}=Gh{i,1}./(sqrt(Dh2{i,1})+eps);
            nn{i,1}=gh{i,1}./(sqrt(dh2{i,1})+eps);
        end
        
        
        
        
        for i=1:N %update weights
            weights{i,1}=weights{i,1}-mu*NN{i,1};
            biases{i,1}=biases{i,1}-mu*nn{i,1};
            
        end
        
    end
    
    
    
end

save ('results/mnist_ADAHESSIAN.mat','weights','biases','J','Jx','mval','mloss','l2val','l2loss','temps');



%%
%%
%% %% ESN PRIME
%%


mval=zeros(floor(pmax/pvisul),epochs);mval(:)=nan;
mloss=zeros(floor(pmax/pvisul),epochs);mloss(:)=nan;
l2val=zeros(floor(pmax/pvisul),epochs);l2val(:)=nan;
l2loss=zeros(floor(pmax/pvisul),epochs);l2loss(:)=nan;
temps=zeros(floor(pmax/pvisul),epochs);


mu=0.01;
lambda=0.01;

weights=weights_0;biases=biases_0;%init network

J=zeros(pmax,epochs);%
Jx=zeros(pmax,epochs);%

beta=0.99;err=1;

x0h=0*patterns(:,1);%
[xxh,dFa] = feedforward (x0h, biases, weights,N,drp_rate);%forward
dph=xxh{N,1};
zzh=xxh;z0h=x0h;
[bbh,b0h] = backprop_sym (xxh{N,1}, weights, dFa, N,0,xxh,x0h); % unregularized


M_=cell(N,1);Y=cell(N,1);X=cell(N,1);V=cell(N,1);
for i=1:N
    M_{i,1}=eye(sizes(i+1))/beta;
    Y{i,1}=zeros(sizes(i+1));X{i,1}=zeros(sizes(i+1),sizes(i));
    V{i,1}=zeros(sizes(i),sizes(i));
    
end


for epoch=1:epochs
    
    for p=1:pmax
        
        [NN,nn] = grd_init(sizes);%init update
        
        if mod(p,pvisul)==0
            temps(floor(p/pvisul),epoch)=toc;
            [mval(floor((p/pvisul)),epoch),l2v]=validation(patterns_val,labels_val,biases,weights,N);
            [mloss(floor((p/pvisul)),epoch),l2]=validation(patterns,labels,biases,weights,N);
            l2loss(floor((p/pvisul)),epoch)=mean(l2.^2);
            l2val(floor((p/pvisul)),epoch)=mean(l2v.^2);
            
            subplot(3,1,1)
            plot(linspace(0,epoch,(epoch-1)*pmax+p-1),movmean(J(1:(epoch-1)*pmax+p-1),64,'omitnan'));
            xlabel('Epoch','fontweight','bold');
            ylabel('Cost L^2','fontweight','bold');
            title('MNIST','fontweight','bold');
            drawnow
            subplot(3,1,2)
            hold all
            plot(linspace(1,epochs,length(mval(:))),100*mval(:),'r-o');axis([1 epochs 95 100]);
            plot(linspace(1,epochs,length(mloss(:))),100*mloss(:),'b-o');axis([1 epochs 95 100]);
            xlabel('Epoch','fontweight','bold');
            ylabel('Validation/Accuracy (%)','fontweight','bold');
            hold off
            drawnow
            subplot(3,1,3)
            histogram(l2);drawnow;
            
            disp (strcat('Epoch: ',int2str(epoch),'/',int2str(epochs),...
                ' Pattern: ', int2str(p),'/',int2str(pmax),...
                ' Val: ', num2str(100*max(mval(:)),4), '%'));
            
        end
        tic
        sp=shuffle_tbl(p,epoch);% shuffeld pattern
        x0=dropout(patterns(:,sp),drp_rate_input);%
        [xx,dFa,d2Fa] = feedforward (x0, biases, weights,N,drp_rate);%forward
        
        dp=labels(:,sp);% label
        eN=dp-xx{N,1};% misfit
        err_=err;
        w=nrm2(eN); 
        err=err*beta+(1-beta)*w;
        
        
        
        lambda=err;
        [bb,b0] = backprop_sym (eN, weights, dFa, N,lambda,xx,x0); % unregularized
        
        
        xrms=x0'*x0;
        x0h=(beta*x0h*err_+(1-beta)*x0*w)/err;

        
        for i=1:N
            xrms=xrms+xx{i,1}'*xx{i,1};
            xxh{i,1}=(beta*xxh{i,1}*err_+(1-beta)*xx{i,1}*w)/err;            
        end
        J(p,epoch)=0.5*(eN'*eN);%
        Jx(p,epoch)=0.5*(xrms)/N;%
        

        i=1;
        y=dFa{i,1}.*bb{i,1};n=Nh{i,1};ny=nrm2(y);
        x=-lambda*x0;
        Y{i,1}=(beta*Y{i,1}+(1-beta)*(y*y')*w);
        X{i,1}=(beta*X{i,1}+(1-beta)*(y*x')*w);

        NN{i,1}=(w*ny*eye(length(y))+Y{i,1})\X{i,1};        
        nn{i,1}=NN{i,1}*x0h;
        
        for i=2:N
            y=dFa{i,1}.*bb{i,1};n=Nh{i,1};ny=nrm2(y);
            x=-lambda*xx{i-1,1};
            Y{i,1}=(beta*Y{i,1}+(1-beta)*(y*y')*w);%w
            X{i,1}=(beta*X{i,1}+(1-beta)*(y*x')*w);%w
            V{i,1}=beta*V{i,1}+(1-beta)*((x-zzh{i-1,1})*(x-zzh{i-1,1})')*w;
            NN{i,1}=(w*ny*eye(length(y))+Y{i,1})\X{i,1};
            nn{i,1}=NN{i,1}*xxh{i-1,1};
            
        end
        
        
        
        
        
        for i=1:N %update weights
            weights{i,1}=weights{i,1}-(mu)*NN{i,1};
            biases{i,1}=biases{i,1}-(mu)*nn{i,1};
            
        end
        
        
        
    end
    
    
    
end

save ('results/mnist_ESN_PRIME.mat','weights','biases','J','Jx','mval','mloss','l2val','l2loss','temps');

%%
%% AUXILIARY FUNCTIONS
%%
%%

function [v_] = dropout(v,drp_rate)
n=length(v);
v_=v.*randsrc(n,1,[0,1;drp_rate,1-drp_rate]);
end



function [mc,l2]=validation(data_val,labels_val,biases,weights,N)

pval=size(data_val,2);
acc=zeros(pval,1);
l2=zeros(pval,1);
for p=1:pval
    x0=data_val(:,p);%
    [xx] = feedforward (x0, biases, weights,N,0);%forward
    err=labels_val(:,p)-xx{N,1};l2(p)=sqrt(err'*err);
    [~,eN]=max(softmax(xx{N,1}));% misfit
    [~,rN]=max(softmax(labels_val(:,p)));
    if eN==rN
        acc(p)=1;
    end
end

mc=mean(acc);

end

function [nrmv] = nrm2(x)
nrmv=(x'*x);
end


%% Functions for ADAHESSIAN


function [dims,dim] = dimsINIT(sizes)
n=length(sizes);
dims=zeros(2*(n-1),1);
dim=0;
for i=1:n-1
    dim=dim+sizes(i)*sizes(i+1)+sizes(i+1);
    dims(i)=sizes(i)*sizes(i+1);
    dims(i+n-1)=sizes(i+1);
end
end

function [grd] = assemble(GG,gg)
tmp_GG = cellfun(@(M) M(:), GG, 'Uniform', 0);
tmp1 = vertcat(tmp_GG{:});
tmp_gg = cellfun(@(M) M(:), gg, 'Uniform', 0);
tmp2 = vertcat(tmp_gg{:});
grd=[tmp1;tmp2];
end

function [GG,gg] = dissemble(grd,dims,N,sizes)
GG=cell(N,1);gg=cell(N,1);Dims=cumsum(dims);
GG{1,1}=reshape(grd(1:Dims(1)),[sizes(2) sizes(1)]);
gg{1,1}=grd(1+Dims(N):Dims(N+1));
for i=2:N
    GG{i,1}=reshape(grd(1+Dims(i-1):Dims(i)),[sizes(i+1) sizes(i)]);
    gg{i,1}=grd(1+Dims(N+i-1):Dims(N+i));
end
end

function [v] = radeMacher(dim)
v=randsrc(dim,1,[-1,1;0.5,0.5]);
end