%% Script used to run simulations in Fig.2A
% Used for rebuttal figure

% Study FIT and TE as a function of number of trials and n. bins
% individual time points

clear all; %close all;

rng(1) % For reproducibility
save_results = 0; % set to 1 to save results file
results_path =  ''; % path to the directory to save results

% Simulation parameters
nTrials_per_stim = [6,7,8,9,10,11]; % number of trials per stimulus value
simReps = 50; % repetitions of the simulation
nShuff = 2; % number of permutations (used for both FIT permutation tests)
nXtrap = 20; % number of extrapolations for bias correction

noise = 0.5; % standard deviation of gaussian noise in X_noise and Y


% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
number_of_bins = [2 3 4 8];
opts.n_binsS = 2; % Number of stimulus values

% Initialize structures
fit = nan(numel(number_of_bins),simReps,numel(nTrials_per_stim)); di = fit; dfi = fit; 
fitSh.simple = nan(numel(number_of_bins),simReps,numel(nTrials_per_stim),nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple;
fitLe = nan(numel(number_of_bins),simReps,numel(nTrials_per_stim)); diLe = fitLe; 
fitQe = nan(numel(number_of_bins),simReps,numel(nTrials_per_stim)); diQe = fitLe; 

%% Run simulation

for bIdx = 1:numel(number_of_bins)
    disp(['Simulation for ',num2str(number_of_bins(bIdx)),' bins'])
    opts.n_binsX = number_of_bins(bIdx);
    opts.n_binsY = number_of_bins(bIdx); 
    for trialIdx = 1:numel(nTrials_per_stim)
        disp(['Simulation for ',num2str(2^nTrials_per_stim(trialIdx)),' trials per stim'])
        for repIdx = 1:simReps
            %disp(['Repetition number ',num2str(repIdx)]);
            nTrials = (2^nTrials_per_stim(trialIdx))*opts.n_binsS; % Compute number of trials

            % Draw the stimulus value for each trial
            S = randi(opts.n_binsS,1,nTrials);

            % Notation slightly confusing: eps is the infinitesimal constant in
            % matlab, epsX is the magnitude of noise in node X

            % simulate X signal
            X = S + noise*randn(1,nTrials); 
            Yt = X + noise*randn(1,nTrials); 
            hY = noise*randn(1,nTrials); 

            % Discretize neural activity
            edgs = eqpop(X, opts.n_binsX);
            [~,bX] = histc(X, edgs);

            edgs = eqpop(Yt, opts.n_binsY);
            [~,bYt] = histc(Yt, edgs);
            edgs = eqpop(hY, opts.n_binsY);
            [~,bYpast] = histc(hY, edgs);

            [di(bIdx,repIdx,trialIdx),dfi(bIdx,repIdx,trialIdx),fit(bIdx,repIdx,trialIdx), diQe(bIdx,repIdx,trialIdx),...
                diLe(bIdx,repIdx,trialIdx), fitQe(bIdx,repIdx,trialIdx), fitLe(bIdx,repIdx,trialIdx)]=...
                compute_FIT_TE_Qe(S, bX, bYt, bYpast,1,nXtrap);

            for shIdx = 1:nShuff

                % simple shuff (shuffle X across all trials)
                idx = randperm(nTrials);
                Ssh = S(idx);
                XSh = bX(idx);

                [~,dfiSh.simple(bIdx,repIdx,trialIdx,shIdx),fitSh.simple(bIdx,repIdx,trialIdx,shIdx)]=...
                    compute_FIT_TE(Ssh, bX, bYt, bYpast);
                [diSh.simple(bIdx,repIdx,trialIdx,shIdx)]=...
                    DI_infToolBox(XSh, bYt, bYpast, 'naive', 0);

            end
        end
    end
end

if save_results
    fname = ['bias_FigS6_' date '.mat'];
    save([fname])
end

%% Plot FIT and TE as a function of number of trials for different number of bins
rng(1)
x_vals = (2.^(nTrials_per_stim))*opts.n_binsS;
cols = [0 0 1; 1 0 0];
minX = 128;

figure('Position',[282,129,808,420])
for bIdx = 1:numel(number_of_bins)
    fit_bin = squeeze(fit(bIdx,:,:));
    fitQe_bin = squeeze(fitQe(bIdx,:,:));
    di_bin = squeeze(di(bIdx,:,:));
    diQe_bin = squeeze(diQe(bIdx,:,:));
    
    subplot(1,numel(number_of_bins),bIdx)
    hold on
    h(2)=plot(x_vals,mean(fitQe_bin,1),'r','linewidth',2);
    shadedErrorBar(x_vals,mean(fitQe_bin,1),std(fitQe_bin,[],1)/sqrt(simReps),'LineProps',{'color','r'},'patchSaturation',0.1)
    h(1)=plot(x_vals,mean(fit_bin,1),'color','b','linewidth',2);
    shadedErrorBar(x_vals,mean(fit_bin),std(fit_bin,[],1)/sqrt(simReps),'LineProps',{'color','b'},'patchSaturation',0.1)
    h(4)=plot(x_vals,mean(diQe_bin,1),'g','linewidth',2);
    shadedErrorBar(x_vals,mean(diQe_bin,1),std(diQe_bin,[],1)/sqrt(simReps),'LineProps',{'color','g'},'patchSaturation',0.1)
    h(3)=plot(x_vals,mean(di_bin,1),'color','k','linewidth',2);
    shadedErrorBar(x_vals,mean(di_bin),std(di_bin,[],1)/sqrt(simReps),'LineProps',{'color','k'},'patchSaturation',0.1)
    
    ylim([0,0.75])
    ylabel('[bits]')
    xlim([minX,x_vals(end)])
    h(5)=xline(400,'--','color',[139,69,19]/255,'linewidth',2);
    h(6)=xline(700,'--','color',[0.5,0.5,0.5],'linewidth',2);
    h(7)=xline(1000,'--','color',[182, 96, 205]/255,'linewidth',2);  
    h(8)=xline(2000,'--','color',[226, 186, 31]/255,'linewidth',2);
    
    if bIdx == 1
        legend([h(1) h(2) h(3) h(4)],{'FIT plug in','FIT QE','TE plug in', 'TE QE'},'Location','SouthEast')
    elseif bIdx == 2
        legend([h(5) h(6) h(7) h(8)],{'MEG','Spiking activity','EEG', 'Simulations'},'Location','SouthEast')
    end
    xlabel('number of trials')
    title([num2str(number_of_bins(bIdx)) ' bins'])

end
sgtitle(['FIT and TE dependence on dataset size, 1D sender, noise=',num2str(noise)])

