clc;clear;close all;
saveFigs=1;% set to true to save out figure files
%% set parameters for single neuron regulator experiments
%actFun=@(x) sign(x).*min(abs(x),1);% clips the argument to [-1,1], so it's saturating, but not rectifying
actFun=@(x) tanh(x); % saturating, nonrectifying, and smooth
numNoiseLevels=100;% number of inside noise levels to test
maxNoiseLevel=2;% maximum inside noise standard deviation to check
neurNoiseStdvList=linspace(0,maxNoiseLevel,numNoiseLevels+1);
neurNoiseStdvList=gpuArray(neurNoiseStdvList(2:end)');% vector containing all the inside noise levels to check
noiseOutStdv=1;% all single neuron regulators have outside noise stdv = 1, without loss of generality
gamma=gpuArray(.2);% gamma = (simulation time step)/(neural time constant)
numWs=4;% number of recurrent weights to check
wMax=-(1-gamma)/gamma;% maximum magnitude recurrent weight without overshoot due to time discretization
w=linspace(0,wMax,numWs+1);
w=gpuArray(permute(w(2:end),[1 3 2]));% vector containing all the recurrent weights to check
numTargets=100;% number of setpoints to check
targetList=gpuArray(linspace(-1,1,numTargets));% vector containing all the setpoints to check
netFun=@(h,b) (1-gamma)*h+gamma*(actFun(w.*h+b+neurNoiseStdvList.*randn(size(b),'gpuArray'))+noiseOutStdv*randn(size(b),'gpuArray'));% vectorized neural dynamics operator when inside noise is nonzero, gives h_t+1 from h_t
noiselessNetFun=@(h,b) (1-gamma)*h+gamma*(actFun(w.*h+b)+noiseOutStdv*randn(size(b),'gpuArray'));% vectporized neural dynamics operator when inside noise is nonzero
numSteps=50;% number of time steps in a trial


%% train single neuron regulators

% % set ADAM optimizer parameters
numIter=10000;% number of training steps to take
stepSz=.001;% initial step size
momentumW=.9;% momentum weight
rmsW=.999;% RMS prop weight
epsilon=1e-7;% parameter to prevent divide-by-zero errors
stepSzHalfLife=numIter/2.2;% step size half life for step size decay
decay=.5^(1/stepSzHalfLife);% step size decay factor

% % initialize 
b0=targetList.*(1-w).*ones(numNoiseLevels,1);% starting guess for all biases would be optimal without any noise
b=dlarray(b0);
momentumVec=zeros(1,numTargets,'like',b);
rmsVec=momentumVec;
stepSzNom=stepSz;
costList=[];

% % run trianing loop
tic;
for iLoop=1:numIter
    % evaluate cost gradients
    [Cost,dCostdb]=dlfeval(@computeBatchCost,b,netFun,targetList,numSteps);
    costList=[costList;gather(extractdata(Cost))];

    % update neural biases with ADAM, per the tensorflow implimentation
    momentumVec=momentumW*momentumVec+(1-momentumW)*dCostdb;
    rmsVec=rmsW*rmsVec+(1-rmsW)*(dCostdb.^2);
    stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-momentumW^iLoop);
    if iLoop==1
        stepSzNow=stepSzNow/10;%blowup protection for the first update
    end
    deltaB=stepSzNow*momentumVec./(sqrt(rmsVec)+epsilon);
    b=b-deltaB;
    % display progress
    if mod(iLoop,10)==0
        disp(['iter = ' num2str(iLoop) '. RMSE = ' num2str(Cost) '. stepSzNow = ' num2str(stepSzNow) '. runtime =' num2str(toc)]);
    end
    % apply step size decay
    stepSzNom=stepSzNom*decay;
end
b=extractdata(b);



%% compute permformance with and without the training noise inside the activation function for all trained regulators 
bRep=repmat(b,[1 1 1 100]);% replicated biases to support vectorized monte carlo approximation of MSEs
% get MSEs with inside noise level set to the training value
[~,~,finalMSEs]=computeBatchCost(bRep,netFun,targetList,1000*numSteps);% monte carlo MSEs based on 100,000x more time steps than a single trial (100 trials 1000x longer than a training trial)
finalMSEs=mean(finalMSEs,4);
% get MSEs with inside noise level set to zero
[~,~,noiselessMSEs]=computeBatchCost(bRep,noiselessNetFun,targetList,1000*numSteps);% monte carlo MSEs based on 100,000x more time steps than a single trial (100 trials 1000x longer than a training trial)
noiselessMSEs=mean(noiselessMSEs,4);
% collect things for plotting
finalMSEs=gather(finalMSEs);
noiselessMSEs=gather(noiselessMSEs);
MSEdif=finalMSEs-noiselessMSEs;% used to find regions of parameter space where a noise preference developes
[~,bestTargetInds]=min(finalMSEs,[],2);
bestTargets=targetList(bestTargetInds);% optimal setpoints for each inside noise level
totalRunTime=toc;
%% plotting
for k=1:numWs % create a seperate plot for each recurrent weight that we checked
    cLevels=[-10 0 10];
    figure(k)
    contourf(targetList,neurNoiseStdvList,MSEdif(:,:,k),cLevels);
    colormap(sky(2))
    hold on;
    scatter(bestTargets(:,:,k),neurNoiseStdvList,80,'m','filled');
    title(['W = ' num2str(w(k))])
    xlabel('Setpoint Location'); ylabel('Inside Noise Level')
    % optionally save out .fig and .svg files
    if saveFigs
        saveas(gcf,['singleNeurW' num2str(k) '_nonRectifyingTanh']);
        saveas(gcf,['singleNeurW' num2str(k) '_nonRectifyingTanh.svg']);
    end
end




%% cost function definition for single neuron regulator experiments
function [Cost,dCostdb,MSEspec]=computeBatchCost(b,netFun,targetList,numSteps)
    h=targetList;% initiaize neural activation at the target value

    % compute MSE over num steps
    MSEspec=zeros(size(b),'like',b);
    for iStep=1:numSteps

        h=netFun(h,b);
        SE=(targetList-h).^2;
        MSEspec=MSEspec*(iStep-1)/iStep + SE/iStep;

    end

    Cost=sqrt(mean(MSEspec,"all")+1e-8);% cost is a smoothed RMSE
    if isdlarray(b)
    dCostdb=dlgradient(Cost,b);
    else
    dCostdb=0;
    end

end