clc;clear;close all;

tic;
% set network parameters
hDim=256;
uDim=9;
zDim=3;
gamma=.2;%from paper
softPlusScl=10;
actFun=@(x) x.*(x>0);% Original paper tried many activation functions
% set training parameters
NeuralNoiseStd= .1; 
InputNoiseStd=.01*sqrt(2/gamma); %from paper
L2regScl=1e-6;% from paper->1e-6
g=.8;% from paper

stepSz=.001;%from paper, tensorflow default->.001
momentumW=.9;%from paper, tensorflow default->.9
rmsW=.999;%from paper, tensorflow default->.999
epsilon=1e-7;% tensorflow default->1e-7
numBatches=15000; 
batchSz=64;%from paper->64
stepSzHalfLife=numBatches/2;
decayW=.5^(1/stepSzHalfLife);



% initialize network/training from scratch

Wrec0=g*diag(randn(1,hDim))/sqrt(hDim);
Win0=g*randn(hDim,uDim)/sqrt(hDim);
Bin0=zeros(hDim,1);
Wout0=g*randn(zDim,hDim)/sqrt(hDim);
Bout0=Bin0;

Theta0=dlarray([Wrec0 Win0 Bin0 Wout0' Bout0]);
Theta=Theta0;
momentumVec=zeros(size(Theta));
rmsVec=momentumVec;
batchCostList=cell(1,6);

figure()
title('batch cost history')
hold on
for taskNum=1:6
    H(taskNum)=plot(nan,nan);
end

hold off
xlabel('batch')
legend('delayPro','delayAnti','memPro','memAnti','reactPro','reactAnti');

%% training


bestRmsCost=inf;
stepSzNom=stepSz;
for iLoop=1:numBatches
    
    % get new demo batch
    [uList,zTargetList,maskList,taskNum] = TrainBatchGen(batchSz,InputNoiseStd);

    % get cost and gradient
    [Cost,dCostdTheta] = dlfeval(@cogNet_BatchCost,Theta,uList,zTargetList,maskList,actFun,gamma,NeuralNoiseStd,L2regScl);
    Cost=extractdata(Cost);
    batchCostList{taskNum}=[batchCostList{taskNum}; iLoop Cost];
    if isnan(Cost)%||norm(dCostdTheta,"fro")<eps
        %disp('things went Nan');
        break
    end

    if iLoop>100
        rmsCost=sqrt(mean(cellfun(@(x) x(end,2),batchCostList).^2));
        if rmsCost<bestRmsCost
            bestRmsCost=rmsCost;
            bestTheta=Theta;
            bestIter=iLoop-1;
        end
    end



    % update network with ADAM, per the tensorflow implimentation
    momentumVec=momentumW*momentumVec+(1-momentumW)*dCostdTheta;
    rmsVec=rmsW*rmsVec+(1-rmsW)*(dCostdTheta.^2);
    stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-momentumW^iLoop);
    Theta=Theta-stepSzNow*momentumVec./(sqrt(rmsVec)+epsilon);
    if iLoop>=100 && mod(iLoop,10)==0
        disp(['iter = ' num2str(iLoop) '. Cost = ' num2str(Cost) '. runTime = ' num2str(toc)]);
        for taskNum=1:6
            H(taskNum).XData=batchCostList{taskNum}(:,1);
            H(taskNum).YData=batchCostList{taskNum}(:,2);
        end
        drawnow;
    end
    % apply learning rate decay
    stepSzNom=stepSzNom*decayW;
end


runtime=toc;

save('Results_neurNoiseStd1en1_Mem2x_15e3batches_Relu_hDim256');


