clc; clear; close all; tic;


%% Set task params
deltaT = 0.02;
funcToApprox=@(x) sin(x);
%funcToApprox=@(x) tanh(x);
numSteps=100;
readOutRng=round([.7 1]*numSteps);
inDim=1;
outDim=1;


%% set training params
numIterations = 10000; % number of optimization iterations
batchSz =32; %  note that this is not exactly one route per batch, but a random selecion with replacement
svRegWeight=0;% weight on max SV squared penalty- makes training more stable
g = 0.8; % from paper
stepSz = 0.001; % tensorflow default->.001
momentumW = 0.9; % tensorflow default->.9
rmsW = 0.999; % tensorflow default->.999
epsilon = 1e-7; % tensorflow default->1e-7
stepSzHalfLife=2500;
stepSzDecay=.5^(1/stepSzHalfLife);


%% setting up the neural network parameters
hDim = 100;
tau=.1;
gamma = deltaT/tau;
softPlusScl=10;
satScl=1.5;
satVal=8;
%actFun = @(x) log(1+exp(softPlusScl*x))/softPlusScl-log(1+exp(satScl*(x-satVal)))/satScl;
actFun = @(x) log(1+exp(softPlusScl*x))/softPlusScl;
NeuralNoiseStd = .1;




%%
% initialize network/training from scratch. some are gaussian normal and
% some are zeros
Wrec0 = g*diag(randn(1,hDim))/sqrt(hDim); % the matrix that multiplies the current neural state for recurrent dynamics
Win0 = g*randn(hDim,inDim)/sqrt(hDim); % the matrix that multiplies the current rule input for recurrent dynamics
Bin0 = zeros(hDim,1); % bias term for recurrent dynamics
Wout0 = g*randn(outDim,hDim)/sqrt(hDim); % matrix that multiplies neural state to produce the output
Bout0 = Bin0; % bias term for output production ... this should be just 2D, but stored as the full size of the matrix
h00 = zeros(hDim,1); % initial neural state for every run


Theta0 = [Wrec0 Win0 Bin0 Wout0' Bout0 h00]; % all the unknowns, initial values
Theta=dlarray(Theta0);%turn into dlarray 
momentumVec = zeros(size(Theta)); % initial momentum
rmsVec=momentumVec; % something for adam
ThetaList=[];
maxSVList=[];
batchCostList=[];
RMSEList=[];

%% actual neural network training
disp('---- beginning optimization -----');

stepSzNom=stepSz;
figure()
for iLoop = 1:numIterations

    % generate a training set
    inList=rand(inDim,batchSz)*2*pi; % use this for sin
    %inList=rand(inDim,batchSz)*8 - 4; % use this for tanh
    outTargetList=funcToApprox(inList);
    readOutStepNumList=randi(readOutRng,[1 batchSz]);

    % get cost and gradient
    [Cost,dCostdTheta,RMSE,svPenalty] = ...
    dlfeval(@computeBatchCostFuncNet_noiseOut,Theta,inList,outTargetList, ...
    readOutStepNumList,actFun,gamma,NeuralNoiseStd,svRegWeight);
    batchCostList=[batchCostList; Cost];
    maxSVList=[maxSVList;sqrt(svPenalty)];
    RMSEList=[RMSEList;RMSE];



    if isnan(Cost)%||norm(dCostdTheta,"fro")<eps
        disp('things went Nan');
        break
    end

    if all(Cost<=batchCostList)
        bestCostTheta=Theta;
        bestCostIter=iLoop-1;
        besCosttMomentumVec=momentumVec;
        bestCostRmsVec=rmsVec;
    end


    if all(RMSE<=RMSEList)
        bestPerfTheta=Theta;
        bestPerfIter=iLoop-1;
        bestPerfMomentumVec=momentumVec;
        bestPerfRmsVec=rmsVec;
    end

    % update network with ADAM, per the tensorflow implimentation
    momentumVec=momentumW*momentumVec+(1-momentumW)*dCostdTheta;
    rmsVec=rmsW*rmsVec+(1-rmsW)*(dCostdTheta.^2);
    stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-momentumW^iLoop);

    if iLoop<10
        stepSzNow=stepSzNow/10;%blowup protection for the first update
    end

    Theta=Theta-stepSzNow*momentumVec./(sqrt(rmsVec)+epsilon);

    if mod(iLoop,10)==0
        disp(['iter = ' num2str(iLoop) '. cost = ' num2str(Cost) '. RMSE = ' num2str(RMSE) '. maxSV = ' num2str(maxSVList(end)) '. stepSzNow = ' num2str(stepSzNow) '. runtime =' num2str(toc)]);
        plot(RMSEList);
        legend('RMSE');drawnow;
        if mod(iLoop,100)==0
            ThetaList=cat(3,ThetaList,extractdata(Theta));
        end
    end
    % decay the LR
    stepSzNom=stepSzDecay*stepSzNom;

end
%%
batchCostList=extractdata(batchCostList);
runTime=toc;

fileName=['simple_noiseOut_funcsin_stepSz1en3_stepSzHL25e2_nStdv1en1_batchSz32_numSteps100_hDim100_1e4batches'];
save(fileName)

