clc;clear;close all;
saveFig=1;% set to 1 to save the plot
% load in data
load('Results_neurNoiseStd1en1_Mem2x_15e3batches_Relu_hDim256.mat');
Theta=extractdata(bestTheta);
numReps=50;
batchSz=50;
gamma=gpuArray(gamma);
Theta=gpuArray(Theta);

figure()
hold on;
cellfun(@(x)plot(x(:,1),x(:,2)),batchCostList);
hold off;
%% pick a range of noises to test
numNoiseLevels=15;
neurNoiseList=gpuArray(linspace(0,1.5*NeuralNoiseStd,numNoiseLevels-1));
neurNoiseList=sort([neurNoiseList NeuralNoiseStd]);
neurNoiseListRepped=repmat(permute(neurNoiseList,[1 5 3 4 2]),[1 1 1 numReps 1]); % 4th dim = montecarlo repetitions, 5th dim = test noise levels

%% get output neural and output trajectories
meanVarList=zeros(6,1,1,1,numNoiseLevels);
meanBiasSqrList=meanVarList;
tic;
for taskNum=1:6
    
    % get new demo batch
    [uList,zTargetList,maskList,~] = TrainBatchGen(batchSz,InputNoiseStd,taskNum);
    zTargetList=gpuArray(zTargetList);
    uList=gpuArray(uList);
    maskList=gpuArray(maskList);
    [zList] = ...
        getOutputTrajectories(Theta,uList,actFun,gamma,neurNoiseListRepped);
    meanVarList(taskNum,:,:,:,:)=sum(maskList.*var(zList,[],4),[1 2 3])/batchSz;
    meanBiasSqrList(taskNum,:,:,:,:)=sum(maskList.*(zTargetList-mean(zList,4)).^2,[1 2 3])/batchSz;
    disp(['num tasks tested = ' num2str(taskNum) '. runTime = ' num2str(toc)])
end

meanBiasSqrList=mean(permute(meanBiasSqrList,[5 1 2 3 4]),2);
meanVarList=mean(permute(meanVarList,[5 1 2 3 4]),2);
MSEList=meanBiasSqrList+meanVarList;
colors=colororder();
%colors=repelem(colors(1:3,:),6, 1);

%%
figure()
plot(neurNoiseList,[sqrt(MSEList) sqrt(meanBiasSqrList) sqrt(meanVarList)]);hold on;
colororder(colors);
xline(NeuralNoiseStd,'k');
%legend('RMSE','biasRMSEComp','varRMSEComp','training noise level')
xlabel('noise standard deviation'); 
figFileName='errorBiasandVarianceForMultiCogNetworkN1en1hdim256';
if saveFig
    saveas(gcf,figFileName);
    saveas(gcf,figFileName,'svg');
end