%% FIT and TE dependence of stim and noise transmission (Fig.2A)

clear all; %close all;

rng(1) % For reproducibility
results_path =  [];

% Simulation parameters
nTrials_per_stim = 500;
simReps = 50; % repetitions of the simulation
nShuff = 10; % number of permutations (used for both FIT permutation tests)

w_xy_sig = 0:0.1:1; % range of w_signal parameter 
w_xy_noise = 0:0.1:1; % range of w_noise parameter
epsY = 2; epsX = 2; % standard deviation of gaussian noise in X_noise and Y
ratio_sig_noise = 0.2; % standard deviation of gaussian noise in X_signal (expressed as fraction of epsX)

% Global params
tparams.simLen = 60; % simulation time, in units of 10ms
tparams.stimWin = [30 35]; % X stimulus encoding window, in units of 10ms
tparams.delays = [4:6]; % communication delays, in units of 10ms
tparams.delayMax = 10; % maximum computed delay, in units of 10ms

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 3; % X has 2 dimensions and each will be discretized in opts.n_binsX bins --> X will have opts.n_binsX^2 possible outcomes
opts.n_binsY = 3; 
opts.n_binsS = 4; % Number of stimulus values
% Opt.nt = nt;
% Opt.method = 'dr';
% Opt.bias   = bias;
% Opt.trperm = 0;

% Draw random delay for each repetition
reps_delays = randsample(tparams.delays,simReps,true);

% Initialize structures
fit = nan(simReps,numel(w_xy_sig),numel(w_xy_noise)); di = fit; dfi = fit; 
fitSh.simple = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple;
fitSh.cond = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.cond = fitSh.simple; dfiSh.cond = fitSh.simple;

%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition number ',num2str(repIdx)]);
    for sigIdx = 1:numel(w_xy_sig)
        for noiseIdx = 1:numel(w_xy_noise)
            nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials

            % Draw the stimulus value for each trial
            S = randi(opts.n_binsS,1,nTrials);

            % simulate neural activity (X noise)
            X_noise = epsX*randn(tparams.simLen,nTrials); % X noise time series

            % Notation slightly confusing: eps is the infinitesimal constant in
            % matlab, epsX is the magnitude of noise in node X

            % simulate X signal
            X_sig = eps*epsX*randn(tparams.simLen,nTrials); % X signal time series (infinitesimal noise to avoid issues with binning zeros aftwewards)
            X_sig(tparams.stimWin(1):tparams.stimWin(2),:) = repmat(S,numel(tparams.stimWin(1):tparams.stimWin(2)),1);
            X_sig = X_sig.*(1+epsX*ratio_sig_noise*randn(1,nTrials)); % Multiplicative noise

            % Time lagged single-trial inputs from the 2 dimensions of X to Y
            X2Ysig = [eps*epsX*randn(reps_delays(repIdx),nTrials); w_xy_sig(sigIdx)*X_sig(1:end-reps_delays(repIdx),:)];
            X2Ynoise = [w_xy_noise(noiseIdx)*epsX*randn(reps_delays(repIdx),nTrials); w_xy_noise(noiseIdx)*X_noise(1:end-reps_delays(repIdx),:)];

            % Compute Y
            Y = X2Ysig + X2Ynoise + epsY*randn(tparams.simLen,nTrials); % Y is the sum of X signal and noise dimension plus an internal noise

            % First time point at which Y receives stim info from X
            t = tparams.stimWin(1)+reps_delays(repIdx); % first emitting time point (t = 200ms) + delay
            d = reps_delays(repIdx);

            % Discretize neural activity
            edgs = eqpop(X_noise(t-d,:), opts.n_binsX);
            [~,bX_noise] = histc(X_noise(t-d,:), edgs);
            edgs = eqpop(X_sig(t-d,:), opts.n_binsX);
            [~,bX_sig] = histc(X_sig(t-d,:), edgs);

            edgs = eqpop(Y(t,:), opts.n_binsY);
            [~,bYt] = histc(Y(t,:), edgs);
            edgs = eqpop(Y(t-d,:), opts.n_binsY);
            [~,bYpast] = histc(Y(t-d,:), edgs);

            bX = (bX_sig - 1) .* opts.n_binsX + bX_noise;

            [di(repIdx,sigIdx,noiseIdx),dfi(repIdx,sigIdx,noiseIdx),fit(repIdx,sigIdx,noiseIdx)]=...
                compute_FIT_TE(S, bX, bYt, bYpast);

            for shIdx = 1:nShuff

                % conditioned shuff (shuffle X at fixed S)
                Sval = unique(S);
                for Ss = 1:numel(Sval)
                    idx = (S == Sval(Ss));
                    tmpX = bX(idx);
                    ridx = randperm(sum(idx));
                    XSh(1,idx) = tmpX(ridx);
                end

                [diSh.cond(repIdx,sigIdx,noiseIdx,shIdx),dfiSh.cond(repIdx,sigIdx,noiseIdx,shIdx),fitSh.cond(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(S, XSh, bYt, bYpast);

                % simple shuff (shuffle X across all trials)
                idx = randperm(nTrials);
                Ssh = S(idx);
                XSh = bX(idx);

                [~,dfiSh.simple(repIdx,sigIdx,noiseIdx,shIdx),fitSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(Ssh, bX, bYt, bYpast);
                [diSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    DI_infToolBox(XSh, bYt, bYpast, 'naive', 0);

            end
        end
    end
end

fname = ['NIPS_Fig2A.mat'];
save([results_path,'\Fig2\',fname])

%%
save_plot = 0;
shuff_types = {'cond','simple'};
null2plot = 'max'; % null hypothesis (maximum between the two)
n_boot = 500; % number of sample of the null distribution
prctile_plt = 99; % percentile used to determine significance
path_save = [];

%% Plot panel B (heatmaps)
rng(1) % resetting seed because of random bootstrap across shufflings 

% Compute null of mean FIT
for shLab = shuff_types
    pooledFITsh.(shLab{1}) = btstrp_shuff(fitSh.(shLab{1}),n_boot);
    pooledTEsh.(shLab{1}) = btstrp_shuff(diSh.(shLab{1}),n_boot);
    pooledDFIsh.(shLab{1}) = btstrp_shuff(dfiSh.(shLab{1}),n_boot);
end
pooledFITsh.max = max(cat(4,pooledFITsh.simple,pooledFITsh.cond),[],4);
pooledTEsh.max = max(cat(4,pooledTEsh.simple,pooledTEsh.cond),[],4);
pooledDFIsh.max = max(cat(4,pooledDFIsh.simple,pooledDFIsh.cond),[],4);
pooledDFIsh.min = min(cat(4,pooledDFIsh.simple,pooledDFIsh.cond),[],4); % min between null hypotheses: computed for DFI (unsigned measure) to do two-tailed test 

% Compute pvals
pvalsFIT = mean(squeeze(mean(fit,1)) <= pooledFITsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);
pvalsTE = mean(squeeze(mean(di,1)) <= pooledTEsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);
% For DFI (unsigned measure) we compute a two-tailed null hypothesis (DFI
% is significant if it is either higher than the max null hypothesis or
% lower than the min null hypothesis)
pvalsDFI = mean(((squeeze(mean(dfi,1)) <= pooledDFIsh.max) & (squeeze(mean(dfi,1)) >= pooledDFIsh.min)),3);

% Plot matrices
fig=figure('Position',[1,237,1270,312]);
ax(1)=subplot(1,3,1);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(fit,1)))
xlabel('Noise')
ylabel('Signal')
title('FIT')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsFIT(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(fit,1)));
colormap(ax(1),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

ax(2)=subplot(1,3,2);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(di,1)));
xlabel('Noise')
ylabel('Signal')
title('TE')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsTE(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(di,1)));
colormap(ax(2),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

ax(3)=subplot(1,3,3);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(dfi,1)));
xlabel('Noise')
ylabel('Signal')
title('DFI')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsDFI(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(dfi,1)));
colormap(ax(3),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

sgtitle('FIT, TE and DFI dependence on signal and noise')

if save_plot
    fname = [path_save,'Fig2\sigNoise_maps_',date];
    saveas(fig,fname)
    fname = [path_save,'Fig2\sigNoise_maps_',date,'_','.svg'];
    saveas(fig,fname)
    fname = [path_save,'Fig2\sigNoise_maps_',date,'_','.png'];
    saveas(fig,fname)
end

%% Null distributions for fixed value
% Plot Fig. S4E
w_sig = 1;
w_noise = 7;
figure()
hold on
histogram(squeeze(pooledFITsh.simple(w_sig,w_noise,:)),18,'FaceColor',[0.7,0.7,0.7])
histogram(squeeze(pooledFITsh.cond(w_sig,w_noise,:)),4,'FaceColor',[0.4, 0.4, 1])
h(1)=xline(prctile(squeeze(pooledFITsh.cond(w_sig,w_noise,:)),prctile_plt),'b','linewidth',2);
h(2)=xline(prctile(squeeze(pooledFITsh.simple(w_sig,w_noise,:)),prctile_plt),'color',[0.4,0.4,0.4],'linewidth',2);
h(3)=xline(mean(fit(:,w_sig,w_noise)),'r','linewidth',2);
lgd=legend([h(1),h(2),h(3)],'Within-trial 99prc','S-shuff 99prc','Measured FIT');
lgd.FontSize = 12;
title('Null distributions for one point (Fig2A) where fixed-S null. gives a FP')
xlabel('FIT value [bits]')
ylabel('Null hypothesis samples')