clc;clear;close all;
saveFig=0;% set to 1 to save the plot
%% load in data
load('stepSz4en3_nVarRt16en3_svReg1en3_flexCon0_batchSz36_hDim600_StpSzHL15e2_6e3batches.mat');
Theta=extractdata(bestPerfTheta);

%% plot cost history

figure()
plot(batchCostList)

%% make data to test on

[zTargetList, uList, fixationList, TPindicatorList] = ...
    trainSetGen_randPauseBuffer(1,optPathPt2Pt,pMatrix,...
    downTimeNumSteps,pauseNumStepsRng,startNumStepsRng,...
    jauntNumSteps,totalNumSteps);
% [zTargetList, uList, fixationList, TPindicatorList] = ...
%     trainSetGen_randPauseBuffer_oneHotRule(1,optPathPt2Pt,pMatrix,...
%     downTimeNumSteps,pauseNumStepsRng,startNumStepsRng,...
%     jauntNumSteps,totalNumSteps);


numReps=gpuArray(50);

%% put everything on the gpu
zTargetList=gpuArray(zTargetList);
uList=gpuArray(uList);
fixationList=gpuArray(fixationList);
TPindicatorList=gpuArray(TPindicatorList);
gamma=gpuArray(gamma);
deltaT=gpuArray(deltaT);
Theta=gpuArray(Theta);
%% pick a range of noizes to test
numNoiseLevels=10;
neurNoiseList=gpuArray(linspace(0,1.5*NeuralNoiseStd,numNoiseLevels-1));
neurNoiseList=sort([neurNoiseList NeuralNoiseStd]);


%% get output neural and output trajectories
meanVarList=zeros(numNoiseLevels,1);
meanBiasSqrList=meanVarList;
tSpecVarList=zeros([numNoiseLevels size(zTargetList,[2 3])]);
tSpecBiasSqrList=tSpecVarList;
tic;
for k=1:length(neurNoiseList)
    
    neurNoiseNow=neurNoiseList(k);
    [~,zList] = ...
        GetNeuralAndOutputTrajectories_wNoise_repped(Theta,uList,fixationList,actFun,deltaT,gamma,zTargetList(:,:,1),TPindicatorList,neurNoiseNow,numReps);
    tSpecVar=sum(var(zList,[],4));
    tSpecBiasSqr=sum((zTargetList-mean(zList,4)).^2);
    meanVarList(k)=mean(tSpecVar,"all");
    meanBiasSqrList(k)=mean(tSpecBiasSqr,"all");
    tSpecVarList(k,:,:)=tSpecVar;
    tSpecBiasSqrList(k,:,:)=tSpecBiasSqr;
    disp(['num noise levels tested = ' num2str(k) '. runTime = ' num2str(toc)])
end


MSEList=meanBiasSqrList+meanVarList;

figure()
plot(neurNoiseList,[sqrt(MSEList) sqrt(meanBiasSqrList) sqrt(meanVarList)]);hold on;
xline(NeuralNoiseStd);
%legend('RMSE','biasRMSEComp','varRMSEComp','training noise level')
xlabel('noise standard deviation'); 
figFileName='errorBiasandVarianceForMazeNetwork';
if saveFig
    saveas(gcf,figFileName);
    saveas(gcf,[figFileName '.svg']);
end