clc;clear;close all;
saveFigs=0;
readOutTypeList=['detRO';'intRO'];
readOutTypeListForTitle={', fixed readout time',', random readout time'};
netFun=@FuncNet_noiseIn;
for readOutTypeNum=1:2
    readOutType=readOutTypeList(readOutTypeNum,:);
    folderName=['savedFuncNetMatFiles_' readOutType];
    dataFolderStruct=dir([folderName '\*.mat']);
    numNets=length(dataFolderStruct);
    for netNum=1:numNets
        MatFileName=dataFolderStruct(netNum).name;
        if isequal(MatFileName(13),'I') % only procede if this is a noise-in net
            
            load([folderName '\' MatFileName]);

            Theta=extractdata(bestPerfTheta);
            [Win,Wrec,Bin,Wout,Bout,~] = unpackTheta(Theta,hDim,inDim,outDim);

            %% set optimization params

            h0Std=.01;% purturbaion stdvs for initial fixed point geusses
            stepSz0=.07; % adam stepSz for noisless
            stepSzN0=stepSz0/4;% adam stepSz for noise
            stepSzHalfLife=1386;
            decay=.5^(1/stepSzHalfLife);
            mW=.9; %adam momenum weight
            rmsW=.999;% adam rms prop weight

            hTols= [7e-5,9e-5];% different tolerances to use as exit conditions for adam
            gradTols= [1e-6, 1e-7,1e-8];
            numNoiseSamps=300;
            hTolsN=[1/10 1/5]*NeuralNoiseStd*sqrt(hDim/numNoiseSamps);
            gradTolsN=gradTols;


            numIter=50000;

            %% make data to test on

            numInSamps=30;
            % identify what kind of network this is
            funcString=func2str(funcToApprox);
            if isequal(funcString(5),'t')
                aproxFunName='tanh';
                inList=linspace(-4,4,numInSamps);
            else
                aproxFunName='sin';
                inList=linspace(0,2*pi,numInSamps);
            end
            outTargetList=funcToApprox(inList);

            %% get output neural and output trajectories
            [hList,outList,out] = ...
                GetNeuralAndOutputTrajectories_wNoise(Theta,inList,numSteps, ...
                actFun,gamma,0,netFun);
            Theta=gpuArray(Theta);

            %% arrange everyhing for optimization
            h0=gpuArray(hList(:,:,end)); numICs=size(h0,2);

            in0=gpuArray(inList);




            %% use adam to seek out fixed/slow points starting from purturbed final h's with no noise
            gamma=gpuArray(gamma);deltaT=gpuArray(deltaT);
            h0=h0+h0Std*randn(size(h0));
            h=dlarray(h0); hm=zeros(size(h0)); hrms=hm;
            in=in0;
            tolh=hTols(1);
            tolGrad=gradTols(end);
            fixPtCandsH=[];
            fixPtCandsIn=[];
            bigLocMinsH=[];
            bigLocMinsIn=[];


            magsList=nan(numIter+1,numICs);
            magsList(1,:)=0;
            tic;
            figure()
            stepSzNom=stepSz0;
            for iLoop=1:numIter

                [Cost,dCostdh,dhMags] = dlfeval(@fixedPtCost,Theta,h,in,gamma,actFun,netFun);
                % remove fixedPt candidates from the running and save them
                locMinInds=sqrt(sum(dCostdh.^2))<tolGrad;
                subTolInds=dhMags<tolh;
                nanInds=isnan(dhMags);
                inProgInds=~(subTolInds|nanInds|locMinInds);
                fixPtCandsH=[fixPtCandsH h(:,subTolInds)];
                fixPtCandsIn=[fixPtCandsIn in(:,subTolInds)];

                bigLocMinInds=locMinInds&(~subTolInds);
                bigLocMinsH=[bigLocMinsH h(:,bigLocMinInds)];
                bigLocMinsIn=[bigLocMinsIn in(:,bigLocMinInds)];

                h=h(:,inProgInds);hm=hm(:,inProgInds); hrms=hrms(:,inProgInds);
                dCostdh=dCostdh(:,inProgInds);
                in=in(:,inProgInds);
                % update the mags list
                ICsStillInProg=~isnan(magsList(iLoop,:));
                ICsStillInProg(ICsStillInProg)=inProgInds;
                magsList(iLoop+1,ICsStillInProg)=dhMags(inProgInds);

                % update with ADAM, per the tensorflow implimentation
                hm=mW*hm+(1-mW)*dCostdh;
                hrms=rmsW*hrms+(1-rmsW)*(dCostdh.^2);

                stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-mW^iLoop);

                h=h-stepSzNow*hm./(sqrt(hrms)+epsilon);
                if mod(iLoop,10)==0
                    disp(['iter = ' num2str(iLoop) '. ICsConverged  = ' num2str(size(fixPtCandsH,2)+size(bigLocMinsH,2)) '. stepSzNow = ' num2str(stepSzNow) '. runtime =' num2str(toc)]);
                    if mod(iLoop,500)==0
                        plot(log(magsList(:,ICsStillInProg)));
                        drawnow;
                    end
                end
                if all(~inProgInds)
                    disp(['iter = ' num2str(iLoop) '. ICsConverged  = ' num2str(size(fixPtCandsH,2)+size(bigLocMinsH,2)) '. stepSzNow = ' num2str(stepSzNow) '. runtime =' num2str(toc)]);
                    break;
                end
                stepSzNom=decay*stepSzNom;
            end

            %% use adam to seek out fixed/slow points starting from purturbed final h's with noise
            gamma=gpuArray(gamma);deltaT=gpuArray(deltaT);
            h0=h0+h0Std*randn(size(h0));
            h=dlarray(h0); hm=zeros(size(h0)); hrms=hm;
            in=in0;
            tolh=hTolsN(1);
            tolGrad=gradTols(end);
            fixPtCandsHN=[];
            fixPtCandsInN=[];
            bigLocMinsHN=[];
            bigLocMinsInN=[];


            magsList=nan(numIter+1,numICs);
            magsList(1,:)=0;
            stepSzNom=stepSzN0;
            for iLoop=1:numIter

                [Cost,dCostdh,dhMags] = dlfeval(@fixedPtCostWNoiseBias,Theta,h,in,gamma,actFun,netFun,numNoiseSamps,NeuralNoiseStd);
                % remove fixedPt candidates from the running and save them
                locMinInds=sqrt(sum(dCostdh.^2))<tolGrad;
                subTolInds=dhMags<tolh;
                nanInds=isnan(dhMags);
                inProgInds=~(subTolInds|nanInds|locMinInds);
                fixPtCandsHN=[fixPtCandsHN h(:,subTolInds)];
                fixPtCandsInN=[fixPtCandsInN in(:,subTolInds)];

                bigLocMinInds=locMinInds&(~subTolInds);
                bigLocMinsHN=[bigLocMinsHN h(:,bigLocMinInds)];
                bigLocMinsInN=[bigLocMinsInN in(:,bigLocMinInds)];

                h=h(:,inProgInds);hm=hm(:,inProgInds); hrms=hrms(:,inProgInds);
                dCostdh=dCostdh(:,inProgInds);
                in=in(:,inProgInds);
                % update the mags list
                ICsStillInProg=~isnan(magsList(iLoop,:));
                ICsStillInProg(ICsStillInProg)=inProgInds;
                magsList(iLoop+1,ICsStillInProg)=dhMags(inProgInds);

                % update with ADAM, per the tensorflow implimentation
                hm=mW*hm+(1-mW)*dCostdh;
                hrms=rmsW*hrms+(1-rmsW)*(dCostdh.^2);

                stepSzNow=stepSzNom*sqrt(1-rmsW^iLoop)/(1-mW^iLoop);

                h=h-stepSzNow*hm./(sqrt(hrms)+epsilon);
                if mod(iLoop,10)==0
                    disp(['iter = ' num2str(iLoop) '. ICsConverged  = ' num2str(size(fixPtCandsHN,2)+size(bigLocMinsHN,2)) '. stepSzNow = ' num2str(stepSzNow) '. runtime =' num2str(toc)]);
                    if mod(iLoop,500)==0
                        plot(log(magsList(:,ICsStillInProg)));
                        drawnow;
                    end
                end
                if all(~inProgInds)
                    break;
                end
                stepSzNom=decay*stepSzNom;
            end
            %%
            runTime=toc;
            % fixPtInfoFileName=[fileName '_fixPtInfo'];
            % save(fixPtInfoFileName, 'fixPtCands*')

            %% discard approximate repeats(not actually necessary if we're only initializing from the ending h's)
            uniquenessTol=.001;

            [fixPtCandsHu,uniqueInds]=uniquetol(extractdata(gather(fixPtCandsH')),uniquenessTol,"ByRows",true);
            fixPtCandsInu=fixPtCandsIn(:,uniqueInds);
            [fixPtCandsInu, idx]=sort(fixPtCandsInu);
            fixPtCandsHu=fixPtCandsHu(idx,:);

            [fixPtCandsHNu,uniqueInds]=uniquetol(extractdata(gather(fixPtCandsHN')),uniquenessTol,"ByRows",true);
            fixPtCandsInNu=fixPtCandsInN(:,uniqueInds);
            [fixPtCandsInNu, idx]=sort(fixPtCandsInNu);
            fixPtCandsHNu=fixPtCandsHNu(idx,:);

            [Comps,Scores,Vars]=pca([fixPtCandsHu;fixPtCandsHNu]);
            ScoresN=Scores(end/2+1:end,:);
            Scores=Scores(1:end/2,:);



            %% plot the fps
            figure()
            cmap=colormap(jet(numInSamps));
            scatter3(Scores(:,1),Scores(:,2),Scores(:,3),[],cmap,'x');hold on;
            scatter3(ScoresN(:,1),ScoresN(:,2),ScoresN(:,3),[],cmap,'o');
            legend('0 noise fixed points','preferred noise fixed Points')

            %% measure differences between noisy and non noisy fps
            FPdiff=fixPtCandsHNu-fixPtCandsHu;
            FPdist=vecnorm(FPdiff,2,2);% raw distances in neural space
            outPotFPdist=abs(FPdiff*Wout');% output potent differences







            %% % % get performance on a range of noises and see that the fixed point distances predict noise preference
            inListRepped=repmat(inList,[1 1 1 numNoiseSamps]);

            if exist('maxNeuralNoiseStd','var')
                NeuralNoiseStd=maxNeuralNoiseStd;
            end
            numNoiseLevels=2;
            neurNoiseList=linspace(0,NeuralNoiseStd,numNoiseLevels)';
            %neurNoiseList=sort([neurNoiseList;NeuralNoiseStd]);

            % meanVarList=zeros(numNoiseLevels,1);
            % meanBiasSqrList=meanVarList;
            inSpecVarList=zeros(numInSamps,numNoiseLevels);
            inSpecBiasSqrList=inSpecVarList;
            for k=1:length(neurNoiseList)
                neurNoiseNow=neurNoiseList(k);
                [hList,outList,out] = ...
                    GetNeuralAndOutputTrajectories_wNoise(Theta,inListRepped,numSteps, ...
                    actFun,gamma,neurNoiseNow,netFun);
                inSpecVar=var(out,[],4);
                inSpecBiasSqr=(outTargetList-mean(out,4)).^2;
                meanVarList(k)=mean(inSpecVar);
                meanBiasSqrList(k,1)=mean(inSpecBiasSqr);
                inSpecVarList(:,k)=inSpecVar;
                inSpecBiasSqrList(:,k)=inSpecBiasSqr;

            end

            %%
            figTitle=[aproxFunName readOutTypeListForTitle{readOutTypeNum}];
            figFileName=[aproxFunName '_' readOutType];
            figure()
            plot(inList,sqrt(inSpecBiasSqrList),'o');hold on;
            plot(inList,outPotFPdist);
            xlabel('input value')
            title(figTitle);
            if saveFigs
                saveas(gcf,figFileName);
                saveas(gcf,[figFileName '.svg']);
            end
        end
    end
end