clc; close all; clear;

%% set task params
funcToApprox=@(x) sin(x);

%% Set NN params

actFun=@(x) x.*(x>0);% relu
noiseStdv=1e-1;
wRegFun=@(x) weightDecayPen(x);
numLayers=3;
inDim=1;
outDim=1;
hDim=100;

% training params
numIterations=2e4;
stepSz = 0.001; % tensorflow default->.001
momentumW = 0.9; % tensorflow default->.9
rmsW = 0.999; % tensorflow default->.999
epsilon = 1e-7; % tensorflow default->1e-7
wRegWeight=1e-3;
stepSzHalfLife=.5*numIterations;
stepSzDecay=.5^(1/stepSzHalfLife);
batchSz=36;% how many training examples in every batch


%% initialize training
g=1e-3;
% initialize 
Theta0(1).W=g*randn(hDim,inDim); Theta0(1).B=zeros(hDim,1);
for lay=2:numLayers-1
    Theta0(lay).W=g*randn(hDim); Theta0(lay).B=zeros(hDim,1);
end
Theta0(numLayers).W=g*randn(inDim,hDim); Theta0(numLayers).B=zeros(inDim,1);
Theta=structArray2DL(Theta0);
momentumVec = structZerosLike(Theta);
rmsVec=momentumVec;


% initialize plots
batchCostList=[];
RMSEList=[];
wRegPenList=[];
if wRegWeight
    figure()
    H.wRegPen=plot(nan);title('wRegPen');
end

figure()
H.reconRMSE=plot(nan);title('recon RMSE');
%% train SAE

disp('---- beginning optimization -----');

stepSzNom=stepSz;
tic;
for iLoop = 1:numIterations

    % generate a training batch
    inList=rand(inDim,batchSz)*2*pi; % use this for sin 
    %inList=rand(inDim,batchSz)*8 - 4; % use this for tanh
    outTargList=funcToApprox(inList);
    

    % get cost and gradient
    [Cost,dCostdTheta,RMSE,wRegPen] ...
    = dlfeval(@computeBatchCost,Theta,inList,outTargList,actFun,wRegFun,wRegWeight,noiseStdv);
    batchCostList=[batchCostList;extractdata(gather(Cost))];
    RMSEList=[RMSEList; extractdata(gather(RMSE))];

    if wRegWeight
        wRegPenList=[wRegPenList; extractdata(gather(wRegPen))];
    end


    if structAnyIsNan(dCostdTheta)&&isfinite(Cost)
        disp('grad went nan before cost');
        break
    end


    if ~isfinite(Cost)
        disp('things went Nan');
        break
    end

    if all(Cost<=batchCostList)
        bestPerfTheta=Theta;
        bestPerfIter=iLoop-1;
        bestPerfMomentumVec=momentumVec;
        bestPerfRmsVec=rmsVec;
    end

    % update network with ADAM, per the tensorflow implimentation
    momentumVec=structExpMoveMean(momentumVec,dCostdTheta,momentumW,1);
    rmsVec=structExpMoveMean(rmsVec,dCostdTheta,rmsW,2);
    stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-momentumW^iLoop);

    Theta = adamUpdate(Theta,momentumVec,rmsVec,stepSzNow,epsilon);
    if structAnyIsNan(Theta)
        disp('Theta went nan first');
        break;
    end

    if mod(iLoop,10)==0
        disp(['iter = ' num2str(iLoop) '. reconRMSE = ' num2str(RMSE) '. stepSzNow = ' num2str(stepSzNow) '. runtime =' num2str(toc)]);
        if wRegWeight
            H.wRegPen.YData=wRegPenList;
        end

        H.reconRMSE.YData=RMSEList;
        drawnow;
        if mod(iLoop,100)==0
            ThetaBackup=Theta;
            rmsVecBackup=rmsVec;
            momentumVecBackup=momentumVec;
            backupIter=iLoop;

        end

    end
    % decay the LR
    stepSzNom=stepSzDecay*stepSzNom;

end
%%
runTime=toc;

save('sin_nStdv1en1_iter2e4_hdim100')