%% Temporal localization of FIT and TE (Fig.S3B, additive noise)

clear all; %close all;

rng(0) % For reproducibility
results_path =  []; % path to save results

% Simulation parameters
nTrials_per_stim = 500;
simReps = 50;
nShuff = 10;

w_xy_sig = 0.5; % stimulus transfer strength
w_xy_noise = 1; % noise transfer strength
epsY = 2; epsX = 2;
ratio_sig_noise = 0.2; % ratio of noise in the signal dimension

% Global params
tparams.simLen = 60; % simulation time, in units of 10ms
tparams.stimWin = [30 35]; % X stimulus encoding window, in units of 10ms
tparams.delays = [4:6]; % communication delays, in units of 10ms
tparams.delayMax = 10; % maximum computed delay, in units of 10ms

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 3; % X has 2 dimensions and each will be discretized in opts.n_binsX bins --> X will have opts.n_binsX^2 possible outcomes
opts.n_binsY = 3; 
opts.n_binsS = 4; % Number of stimulus values

% Draw random delay for each repetition
reps_delays = randsample(tparams.delays,simReps,true);

% Initialize structures
fit = nan(simReps,tparams.simLen,tparams.delayMax); di = fit; dfi = fit; sfit = fit; yfit = fit;
fitSh.simple = nan(simReps,tparams.simLen,tparams.delayMax,nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple; sfitSh.simple = fitSh.simple; yfitSh.simple = fitSh.simple;
fitSh.cond = nan(simReps,tparams.simLen,tparams.delayMax,nShuff); diSh.cond = fitSh.simple; dfiSh.cond = fitSh.simple; sfitSh.cond = fitSh.simple; yfitSh.cond = fitSh.simple;
%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition ',num2str(repIdx)])
    
    nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials
    
    S = randi(opts.n_binsS,1,nTrials);
    
    X_noise = epsX*randn(tparams.simLen,nTrials); % X noise time series
        
    % Notation slightly confusing: eps is the infinitesimal constant in
    % matlab, epsX is the magnitude of noise in node X
    
    X_sig = eps*epsX*randn(tparams.simLen,nTrials); % X signal time series (infinitesimal noise to avoid issues with binning zeros aftwewards)
    X_sig(tparams.stimWin(1):tparams.stimWin(2),:) = repelem(S,numel(tparams.stimWin(1):tparams.stimWin(2)),1);
    X_sig = X_sig+epsX*ratio_sig_noise*randn(1,nTrials); % Additive noise
    
    % Time lagged single-trial inputs from the 2 dimensions of X
    X2Ysig = [eps*epsX*randn(reps_delays(repIdx),nTrials); w_xy_sig*X_sig(1:end-reps_delays(repIdx),:)];
    X2Ynoise = [w_xy_noise*epsX*randn(reps_delays(repIdx),nTrials); w_xy_noise*X_noise(1:end-reps_delays(repIdx),:)];
    
    Y = X2Ysig + X2Ynoise + epsY*randn(tparams.simLen,nTrials); % Y is the sum of X signal and noise dimension plus an internal noise
        
             
    % loop over time and delays
    for t = 1:tparams.simLen
        for d = 1:tparams.delayMax
            if d < t
            
                edgs = eqpop(X_noise(t-d,:), opts.n_binsX);
                [~,bX_noise] = histc(X_noise(t-d,:), edgs);
                edgs = eqpop(X_sig(t-d,:), opts.n_binsX);
                [~,bX_sig] = histc(X_sig(t-d,:), edgs);

                edgs = eqpop(Y(t,:), opts.n_binsY);
                [~,bYt] = histc(Y(t,:), edgs);
                edgs = eqpop(Y(t-d,:), opts.n_binsY);
                [~,bYpast] = histc(Y(t-d,:), edgs);

                bX = (bX_sig - 1) .* opts.n_binsX + bX_noise;

                % compute TE and FIT
                [di(repIdx,t,d),dfi(repIdx,t,d),fit(repIdx,t,d)]=...
                    compute_FIT_TE(S, bX, bYt, bYpast);

                % loop over shufflings
                for shIdx = 1:nShuff
                    
                    % conditioned shuff (destroy within-trial correlations
                    % between X and Y at fixed S)
                    Sval = unique(S);
                    for Ss = 1:numel(Sval)
                        idx = (S == Sval(Ss));
                        tmpX = bX(idx);
                        ridx = randperm(sum(idx));
                        XSh(1,idx) = tmpX(ridx);
                    end

                    [diSh.simple(repIdx,t,d,shIdx),dfiSh.simple(repIdx,t,d,shIdx),fitSh.simple(repIdx,t,d,shIdx)]=...
                        compute_FIT_TE(S, XSh, bYt, bYpast);

                    % simple shuff (shuffle X across all trials)
                    idx = randperm(nTrials);
                    Ssh = S(idx);
                    XSh = bX(idx);

                    [~,dfiSh.cond(repIdx,t,d,shIdx),fitSh.cond(repIdx,t,d,shIdx)]=...
                        compute_FIT_TE(Ssh, bX, bYt, bYpast);
                    [diSh.cond(repIdx,t,d,shIdx)]=...
                        DI_infToolBox(XSh, bYt, bYpast, 'naive', 0);
                
                end
            end
        end
    end
end

RSNLab = num2str(ratio_sig_noise);
RSNLab = replace(RSNLab,'.','');
fname = ['FigS3B.mat'];
save([results_path,'\timeSim\',fname])
%% Plot
rng(0)
save_plot = 0;
path_save = [];

shuff_types = {'cond','simple'};
n_boot = 500;% number of sample of the null distribution
null2plot = 'max'; % null hypothesis (maximum between the two)
prctile_plt = 99;

RSNLab = num2str(ratio_sig_noise);
RSNLab = replace(RSNLab,'.','');

saveLab = [num2str(simReps),'Reps_',num2str(nShuff),null2plot,'Shuff_',RSNLab,'RSN'];

for shLab = shuff_types
    pooledFITsh.(shLab{1}) = btstrp_shuff(fitSh.(shLab{1}),n_boot);
    pooledTEsh.(shLab{1}) = btstrp_shuff(diSh.(shLab{1}),n_boot);
    pooledDFIsh.(shLab{1}) = btstrp_shuff(dfiSh.(shLab{1}),n_boot);
end
pooledFITsh.max = max(cat(4,pooledFITsh.simple,pooledFITsh.cond),[],4);
pooledTEsh.max = max(cat(4,pooledTEsh.simple,pooledTEsh.cond),[],4);
pooledDFIsh.max = max(cat(4,pooledDFIsh.simple,pooledDFIsh.cond),[],4);
    
%% Plot trends
minPlotTime = tparams.delayMax; % so that we don't average over non computed nan values
tplot = (minPlotTime:tparams.simLen)-minPlotTime;
t_rec_stim = (tparams.stimWin(1)+tparams.delays(1):tparams.stimWin(2)+tparams.delays(end))-minPlotTime;

fig=figure();
subplot(3,1,1)
hold on
meanFIT = squeeze(mean(mean(fit(:,minPlotTime:end,tparams.delays(1):tparams.delays(end)),1),3));
semFIT = squeeze(mean(std(fit(:,minPlotTime:end,tparams.delays(1):tparams.delays(end)),[],1),3))/sqrt(simReps);
fitSh_topPrc = prctile(pooledFITsh.(null2plot),99,3);
meanFIT_topPrc = mean(fitSh_topPrc(minPlotTime:end,tparams.delays(1):tparams.delays(end)),2);
sigmask = (meanFIT>meanFIT_topPrc');

plot(tplot,meanFIT)
shadedErrorBar(tplot,meanFIT,semFIT)
plot(tplot,meanFIT_topPrc,'r--')
xlabel('time')
ylabel('bits')
title('FIT')
set(gca,'Ydir','normal')
tmpY = ylim;
ha = area([t_rec_stim(1) t_rec_stim(end)], [tmpY(2) tmpY(2)], 'FaceColor', [0.7,0.7,0.7]);
ha.FaceAlpha = 0.3;
scatter(tplot(sigmask),meanFIT(sigmask),12,'filled')

subplot(3,1,2)
hold on
meanTE = squeeze(mean(mean(di(:,minPlotTime:end,tparams.delays(1):tparams.delays(end)),1),3));
semTE = squeeze(mean(std(di(:,minPlotTime:end,tparams.delays(1):tparams.delays(end)),[],1),3))/sqrt(simReps);
% pooledTEsh = squeeze(mean(diSh,1));
diSh_topPrc = prctile(pooledTEsh.(null2plot),prctile_plt,3);
meanTE_topPrc = mean(diSh_topPrc(minPlotTime:end,tparams.delays(1):tparams.delays(end)),2);
sigmask = (meanTE>meanTE_topPrc');

plot(tplot,meanTE)
shadedErrorBar(tplot,meanTE,semTE)
plot(tplot,meanTE_topPrc,'r--')
xlabel('time')
ylabel('bits')
title('DI')
set(gca,'Ydir','normal')
tmpY = ylim;
ha = area([t_rec_stim(1) t_rec_stim(end)], [tmpY(2) tmpY(2)], 'FaceColor', [0.7,0.7,0.7]);
ha.FaceAlpha = 0.3;
scatter(tplot(sigmask),meanTE(sigmask),12,'filled')

subplot(3,1,3)
hold on
meanDFI = squeeze(mean(mean(dfi(:,minPlotTime:end,tparams.delays(1):tparams.delays(end)),1),3));
semDFI = squeeze(mean(std(dfi(:,minPlotTime:end,tparams.delays(1):tparams.delays(end)),[],1),3))/sqrt(simReps);
dfiSh_topPrc = prctile(pooledDFIsh.(null2plot),prctile_plt,3);
meanDFI_topPrc = mean(dfiSh_topPrc(minPlotTime:end,tparams.delays(1):tparams.delays(end)),2);
sigmask = (meanDFI>meanDFI_topPrc');

plot(tplot,meanDFI)
shadedErrorBar(tplot,meanDFI,semDFI)
plot(tplot,meanDFI_topPrc,'r--')
xlabel('time')
ylabel('bits')
title('DFI')
set(gca,'Ydir','normal')
tmpY = ylim;
basevalue = tmpY(1);
ha = area([t_rec_stim(1) t_rec_stim(end)], [tmpY(2) tmpY(2)],basevalue, 'FaceColor', [0.7,0.7,0.7]);
ha.FaceAlpha = 0.3;
scatter(tplot(sigmask),meanDFI(sigmask),12,'filled')

sgtitle('FIT, TE and DFI time profiles')

if save_plot
    fname = [path_save,'suppFig\timeProfiles_',date];
    saveas(fig,fname)
    fname = [path_save,'suppFig\timeProfiles_',date,'.svg'];
    saveas(fig,fname)
    fname = [path_save,'suppFig\timeProfiles_',date,'.png'];
    saveas(fig,fname)
end