clc; clear; close all; tic;


%% hard coding T maze params
deltaT = 0.02; distBetweenPts=1;
downTimeNumSteps=20;% down time at the start of a run when we're just fixating, as if we've just finished a route ending at the start point of the route we're about to do
pauseNumStepsRng=[30 50];
startNumStepsRng=[30 50];
jauntNumSteps=20;
totalNumSteps=downTimeNumSteps+max(startNumStepsRng)+3*max(pauseNumStepsRng)+3*jauntNumSteps;


desiredDriftVarPerJaunt = (.1*distBetweenPts)^2;
DriftStd=sqrt(desiredDriftVarPerJaunt/jauntNumSteps);

%% set training params
numIterations = 6000; % number of optimization iterations
batchSz =36; 
flexConWeight=0;
svRegWeight=1e-3;% weight on max SV squared penalty- makes training more stable
signalRegWeight=0;% weight on signalling penalty, discourages long range, strong influence
senseRegWeight=0;% weight on sensory penalty, discourages long range strong influence
actRegWeight=0;% weight on activation penalty, discourages non zero firing rates
g = 0.8; % from paper
tpCostBuffLength=10;
stepSz = 0.004; % from paper, tensorflow default->.001
momentumW = 0.9; % from paper, tensorflow default->.9
rmsW = 0.999; % from paper, tensorflow default->.999
epsilon = 1e-7; % tensorflow default->1e-7
stepSzHalfLife=numIterations/4;
stepSzDecay=.5^(1/stepSzHalfLife);
%% hard coded set of points
pMatrix(1,:) = [1 0];
pMatrix(2,:) = [1 1];
pMatrix(3,:) = [1 -1];
pMatrix(4,:) = [0 0];
pMatrix(5,:) = [0 1];
pMatrix(6,:) = [0 -1];

% connectivity matrix
numPoints = length(pMatrix);
cMatrix = zeros(numPoints);
cMatrix(1,2) = 1;
cMatrix(1,3) = 1;
cMatrix(1,4) = 1;
cMatrix(4,5) = 1;
cMatrix(4,6) = 1;
cMatrix = (cMatrix+cMatrix');


% these are paths between extreme points of the backbone, so that opt paths
% between any two points can be determined as subsets of these paths
majorPaths = [3 1 4 5;
    2 1 4 5
    6 4 1 2
    6 4 1 3
    3 1 2 NaN
    6 4 5 NaN];


%% paths from one point to another as ordered lists of pMatrix row indices for [start point, via points, end point]
for iPoint = 1:numPoints
    for jPoint = 1:numPoints
        if iPoint~=jPoint
            optPathPt2Pt{iPoint,jPoint} = computeMinPath(iPoint,jPoint,majorPaths);
        else
            optPathPt2Pt{iPoint,jPoint}=[iPoint jPoint];
        end
    end
end

%% make a test set (just to get tpIDs and sizes of things since we're generating batches on the fly)

[zTargetList, uList, fixationList, TPindicatorList] = ...
    trainSetGen_randPauseBuffer(1,optPathPt2Pt,pMatrix,...
    downTimeNumSteps,pauseNumStepsRng,startNumStepsRng,...
    jauntNumSteps,totalNumSteps);

TPindicatorListResh=reshape(TPindicatorList,5,[],1)';
[tpIDs,~,~]=unique(TPindicatorListResh,"rows");
tpIDs=tpIDs';
numTPs=size(tpIDs,2);

%% setting up the neural network parameters
% set non spacial network parameters
desiredNeuralDim = 600; % desired number of neurons(will end up with a bit more)
uDim = size(uList,1); % dimensionality of rule inputs, same as number of routes
zDim = size(zTargetList,1); % dimensionality of paths, basically 2D
tau=.1;
gamma = deltaT/tau;
%actFun = @(x) x.*(x>0); % relu activation % consider changing to rbf's
softPlusScl=10;
%actFun = @(x) softPlus(x,softPlusScl,100); % softplus activation with threshholding for numerical stability %
actFun = @(x) log(1+exp(softPlusScl*x))/softPlusScl;
desiredNeuralNoiseVarRate=.016; % .004 would be consistent with neuralNoiseStd= .1, deltaT=.1, tau=.5 / gamma=.2 (gamma=dt/tau)
NeuralNoiseStd = tau*sqrt(desiredNeuralNoiseVarRate/deltaT); % added every step % manoj: consider making this multiplicative?
fixationDim = size(fixationList,1);

% make coordinates/distance matrices for neurons on a square grid within a 3D ball
boundingCubeEdgeNumNeurs=ceil((desiredNeuralDim*6/pi)^(1/3))+1;
linGrid=linspace(-1,1,boundingCubeEdgeNumNeurs);
[X,Y,Z]=meshgrid(linGrid,linGrid,linGrid);
boundingCube=cat(4,X,Y,Z);
boundingCube=reshape(permute(boundingCube,[4 1 2 3]),3,[])';
origDists=vecnorm(boundingCube,2,2);
ballPoses=origDists<1;
hDim=sum(ballPoses);% num neurons
neuralCords=boundingCube(ballPoses,:);
origDists=origDists(ballPoses);% will be used to penalize input weights and output signals (firingrate*|outputweight|)
neurDistMat=squareform(pdist(neuralCords));% will be used to penalize recurrent signals



%%
% initialize network/training from scratch. some are gaussian normal and
% some are zeros
Wrec0 = g*diag(randn(1,hDim))/sqrt(hDim); % the matrix that multiplies the current neural state for recurrent dynamics
Win0 = g*randn(hDim,uDim)/sqrt(hDim); % the matrix that multiplies the current rule input for recurrent dynamics
Bin0 = zeros(hDim,1); % bias term for recurrent dynamics
Wout0 = g*randn(zDim,hDim)/sqrt(hDim); % matrix that multiplies neural state to produce the output
Bout0 = Bin0; % bias term for output production ... this should be just 2D, but stored as the full size of the matrix
h00 = zeros(hDim,numPoints); % initial neural state for every point
Wfeedback0 = g*randn(hDim,zDim)/sqrt(hDim); % matrix that multiplies current state for recurrent dynamics
Wfixation0 = g*randn(hDim,fixationDim)/sqrt(hDim); % matrix that multiplies fixation input for recurrent dynamics

Theta0 = [Wrec0 Win0 Bin0 Wout0' Bout0 Wfeedback0 Wfixation0 h00]; % all the unknowns, initial values
Theta=dlarray(Theta0);%turn into dlarray on gpu
momentumVec = zeros(size(Theta)); % initial momentum
rmsVec=momentumVec; % something for adam
ThetaList=[];
rmsFlexPenaltyList=[];
maxSVList=[];
batchCostList=[];
tpRMSEList=[];
RMSEList=[];
tpCostBuff=inf(tpCostBuffLength,numTPs);
minMaxTPCostBuff=inf;
maxTPCostBuff=inf;
stepSzNom=stepSz;
%% actual neural network training
disp('---- beginning optimization -----');


figure()
for iLoop = 1:numIterations

    % generate a training set
    [zTargetList, uList, fixationList, TPindicatorList] = ...
        trainBatchGen_randPauseBuffer(batchSz,optPathPt2Pt,pMatrix,...
        downTimeNumSteps,pauseNumStepsRng,startNumStepsRng,...
        jauntNumSteps,totalNumSteps);

    % get cost and gradient
    [Cost,dCostdTheta,RMSE,RMSE_TPsplit,flexPenalty,svPenalty]...
        = dlfeval(@computeBatchCostTmaze,Theta,uList,zTargetList,...
        fixationList, TPindicatorList, actFun,gamma,NeuralNoiseStd,...
        signalRegWeight,senseRegWeight,actRegWeight,svRegWeight,...
        flexConWeight,deltaT,tpIDs,DriftStd,neurDistMat,origDists);
    tpsInBatch=~isnan(RMSE_TPsplit);
    tpCostBuff(:,tpsInBatch)=...
        [tpCostBuff(2:end,tpsInBatch);RMSE_TPsplit(tpsInBatch)];
    maxTPCostBuff=max(tpCostBuff(:));
    batchCostList=[batchCostList; Cost];
    tpRMSEList=[tpRMSEList; RMSE_TPsplit];
    rmsFlexPenaltyList=[rmsFlexPenaltyList;sqrt(flexPenalty)];
    maxSVList=[maxSVList;sqrt(svPenalty)];
    RMSEList=[RMSEList;RMSE];



    if isnan(Cost)%||norm(dCostdTheta,"fro")<eps
        disp('things went Nan');
        break
    end

    if all(Cost<=batchCostList)
        bestCostTheta=Theta;
        bestCostIter=iLoop-1;
        besCosttMomentumVec=momentumVec;
        bestCostRmsVec=rmsVec;
    end


    if maxTPCostBuff<minMaxTPCostBuff
        minMaxTPCostBuff=maxTPCostBuff;
        bestPerfTheta=Theta;
        bestPerfIter=iLoop-1;
        bestPerfMomentumVec=momentumVec;
        bestPerfRmsVec=rmsVec;
    end

    % update network with ADAM, per the tensorflow implimentation
    momentumVec=momentumW*momentumVec+(1-momentumW)*dCostdTheta;
    rmsVec=rmsW*rmsVec+(1-rmsW)*(dCostdTheta.^2);
    stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-momentumW^iLoop);

    if iLoop<10
        stepSzNow=stepSzNow/10;%blowup protection for the first update
    end

    Theta=Theta-stepSzNow*momentumVec./(sqrt(rmsVec)+epsilon);

    if mod(iLoop,10)==0
        disp(['iter = ' num2str(iLoop) '. cost = ' num2str(Cost) '. maxCbuff = ' num2str(maxTPCostBuff) '. maxSV = ' num2str(maxSVList(end)) '. rmsFlexPen = ' num2str(rmsFlexPenaltyList(end)) '. runtime =' num2str(toc)]);
        plot([RMSEList maxSVList rmsFlexPenaltyList]);
        legend('RMSE','maxSV','flexPen');drawnow;
        if mod(iLoop,100)==0
            ThetaList=cat(3,ThetaList,extractdata(Theta));
        end
    end
    % apply learning rate decay
    stepSzNom=stepSzNom*stepSzDecay;

end
%%
batchCostList=extractdata(batchCostList);
tpRMSEList=extractdata(tpRMSEList);
runTime=toc;

fileName=['stepSz4en3_nVarRt16en3_svReg1en3_flexCon0_batchSz36_hDim600__StpSzHL15e2_6e3batches'];
save(fileName)

