%% Simulations in presence of a third region Z (Fig S7)

clear all; %close all;

rng(1) % For reproducibility
results_path =  [];

% Simulation parameters
nTrials_per_stim = 500;
simReps = 50;
nShuff = 10;

w_xy = 0:0.1:1; % range of w_xz parameter
w_zy = 0:0.1:1; % range of w_yz parameter
epsY = 2; epsX = 2; epsZ = 2; % standard deviation of gaussian noise in X_noise and Y
ratio_sig_noise = 0.2; % standard deviation of gaussian noise in X_signal (expressed as fraction of epsX)


% Global params
tparams.simLen = 60; % simulation time, in units of 10ms
tparams.stimWin = [30 35]; % X stimulus encoding window, in units of 10ms
tparams.delays = [4,6]; % communication delays, in units of 10ms
tparams.delayMax = 10; % maximum computed delay, in units of 10ms

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 3; % X has 2 dimensions and each will be discretized in opts.n_binsX bins --> X will have opts.n_binsX^2 possible outcomes
opts.n_binsY = 3; 
opts.n_binsS = 4; % Number of stimulus values

% Draw random delay for each repetition
reps_delays = randsample(tparams.delays,simReps,true);

% Initialize structures
fit = nan(simReps,numel(w_xy),numel(w_zy)); di = fit; dfi = fit; cfit = fit;
fitSh.simple = nan(simReps,numel(w_xy),numel(w_zy),nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple; cfitSh.simple = fitSh.simple;
fitSh.cond = nan(simReps,numel(w_xy),numel(w_zy),nShuff); diSh.cond = fitSh.simple; dfiSh.cond = fitSh.simple; cfitSh.cond = fitSh.simple;

%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition ',num2str(repIdx)])
    
    for xIdx = 1:numel(w_xy)
        for zIdx = 1:numel(w_zy)
            nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials

            S = randi(opts.n_binsS,1,nTrials)-1; % for the nonlinear encoding function we take inputs from 0 to 3

            % Notation slightly confusing: eps is the infinitesimal constant in
            % matlab, epsX and epsZ are the magnitude of noise in node X
            % and Z, respectively

            % Simulate sender X
            X = eps*epsX*randn(tparams.simLen,nTrials); % X signal time series (infinitesimal noise to avoid issues with binning zeros aftwewards)
            X(tparams.stimWin(1):tparams.stimWin(2),:) = repmat(S,numel(tparams.stimWin(1):tparams.stimWin(2)),1);
            X = X.*(1+epsX*ratio_sig_noise*randn(1,nTrials)); % Multiplicative noise
            
            % Simulate third region Z (Z encodes S nonlinearly)
            Z = eps*epsZ*randn(tparams.simLen,nTrials); % X signal time series (infinitesimal noise to avoid issues with binning zeros aftwewards)
            Z(tparams.stimWin(1):tparams.stimWin(2),:) = repmat(nonlinear_encoding(S),numel(tparams.stimWin(1):tparams.stimWin(2)),1);
            Z = Z.*(1+epsZ*ratio_sig_noise*randn(1,nTrials)); % Multiplicative noise

            % Time lagged single-trial inputs from the 2 dimensions of X
            X2Y = [eps*epsX*randn(reps_delays(repIdx),nTrials); X(1:end-reps_delays(repIdx),:)];
            Z2Y = [eps*epsZ*randn(reps_delays(repIdx),nTrials); Z(1:end-reps_delays(repIdx),:)];
            
            Y = w_xy(xIdx)*X2Y + w_zy(zIdx)*Z2Y + epsY*randn(tparams.simLen,nTrials); % Y is the sum of X signal and noise dimension plus an internal noise

            % select first time-delay of stim.-related communication
            t = tparams.stimWin(1)+reps_delays(repIdx); % first emitting time point + delay
            d = reps_delays(repIdx);

            % Discretize activity
            edgs = eqpop(X(t-d,:), opts.n_binsX);
            [~,bX] = histc(X(t-d,:), edgs);
            edgs = eqpop(Z(t-d,:), opts.n_binsX);
            [~,bZ] = histc(Z(t-d,:), edgs);
            
            edgs = eqpop(Y(t,:), opts.n_binsY);
            [~,bYt] = histc(Y(t,:), edgs);
            edgs = eqpop(Y(t-d,:), opts.n_binsY);
            [~,bYpast] = histc(Y(t-d,:), edgs);

            %bX=map_Nd_resp_to_1d([bX_sig;bX_noise]);
            S = S + 1;

            [di(repIdx,xIdx,zIdx),dfi(repIdx,xIdx,zIdx), fit(repIdx,xIdx,zIdx), cfit(repIdx,xIdx,zIdx)] = ...
                        compute_FIT_TE_cFIT(S, bX, bYt, bYpast, bZ);
                    
            for shIdx = 1:nShuff

                % conditioned shuff (i.e. shuffling X at fixed values of S)
                Sval = unique(S);
                for Ss = 1:numel(Sval)
                    idx = (S == Sval(Ss));
                    tmpX = bX(idx);
                    ridx = randperm(sum(idx));
                    XSh(1,idx) = tmpX(ridx);
                end
            
                [diSh.cond(repIdx,xIdx,zIdx,shIdx), dfiSh.cond(repIdx,xIdx,zIdx,shIdx), fitSh.cond(repIdx,xIdx,zIdx,shIdx), cfitSh.cond(repIdx,xIdx,zIdx,shIdx)] = ...
                        compute_FIT_TE_cFIT( S, XSh, bYt, bYpast, bZ);
                    
                % simple shuff (i.e. shuffling S across all trials)
                idx = randperm(nTrials);
                Ssh = S(idx);
                XSh = bX(idx);

                [~,~,fitSh.simple(repIdx,xIdx,zIdx,shIdx),cfitSh.simple(repIdx,xIdx,zIdx,shIdx)]=...
                    compute_FIT_TE_cFIT(Ssh, bX, bYt, bYpast, bZ);
                [diSh.simple(repIdx,xIdx,zIdx,shIdx)]=...
                    DI_infToolBox(XSh, bYt, bYpast, 'naive', 0);

            end
        end
    end
end

RSNLab = num2str(ratio_sig_noise);
RSNLab = replace(RSNLab,'.','');
fname = ['NIPS_FIGS4.mat'];
save([results_path,'\confounder\',fname])

%% Plots
rng(1)
prctile_plt = 99; % percentile used to determine significance
n_boot = 500; % number of samples for the null hypothesis distribution
save_plot = 0;
path_save = [];

% Pool nulls across repetitions
shuff_types = {'cond','simple'};
null2plot = 'max'; % null hypothesis (maximum between the two)

% Compute pvalues
for shLab = shuff_types
    pooledFITsh.(shLab{1}) = btstrp_shuff(fitSh.(shLab{1}),n_boot);
    pooledCFITsh.(shLab{1}) = btstrp_shuff(cfitSh.(shLab{1}),n_boot);
end
pooledFITsh.max = max(cat(4,pooledFITsh.simple,pooledFITsh.cond),[],4);
pooledCFITsh.max = max(cat(4,pooledCFITsh.simple,pooledCFITsh.cond),[],4);
    
pvalsFIT = mean(squeeze(mean(fit,1)) <= pooledFITsh.(null2plot),3)+1/size(pooledFITsh.(null2plot),3);
pvalsCFIT = mean(squeeze(mean(cfit,1)) <= pooledCFITsh.(null2plot),3)+1/size(pooledCFITsh.(null2plot),3);

%% Plot heatmaps (FIT and cFIT with stim and noise)
fig=figure('Position',[41,298,766,305]);
ax=subplot(1,2,1);
hold on
imagesc(w_xy,w_zy,squeeze(mean(fit,1)))
xlabel('w_{zy}')
ylabel('w_{xy}')
title('FIT')
cmap = my_colormap_rb(squeeze(mean(fit,1)));
colormap(ax(1),cmap)
colorbar()
set(gca,'Ydir','normal')
for i = 1:numel(w_xy)
    for j = 1:numel(w_zy)
        pvalues_plot_threshold(pvalsFIT(i,j),w_xy(j),w_zy(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
xlim([-0.05,w_xy(end)+0.05])
ylim([-0.05,w_zy(end)+0.05])

ax=subplot(1,2,2);
hold on
imagesc(w_xy,w_zy,squeeze(mean(cfit,1)))
xlabel('w_{zy}')
ylabel('w_{xy}')
title('cFIT')
cmap = my_colormap_rb(squeeze(mean(fit,1)));
colormap(ax(1),cmap)
colorbar()
set(gca,'Ydir','normal')
for i = 1:numel(w_xy)
    for j = 1:numel(w_zy)
        pvalues_plot_threshold(pvalsCFIT(i,j),w_xy(j),w_zy(i),0,0,12,0,'k',(100-prctile_plt)/100)
    end
end
xlim([-0.05,w_xy(end)+0.05])
ylim([-0.05,w_zy(end)+0.05])

sgtitle('FIT and cFIT dependence on w_{xy} and w_{zy}')

if save_plot
    fname = [path_save,'SuppFig\parallelTransf_maps_',date,'.png'];
    saveas(fig,fname)
    fname = [path_save,'SuppFig\parallelTransf_maps_',date,'.fig'];
    saveas(fig,fname)
    fname = [path_save,'SuppFig\parallelTransf_maps_',date,'.svg'];
    fig.Renderer = 'painters';
    saveas(fig,fname)
end

%% Plot trends for w_xy = 0
fig=figure();
subplot(2,1,1)
hold on
plot(w_zy,squeeze(mean(fit(:,1,:),1)))
plt_shuff = squeeze(pooledFITsh.(null2plot)(1,:,:));
plot(w_zy,prctile(plt_shuff,prctile_plt,2),'g--')
plt_shuff = squeeze(pooledFITsh.simple(1,:,:));
plot(w_zy,prctile(plt_shuff,prctile_plt,2),'r--')
% plt_shuff = squeeze(fit_shuff.cond(1,:,:));
% plot(w_zy,prctile(plt_shuff,prctile_plt,2),'--')

xlabel('w_{zy}')
title('FIT')
set(gca,'Ydir','normal')

subplot(2,1,2)
hold on
plot(w_zy,squeeze(mean(cfit(:,1,:),1)))
plt_shuff = squeeze(pooledCFITsh.(null2plot)(1,:,:));
plot(w_zy,prctile(plt_shuff,prctile_plt,2),'g--')

xlabel('w_{zy}')
title('cFIT')
set(gca,'Ydir','normal')

sgtitle(['FIT, and cFIT on w_{zy}; w_{xy} = ',num2str(w_xy(1))])

if save_plot
    fname = [path_save,'confounder\trendsWzy_Wxy0_maps_',date,'.png'];
    saveas(fig,fname)
    fname = [path_save,'confounder\trendsWzy_Wxy0_maps_',date,'.fig'];
    saveas(fig,fname)
    fname = [path_save,'confounder\trendsWzy_Wxy0_maps_',date,'.svg'];
    fig.Renderer = 'painters';
    saveas(fig,fname)
end

%% Nullhyp distributions
fixed_value = 1;

simpleShuff = squeeze(pooledFITsh.simple(1,fixed_value,:));
condShuff = squeeze(pooledFITsh.cond(1,fixed_value,:));
figure()
hold on
histogram(simpleShuff,10,'FaceColor',[0.7,0.7,0.7])
histogram(condShuff,10,'FaceColor',[0.4,0.4,1])
h(1)=xline(prctile(condShuff,prctile_plt),'b','linewidth',2);
h(2)=xline(prctile(simpleShuff,prctile_plt),'color',[0.4,0.4,0.4],'linewidth',2);
h(3)=xline(mean(fit(:,1,fixed_value)),'r','linewidth',2);
lgd=legend([h(1),h(2),h(3)],'Within-trial 99prc','S-shuff 99prc','Measured FIT');
lgd.FontSize = 12;
