%% Delta TE simulation (Fig. S14)

clear all; %close all;

rng(1) % for reproducibility

results_path =  [];

% Simulation parameters
nTrials_per_stim = 500;
simReps = 50; % reprtitions of the simulation
nShuff = 2; % not used in this simulation (significance tested using a two-tailed t-test vs zero mean)

w_xy_sig = 0:0.1:1;
w_xy_noise = 0:0.1:1;
epsY = 2; epsX = 2; % standard deviation of gaussian noise in X_noise and Y
ratio_sig_noise = 0.2; % standard deviation of gaussian noise in X_signal (expressed as fraction of epsX)

% Time parameters
tparams.simLen = 60; % simulation time, in units of 10ms
tparams.stimWin = [30 35]; % X stimulus encoding window, in units of 10ms
tparams.delays = [4:6]; % communication delays, in units of 10ms
tparams.delayMax = 10; % maximum computed delay, in units of 10ms

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 3; % X has 2 dimensions and each will be discretized in opts.n_binsX bins --> X will have opts.n_binsX^2 possible outcomes
opts.n_binsY = 3; 
opts.n_binsS = 2; % Number of stimulus values (2 to define delta TE)

% Draw random delay for each repetition
reps_delays = randsample(tparams.delays,simReps,true);

% Run simulation
fit = nan(simReps,numel(w_xy_sig),numel(w_xy_noise)); di = fit; dfi = fit; fits = fit; fity = fit;
fitSh.simple = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple;
fitSh.cond = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.cond = fitSh.simple; dfiSh.cond = fitSh.simple;
di_fixedS = nan(opts.n_binsS,simReps,numel(w_xy_sig),numel(w_xy_noise));

for repIdx = 1:simReps
    disp(['Repetition number ',num2str(repIdx)]);
    for sigIdx = 1:numel(w_xy_sig)
        for noiseIdx = 1:numel(w_xy_noise)

            nTrials = nTrials_per_stim*opts.n_binsS;
            
            S = randi(opts.n_binsS,1,nTrials);

            X_noise = epsX*randn(tparams.simLen,nTrials); % X noise time series

            % Notation slightly confusing: eps is the infinitesimal constant in
            % matlab, epsX is the magnitude of noise in node X

            X_sig = eps*epsX*randn(tparams.simLen,nTrials); % X signal time series (infinitesimal noise to avoid issues with binning zeros aftwewards)
            X_sig(tparams.stimWin(1):tparams.stimWin(2),:) = repmat(S,numel(tparams.stimWin(1):tparams.stimWin(2)),1);
            X_sig = X_sig+epsX*ratio_sig_noise*randn(1,nTrials); % Additive noise

            % Time lagged single-trial inputs from the 2 dimensions of X
            X2Ysig = [eps*epsX*randn(reps_delays(repIdx),nTrials); w_xy_sig(sigIdx)*X_sig(1:end-reps_delays(repIdx),:)];
            X2Ynoise = [w_xy_noise(noiseIdx)*epsX*randn(reps_delays(repIdx),nTrials); w_xy_noise(noiseIdx)*X_noise(1:end-reps_delays(repIdx),:)];

            Y = X2Ysig + X2Ynoise + epsY*randn(tparams.simLen,nTrials); % Y is the sum of X signal and noise dimension plus an internal noise

            % First time point at which Y receives stim info from X
            t = tparams.stimWin(1)+reps_delays(repIdx); % first emitting time point + delay
            d = reps_delays(repIdx);
            
            % Discretize neural activity
            edgs = eqpop(X_noise(t-d,:), opts.n_binsX);
            [~,bX_noise] = histc(X_noise(t-d,:), edgs);
            edgs = eqpop(X_sig(t-d,:), opts.n_binsX);
            [~,bX_sig] = histc(X_sig(t-d,:), edgs);

            edgs = eqpop(Y(t,:), opts.n_binsY);
            [~,bYt] = histc(Y(t,:), edgs);
            edgs = eqpop(Y(t-d,:), opts.n_binsY);
            [~,bYpast] = histc(Y(t-d,:), edgs);

            bX = (bX_sig - 1) .* opts.n_binsX + bX_noise;
            S = S;

            [di(repIdx,sigIdx,noiseIdx),dfi(repIdx,sigIdx,noiseIdx),fit(repIdx,sigIdx,noiseIdx)]=...
                compute_FIT_TE(S, bX, bYt, bYpast);

             sval = unique(S);
            for sS = 1:numel(sval)
                sel_idxs = (S == sval(sS));
                di_fixedS(sS,repIdx,sigIdx,noiseIdx)=DI_infToolBox(bX(sel_idxs), bYt(sel_idxs), bYpast(sel_idxs), 'naive', 0);
            end
    
            for shIdx = 1:nShuff
                % Conditioned null
                Sval = unique(S);
                for Ss = 1:numel(Sval)
                    idx = (S == Sval(Ss));
                    tmpX = bX(idx);
                    ridx = randperm(sum(idx));
                    XSh(1,idx) = tmpX(ridx);
                end

                [diSh.cond(repIdx,sigIdx,noiseIdx,shIdx),dfiSh.cond(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(S, XSh, bYt, bYpast);
                % Simple null
                idx = randperm(nTrials);
                Ssh = S(idx);
                XSh = bX(idx);

                [~,dfiSh.simple(repIdx,sigIdx,noiseIdx,shIdx),fitSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(Ssh, bX, bYt, bYpast);
                [diSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    DI_infToolBox(XSh, bYt, bYpast, 'naive', 0);
            end
        end
    end
end

RSNLab = num2str(ratio_sig_noise);
RSNLab = replace(RSNLab,'.','');

fname = ['NIPS_FigS12.mat'];
save([results_path,'\deltaTE\',fname])

%% Plots heatmaps dependency on w_sig and w_noise
path_save = [];
save_plot = 0;
prctile_plt = 99; % percetile used to determine significance

deltaTE = squeeze(di_fixedS(2,:,:,:)-di_fixedS(1,:,:,:));
   
pvalsDTE = nan(numel(w_xy_sig),numel(w_xy_noise));
for noiseIdx = 1:numel(w_xy_noise)
    [~,pvalsDTE(:,noiseIdx)] = ttest(squeeze(deltaTE(:,:,noiseIdx)));
end

fig=figure('Position',[360,324,391,294]);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(deltaTE,1)))
cmap = my_colormap_rb(squeeze(mean(deltaTE,1)));
colormap(cmap)
colorbar()
xlabel('Noise')
ylabel('Signal')
title('deltaDI')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsDTE(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

if save_plot
    fname = [path_save,'suppFig\deltaDI_map_',date];
    saveas(fig,fname)
    fname = [path_save,'suppFig\deltaDI_map_','.svg'];
    saveas(fig,fname)
    fname = [path_save,'suppFig\deltaDI_map_','.png'];
    saveas(fig,fname)
end