%% FIT and TE dependence of stim and noise transmission (Fig.S3A, additive noise)


clear all; %close all;

rng(1) % For reproducibility
results_path =  [];

% Simulation parameters
nTrials_per_stim = 500;
simReps = 50; % repetitions of the simulation
nShuff = 10; % number of permutations

w_xy_sig = 0:0.1:1; % range of w_sig parameter
w_xy_noise = 0:0.1:1; % range of w_noise parameter
epsY = 2; epsX = 2; % standard deviation of gaussian noise in X_noise and Y
ratio_sig_noise = 0.2; % standard deviation of gaussian noise in X_signal (expressed as fraction of epsX)

% Global params
tparams.simLen = 60; % simulation time, in units of 10ms
tparams.stimWin = [30 35]; % X stimulus encoding window, in units of 10ms
tparams.delays = [4:6]; % communication delays, in units of 10ms
tparams.delayMax = 10; % maximum computed delay, in units of 10ms

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 3; % X has 2 dimensions and each will be discretized in opts.n_binsX bins --> X will have opts.n_binsX^2 possible outcomes
opts.n_binsY = 3; 
opts.n_binsS = 4; % Number of stimulus values
% Opt.nt = nt;
% Opt.method = 'dr';
% Opt.bias   = bias;
% Opt.trperm = 0;

% Draw random delay for each repetition
reps_delays = randsample(tparams.delays,simReps,true);

% Initialize structures
fit = nan(simReps,numel(w_xy_sig),numel(w_xy_noise)); di = fit; dfi = fit; fits = fit; fity = fit;
fitSh.simple = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple;
fitSh.cond = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.cond = fitSh.simple; dfiSh.cond = fitSh.simple;

%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition number ',num2str(repIdx)]);
    for sigIdx = 1:numel(w_xy_sig)
        for noiseIdx = 1:numel(w_xy_noise)
            nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials

            % Draw the stimulus value for each trial
            S = randi(opts.n_binsS,1,nTrials);

            % simulate neural activity (X noise)
            X_noise = epsX*randn(tparams.simLen,nTrials); % X noise time series

            % Notation slightly confusing: eps is the infinitesimal constant in
            % matlab, epsX is the magnitude of noise in node X

            % simulate X signal
            X_sig = eps*epsX*randn(tparams.simLen,nTrials); % X signal time series (infinitesimal noise to avoid issues with binning zeros aftwewards)
            X_sig(tparams.stimWin(1):tparams.stimWin(2),:) = repmat(S,numel(tparams.stimWin(1):tparams.stimWin(2)),1);
            X_sig = X_sig+epsX*ratio_sig_noise*randn(1,nTrials); % Additive noise

            % Time lagged single-trial inputs from the 2 dimensions of X to Y
            X2Ysig = [eps*epsX*randn(reps_delays(repIdx),nTrials); w_xy_sig(sigIdx)*X_sig(1:end-reps_delays(repIdx),:)];
            X2Ynoise = [w_xy_noise(noiseIdx)*epsX*randn(reps_delays(repIdx),nTrials); w_xy_noise(noiseIdx)*X_noise(1:end-reps_delays(repIdx),:)];

            % Compute Y
            Y = X2Ysig + X2Ynoise + epsY*randn(tparams.simLen,nTrials); % Y is the sum of X signal and noise dimension plus an internal noise

            % First time point at which Y receives stim info from X
            t = tparams.stimWin(1)+reps_delays(repIdx); % first emitting time point (t = 200ms) + delay
            d = reps_delays(repIdx);

            % Discretize neural activity
            edgs = eqpop(X_noise(t-d,:), opts.n_binsX);
            [~,bX_noise] = histc(X_noise(t-d,:), edgs);
            edgs = eqpop(X_sig(t-d,:), opts.n_binsX);
            [~,bX_sig] = histc(X_sig(t-d,:), edgs);

            edgs = eqpop(Y(t,:), opts.n_binsY);
            [~,bYt] = histc(Y(t,:), edgs);
            edgs = eqpop(Y(t-d,:), opts.n_binsY);
            [~,bYpast] = histc(Y(t-d,:), edgs);

            bX = (bX_sig - 1) .* opts.n_binsX + bX_noise;

            [di(repIdx,sigIdx,noiseIdx),dfi(repIdx,sigIdx,noiseIdx),fit(repIdx,sigIdx,noiseIdx)]=...
                compute_FIT_TE(S, bX, bYt, bYpast);

            for shIdx = 1:nShuff

                % conditioned shuff (destroy within-trial correlations
                % between X and Y at fixed S)
                Sval = unique(S);
                for Ss = 1:numel(Sval)
                    idx = (S == Sval(Ss));
                    tmpX = bX(idx);
                    ridx = randperm(sum(idx));
                    XSh(1,idx) = tmpX(ridx);
                end

                [diSh.cond(repIdx,sigIdx,noiseIdx,shIdx),dfiSh.cond(repIdx,sigIdx,noiseIdx,shIdx),fitSh.cond(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(S, XSh, bYt, bYpast);

                % simple shuff (shuffle X across all trials)
                idx = randperm(nTrials);
                Ssh = S(idx);
                XSh = bX(idx);

                [~,dfiSh.simple(repIdx,sigIdx,noiseIdx,shIdx),fitSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(Ssh, bX, bYt, bYpast);
                [diSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    DI_infToolBox(XSh, bYt, bYpast, 'naive', 0);

            end
        end
    end
end

fname = ['NIPS_FigS3A.mat'];
save([results_path,'\FigS3\',fname])

%%
save_plot = 0;
shuff_types = {'cond','simple'};
n_boot = 500; % number of sample of the null distribution
null2plot = 'max'; % null hypothesis (maximum between the two)
path_save = [];

%% Plot panel B (heatmaps)
rng(0) % resetting seed because of random bootstrap across shufflings 

prctile_plt = 99;

% Compute null of mean FIT
for shLab = shuff_types
    pooledFITsh.(shLab{1}) = btstrp_shuff(fitSh.(shLab{1}),n_boot);
    pooledTEsh.(shLab{1}) = btstrp_shuff(diSh.(shLab{1}),n_boot);
    pooledDFIsh.(shLab{1}) = btstrp_shuff(dfiSh.(shLab{1}),n_boot);
end
pooledFITsh.max = max(cat(4,pooledFITsh.simple,pooledFITsh.cond),[],4);
pooledTEsh.max = max(cat(4,pooledTEsh.simple,pooledTEsh.cond),[],4);
pooledDFIsh.max = max(cat(4,pooledDFIsh.simple,pooledDFIsh.cond),[],4);
    
% Compute pvals
pvalsFIT = mean(squeeze(mean(fit,1)) <= pooledFITsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);
pvalsTE = mean(squeeze(mean(di,1)) <= pooledTEsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);
pvalsDFI = mean(squeeze(mean(dfi,1)) <= pooledDFIsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);

% Plot matrices
fig=figure('Position',[1,237,1270,312]);
ax(1)=subplot(1,3,1);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(fit,1)))
xlabel('Noise')
ylabel('Signal')
title('FIT')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsFIT(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(fit,1)));
colormap(ax(1),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

ax(2)=subplot(1,3,2);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(di,1)));
xlabel('Noise')
ylabel('Signal')
title('TE')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsTE(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(di,1)));
colormap(ax(2),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

ax(3)=subplot(1,3,3);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(dfi,1)));
xlabel('Noise')
ylabel('Signal')
title('DFI')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsDFI(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(dfi,1)));
colormap(ax(3),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

sgtitle('FIT, TE and DFI dependence on signal and noise')

if save_plot
    fname = [path_save,'FigS3\sigNoise_maps_',saveLab,'_',date];
    saveas(fig,fname)
    fname = [path_save,'FigS3\sigNoise_maps_',saveLab,'_','.svg'];
    saveas(fig,fname)
    fname = [path_save,'FigS3\sigNoise_maps_',saveLab,'_','.png'];
    saveas(fig,fname)
end
w_xy_sig = 0:0.1:1;
w_xy_noise = 0:0.1:1;
% w_xy_sig = 0:0.01:0.2;
% w_xy_noise = 0:0.01:0.2;
%w_xy_noise = 0.5;
epsY = 2; epsX = 2;
ratio_sig_noise = 0.2; % ratio of noise in the signal dimension

% Draw random delay for each repetition
%reps_delays = tparams.delays*ones(simReps,1);
reps_delays = randsample(tparams.delays,simReps,true);

% Initialize structures
fit = nan(simReps,numel(w_xy_sig),numel(w_xy_noise)); di = fit; dfi = fit; fits = fit; fity = fit;
fitSh.simple = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple;
fitSh.cond = nan(simReps,numel(w_xy_sig),numel(w_xy_noise),nShuff); diSh.cond = fitSh.simple; dfiSh.cond = fitSh.simple;

%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition number ',num2str(repIdx)]);
    for sigIdx = 1:numel(w_xy_sig)
        for noiseIdx = 1:numel(w_xy_noise)
            nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials

            % Draw the stimulus value for each trial
            S = randi(opts.n_binsS,1,nTrials);

            % simulate neural activity (X noise)
            X_noise = epsX*randn(tparams.simLen,nTrials); % X noise time series

            % Notation slightly confusing: eps is the infinitesimal constant in
            % matlab, epsX is the magnitude of noise in node X

            % simulate X signal
            X_sig = eps*epsX*randn(tparams.simLen,nTrials); % X signal time series (infinitesimal noise to avoid issues with binning zeros aftwewards)
            X_sig(tparams.stimWin(1):tparams.stimWin(2),:) = repmat(S,numel(tparams.stimWin(1):tparams.stimWin(2)),1);
            X_sig = X_sig+epsX*ratio_sig_noise*randn(1,nTrials); % do we need noise here?

            % Time lagged single-trial inputs from the 2 dimensions of X to Y
            X2Ysig = [eps*epsX*randn(reps_delays(repIdx),nTrials); w_xy_sig(sigIdx)*X_sig(1:end-reps_delays(repIdx),:)];
            X2Ynoise = [w_xy_noise(noiseIdx)*epsX*randn(reps_delays(repIdx),nTrials); w_xy_noise(noiseIdx)*X_noise(1:end-reps_delays(repIdx),:)];

            % Compute Y
            Y = X2Ysig + X2Ynoise + epsY*randn(tparams.simLen,nTrials); % Y is the sum of X signal and noise dimension plus an internal noise

            % First time point at which Y receives stim info from X
            t = tparams.stimWin(1)+reps_delays(repIdx); % first emitting time point (t = 200ms) + delay
            d = reps_delays(repIdx);

            % Discretize neural activity
            edgs = eqpop(X_noise(t-d,:), opts.n_binsX);
            [~,bX_noise] = histc(X_noise(t-d,:), edgs);
            edgs = eqpop(X_sig(t-d,:), opts.n_binsX);
            [~,bX_sig] = histc(X_sig(t-d,:), edgs);

            edgs = eqpop(Y(t,:), opts.n_binsY);
            [~,bYt] = histc(Y(t,:), edgs);
            edgs = eqpop(Y(t-d,:), opts.n_binsY);
            [~,bYpast] = histc(Y(t-d,:), edgs);

            bX = (bX_sig - 1) .* opts.n_binsX + bX_noise;

            [di(repIdx,sigIdx,noiseIdx),dfi(repIdx,sigIdx,noiseIdx),fit(repIdx,sigIdx,noiseIdx)]=...
                compute_FIT_TE(S, bX, bYt, bYpast, 0);

            for shIdx = 1:nShuff

                % conditioned shuff
                Sval = unique(S);
                for Ss = 1:numel(Sval)
                    idx = (S == Sval(Ss));
                    tmpX = bX(idx);
                    ridx = randperm(sum(idx));
                    XSh(1,idx) = tmpX(ridx);
                end

                [diSh.cond(repIdx,sigIdx,noiseIdx,shIdx),dfiSh.cond(repIdx,sigIdx,noiseIdx,shIdx),fitSh.cond(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(S, XSh, bYt, bYpast, 0);

                % simple shuff
                idx = randperm(nTrials);
                Ssh = S(idx);
                XSh = bX(idx);

                [~,dfiSh.simple(repIdx,sigIdx,noiseIdx,shIdx),fitSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    compute_FIT_TE(Ssh, bX, bYt, bYpast, 0);
                [diSh.simple(repIdx,sigIdx,noiseIdx,shIdx)]=...
                    DI_infToolBox(XSh, bYt, bYpast, 'naive', 0);

            end
        end
    end
end

RSNLab = num2str(ratio_sig_noise);
RSNLab = replace(RSNLab,'.','');
fname = ['NIPS_multipleDelays_additive_seed1_',num2str(simReps),'Reps_sigNoise',RSNLab,'_',num2str(nShuff),'bothShuff_',num2str(nTrials),'Trials',date,'.mat'];
save([results_path,'\sigNoiseTimeGT\',fname])
%% Plot

fixed_val = 1; % value to fix when plotting trends with other param
prctile_plt = 99; % significant percentile
save_plot = 0;
shuff_types = {'cond','simple'};
n_boot = 500;
null2plot = 'max'; % null hypothesis distribution to use for the significance test (maximum of the two above)
path_save = [];


%% Plot panel B (heatmaps)
rng(0) % resetting seed because of random bootstrap across shufflings 

prctile_plt = 99;

for shLab = shuff_types
    pooledFITsh.(shLab{1}) = btstrp_shuff(fitSh.(shLab{1}),n_boot);
    pooledDIsh.(shLab{1}) = btstrp_shuff(diSh.(shLab{1}),n_boot);
    pooledDFIsh.(shLab{1}) = btstrp_shuff(dfiSh.(shLab{1}),n_boot);
end
pooledFITsh.max = max(cat(4,pooledFITsh.simple,pooledFITsh.cond),[],4);
pooledDIsh.max = max(cat(4,pooledDIsh.simple,pooledDIsh.cond),[],4);
pooledDFIsh.max = max(cat(4,pooledDFIsh.simple,pooledDFIsh.cond),[],4);
    
pvalsFIT = mean(squeeze(mean(fit,1)) <= pooledFITsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);
pvalsDI = mean(squeeze(mean(di,1)) <= pooledDIsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);
pvalsDFI = mean(squeeze(mean(dfi,1)) <= pooledDFIsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);

fig=figure('Position',[1,237,1270,312]);
ax(1)=subplot(1,3,1);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(fit,1)))
xlabel('Noise')
ylabel('Signal')
title('FIT')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsFIT(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(fit,1)));
colormap(ax(1),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

ax(2)=subplot(1,3,2);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(di,1)));
xlabel('Noise')
ylabel('Signal')
title('DI')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsDI(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(di,1)));
colormap(ax(2),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

ax(3)=subplot(1,3,3);
hold on
imagesc(w_xy_sig,w_xy_noise,squeeze(mean(dfi,1)));
xlabel('Noise')
ylabel('Signal')
title('DFI')
for i = 1:numel(w_xy_sig)
    for j = 1:numel(w_xy_noise)
        pvalues_plot_threshold(pvalsDFI(i,j),w_xy_sig(j),w_xy_noise(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
cmap = my_colormap_rb(squeeze(mean(dfi,1)));
colormap(ax(3),cmap)
colorbar()
xlim([-0.05,w_xy_sig(end)+0.05])
ylim([-0.05,w_xy_noise(end)+0.05])
set(gca,'Ydir','normal')

sgtitle('FIT, DI and DFI dependence on signal and noise')

if save_plot
    fname = [path_save,'sigNoiseTimeGT\sigNoise_maps_',saveLab,'_',date];
    saveas(fig,fname)
    fname = [path_save,'sigNoiseTimeGT\sigNoise_maps_',saveLab,'_','.svg'];
    saveas(fig,fname)
    fname = [path_save,'sigNoiseTimeGT\sigNoise_maps_',saveLab,'_','.png'];
    saveas(fig,fname)
end
