clc; clear; close all; tic;


%% Set task params
uDim=2;% 1D tangenial velocity + 1D heading angle in [-pi,pi]
zDim=2;% 2D translation
numSteps=100;
dsMu=.1;% average tangential velociy
dsStdv=dsMu*.1;% standardeviaion on gausian tangential velocity
dheadStdv=20*pi/180;% standard deviation on zero centered roational velocity
deltaT=.02;
roomL=2.5;


%% set training params
numIterations = 10000; % number of optimization iterations
batchSz =32; %
actRegWeight=0;% 
wRegWeight=0;%
stepSz = 0.001; % tensorflow default->.001
momentumW = 0.9; % tensorflow default->.9
rmsW = 0.999; % tensorflow default->.999
epsilon = 1e-7; % tensorflow default->1e-7
stepSzHalfLife=5000;
stepSzDecay=.5^(1/stepSzHalfLife);


%% setting up the neural network parameters
hDim = 100;
tau=.2;% consistent with cueva making dt = tau/10
gamma = deltaT/tau;
actFun = @(x) x.*(x>0);
NeuralNoiseStd =.1; 





%%
% initialize network/training from scratch. some are gaussian normal and
% some are zeros
Wrec0 = diag(randn(1,hDim))/sqrt(hDim); % the matrix that multiplies the current neural state for recurrent dynamics
Win0 = randn(hDim,uDim)/sqrt(hDim); % the matrix that multiplies the current rule input for recurrent dynamics
Bin0 = zeros(hDim,1); % bias term for recurrent dynamics
Wout0 = randn(zDim,hDim)/sqrt(hDim); % matrix that multiplies neural state to produce the output
Bout0 = Bin0; % bias term for output production ... this should be just 2D, but stored as the full size of the matrix
h00 = zeros(hDim,1); % initial neural state for every run


Theta0 = [Wrec0 Win0 Bin0 Wout0' Bout0 h00]; % all the unknowns, initial values
Theta=dlarray(Theta0);%turn into dlarray 
momentumVec = zeros(size(Theta)); % initial momentum
rmsVec=momentumVec; % something for adam
ThetaList=[];
batchCostList=[];
RMSEList=[];

figure()
H=plot(nan);
title('RMSE');

%% actual neural network training
disp('---- beginning optimization -----');

stepSzNom=stepSz;
for iLoop = 1:numIterations

    % generate a training batch
    [zTargetList,uList] = trainBatchGen_unicycle_cartOut_headIn_square(batchSz,numSteps,dsMu,dsStdv,dheadStdv,roomL);
    

    % get cost and gradient
    [Cost,dCostdTheta,RMSE] = ...
    dlfeval(@computeBatchCost,Theta,uList,zTargetList, ...
    actFun,gamma,NeuralNoiseStd,actRegWeight,wRegWeight);
    batchCostList=[batchCostList; extractdata(Cost)];
    RMSEList=[RMSEList;extractdata(RMSE)];



    if isnan(Cost)%||norm(dCostdTheta,"fro")<eps
        disp('things went Nan');
        break
    end

    if all(Cost<=batchCostList)
        bestCostTheta=Theta;
        bestCostIter=iLoop-1;
        besCosttMomentumVec=momentumVec;
        bestCostRmsVec=rmsVec;
    end


    if all(RMSE<=RMSEList)
        bestPerfTheta=Theta;
        bestPerfIter=iLoop-1;
        bestPerfMomentumVec=momentumVec;
        bestPerfRmsVec=rmsVec;
    end

    % update network with ADAM, per the tensorflow implimentation
    momentumVec=momentumW*momentumVec+(1-momentumW)*dCostdTheta;
    rmsVec=rmsW*rmsVec+(1-rmsW)*(dCostdTheta.^2);
    stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-momentumW^iLoop);

    if iLoop<10
        stepSzNow=stepSzNow/10;%blowup protection for the first update
    end

    Theta=Theta-stepSzNow*momentumVec./(sqrt(rmsVec)+epsilon);

    if mod(iLoop,10)==0
        disp(['iter = ' num2str(iLoop) '. cost = ' num2str(Cost) '. RMSE = ' num2str(RMSE) '. stepSzNow = ' num2str(stepSzNow) '. runtime =' num2str(toc)]);
        H.YData=RMSEList;
        drawnow;
        if mod(iLoop,500)==0
            ThetaList=cat(3,ThetaList,extractdata(Theta));
        end
    end
    % decay the LR
    stepSzNom=stepSzDecay*stepSzNom;

end
%%

runTime=toc;

fileName='squareRoomL25en1_1e2steps_dsMu1en1_dsStdv1en2_dheadStdv20_actReg0_wReg0_nStdv1en1_stepSz1en3_HL5e3_iter15e3';
save(fileName)

