% Script to run the bidirectional transfer scenario of Fig.2 supp.1 

clear all
%clf
%Global simulation parameters
numOfStimuli = 3; 
simulationLength = 230;
%The stimulus is modelled as a gaussian bump in the activity
stimulusEffect = normpdf(-3:0.15:3, 0, 2) - min(normpdf(-3:0.15:3, 0, 2));
stimulusStrengthY = 4.2;
stimulusStrengthX = 4.2;
noiseLambda = 2;
maxDelay = 45;
bins = 3;
Nstrenght = 5;
%The number of parallel processes you can start in your cluster/PC
numWorkers = 4;
doShuffle = 1;
numOfShuffling = 100;
numbersOfTrials = 2^18;%, 2^15, 2^16, 2^17 ];
numberOfTrials = numbersOfTrials(length(numbersOfTrials));
computeQErdfi = 0;

delayXY = 10;
delayYX = 15;
delaySY = 60;

YXw = 0.7;
XYw = 0.7;

stimulusOnsetTime1 = 100;
timeStart = 50; % Time from which we start computing DI and FIT (to avoid the transient phase at the beginning of the dynamics)
stimulusOnsetTime2 = stimulusOnsetTime1 + delayXY;
stimulusEndTime1 = stimulusOnsetTime1 + length(stimulusEffect) - 1;
stimulusEndTime2 = stimulusOnsetTime2 + length(stimulusEffect) - 1;

noiseAmplitude_x = 4;
noiseAmplitude_y = 12;
N_x = -noiseAmplitude_x+2*noiseAmplitude_x*rand(simulationLength,numberOfTrials);
N_ya = -noiseAmplitude_y+2*noiseAmplitude_y*rand(simulationLength,numberOfTrials);
N_yb = -noiseAmplitude_y+2*noiseAmplitude_y*rand(simulationLength,numberOfTrials);

baseline = 1;

signal1 = baseline*ones(simulationLength, numberOfTrials);
signal2 = baseline*ones(simulationLength, numberOfTrials);
stimulus = randi([0 numOfStimuli], numberOfTrials,1);

activity1 = zeros(simulationLength, numberOfTrials);
activity2 = zeros(simulationLength, numberOfTrials);
activity2a = zeros(simulationLength, numberOfTrials);
activity2b = zeros(simulationLength, numberOfTrials);

%The activity of the X cluster (1a and 1b) is modulated by the gaussian
%bump. X- is the Y-axis mirror version of X+, so that when their
%activity is averaged and passed to Y, there is only stim-unrelated info
%transfer.
signal1(stimulusOnsetTime1:stimulusEndTime1, :) = signal1(stimulusOnsetTime1:stimulusEndTime1, :) + stimulusEffect' * stimulus' * stimulusStrengthX;% + N(stimulusOnsetTime1:stimulusEndTime1,:);
signal1 = signal1 .* 2;
signal1(signal1<0)= 0;


signal2(stimulusOnsetTime1+delaySY:stimulusEndTime1+delaySY, :) = signal2(stimulusOnsetTime1+delaySY:stimulusEndTime1+delaySY, :) + stimulusEffect' * stimulus' * stimulusStrengthY;
signal2 = signal2 .* 2;
signal2(signal2<0)= 0;

% Generate simulated poissonian activity
for t = 1:simulationLength
    if t == 1
        % Initialize X(0) so that dX/dt = 0 at t = 0 
%         activity1(t,:) = poissrnd(signal1(1,:))*(1+YXw)./(1-XYw*YXw) + N_x(t,:);
%         activity2(t,:) = poissrnd(signal2(1,:))*(1+XYw)./(1-XYw*YXw) + 2*N_y(t,:);
        activity1(t,:) = poissrnd(signal1(1,:))*(1+YXw)./(1-XYw*YXw) + N_x(t,:);
        activity2(t,:) = poissrnd(signal2(1,:))*(1+XYw)./(1-XYw*YXw) + N_ya(t,:) + N_yb(t,:);
        activity2a(t,:) = activity2(t,:);
        activity2b(t,:) = activity2(t,:);
        
    elseif t>1
        
        if t-delayXY <= 0

            YtoX = activity2(1,:)*YXw;
            XtoY = activity1(1,:)*XYw;

        elseif t <= delayYX

            YtoX = activity2(1,:)*YXw;
            XtoY = activity1(t-delayXY,:)*XYw;

        elseif t > delayYX

            YtoX = activity2(t-delayYX,:)*YXw;
            XtoY = activity1(t-delayXY,:)*XYw;

        end
    
        %signal2baseline = 2*mean(activity2(1:delayXY,:),1);
        signal2baseline = 2*signal2(1,:)*(1+XYw)./(1-XYw*YXw);

        activity1(t,:) = activity1(t,:) + poissrnd(signal1(t,:)) + YtoX + N_x(t, :);
        activity2a(t,:) = activity2a(t,:) + poissrnd(signal2(t,:)) + XtoY + N_ya(t,:);
        activity2b(t,:) = signal2baseline - activity2a(t,:) + N_yb(t,:);
        activity2(t,:) = (activity2a(t,:)+activity2b(t,:))/2;
    end
end


% We take the inter-trial average.
final_activity1 = mean(activity1,2);
final_activity2 = mean(activity2,2);
final_activity2a = mean(activity2a,2);
final_activity2b = mean(activity2b,2);

figure
plot(final_activity1(1:end))
hold on
%yyaxis right
plot(final_activity2a(1:end))
plot(final_activity2b(1:end))
title(['Delay fwd = ', num2str(delayXY), ' bwd = ', num2str(delayYX)], 'fontsize', 16)
xlabel('time', 'fontsize', 14)
legend('X', 'Y')

%% Preallocation of all the different InfoTheory quantities:

    % r: denotes one set of synergy. Basically stands for {} but they
    % cannot be in a name of a variable. rX -> {X}, rXrYhY -> {X}{Y, hY}
    % red: denotes one set of redundancy.
    % FIT: Feature-related Information Transfer (the one proposed here).
    % DFI: Directed Feature Information (Ince, 2015).
    % di: Directed Information (Massey, 1990).
    % *ItQe: quadratic interpolation of * .
    % *Sh: the shuffled version of * .
    
    FIT = zeros(simulationLength-timeStart, maxDelay);
    FIT_YtoX = FIT; dfi = FIT; di = FIT;
    dfi_YtoX = FIT; di_YtoX = FIT;
    
    FITSh = zeros(simulationLength-timeStart, maxDelay, numOfShuffling);
    FITSh_YtoX = FITSh; dfiSh = FITSh; diSh = FITSh;
    dfi_YtoXSh = FITSh; di_YtoXSh = FITSh;

    % Compute the information theory quantities

    if (numWorkers > 1)
        if (isempty(gcp('nocreate')))
            parpool(numWorkers);
        else
            poolTemp = gcp;
            if(poolTemp.NumWorkers < numWorkers)
                delete(gcp)
                parpool(numWorkers);
            end
        end
    end

    pctRunOnAll warning('off','all')

    tic

    for n = 1:length(numbersOfTrials)
        trialIds = 1:(numberOfTrials / numbersOfTrials(n)):numberOfTrials;
        a1 = activity1(:, trialIds);
        stim = (stimulus(trialIds) + 1)';

        stimSh = zeros(numOfShuffling, length(stim));

    %In every parallel process, check if the current timestep is greater
    %than the selected delay of transmission AND if either the current
    %trial is the last one or the delay and time are present in the 
    %coordinates (if delay \in indicesX AND time \in indicesY). If these
    %conditions are fulfilled, bin X and Y using a number of bins == bins.
        for time = timeStart+1:simulationLength
            for delay = 1:maxDelay
                if (time - delay) > 0
                    X = a1(time, :);
                    edgs = eqpop(X, bins);
                    [~,X] = histc(X, edgs);
                    
                    hX = a1(time-delay, :);
                    edgs = eqpop(hX, bins);
                    [~,hX] = histc(hX, edgs);

                    edgs = eqpop(activity2a(time, :), bins);
                    [~,a2a] = histc(activity2a(time, :), edgs);
                    edgs = eqpop(activity2b(time, :), bins);
                    [~,a2b] = histc(activity2b(time, :), edgs);
                    Y = (a2a - 1) .* bins + a2b;

                    edgs = eqpop(activity2a(time-delay, :), bins);
                    [~,ha2a] = histc(activity2a(time-delay, :), edgs);
                    edgs = eqpop(activity2b(time-delay, :), bins);
                    [~,ha2b] = histc(activity2b(time-delay, :), edgs);
                    hY = (ha2a - 1) .* bins + ha2b;

                    if (size(X, 2) < size(X,1))
                        X = X';
                        Y = Y';
                        hX = hX';
                        hY = hY';
                    end

                    % compute DI, DFI, FIT
                     [di(time-timeStart, delay), dfi(time-timeStart, delay), FIT(time-timeStart, delay)] = compute_FIT_TE(stim, hX, Y, hY);
                     [di_YtoX(time-timeStart, delay), dfi_YtoX(time-timeStart, delay), FIT_YtoX(time-timeStart, delay)] = compute_FIT_TE(stim, hY, X, hX);

                    % compute shuffled values
                    if (doShuffle)
                        Sval = unique(stim);  
                        for s = 1:numOfShuffling
                            % shufflings for FIT and DFI
                            XSh = zeros(1,numel(X));YSh = XSh;
                            for Ss = 1:numel(Sval)
                                idx = (stim == Sval(Ss)); % select trials where stim = Ss
                                Xx = hX(idx); % take X values on those trials
                                ridx = randperm(sum(idx)); % generate random idxs
                                XSh(1, idx) = Xx(ridx); % assign to XSh, on idxs where stim = Ss, reshuffled values of X at stim = Ss
                                
                                Yy = hY(idx); % take Y values on those trials
                                YSh(1, idx) = Yy(ridx); % assign to YSh, on idxs where stim = Ss, reshuffled values of Y at stim = Ss
                            end
                            [~, dfiSh(time-timeStart, delay, s), FITSh(time-timeStart, delay, s)] = compute_FIT_TE(stim, XSh, Y, hY);
                            [~, dfi_YtoX_Sh(time-timeStart, delay,s), FIT_YtoX_Sh(time-timeStart, delay,s)] = compute_FIT_TE(stim, YSh, X, hX);
                        end
                        
                        for s = 1:numOfShuffling
                            % shufflings for DI
                            XSh = zeros(1,numel(X));YSh = XSh;
                            
                            index = randperm(numel(stim));
                            XSh = hX(index);
                            YSh = hY(index);
                            diSh(time-timeStart, delay, s) = DI_infToolBox(XSh, Y, hY, 'naive', 0);
                            di_YtoX_Sh(time-timeStart, delay, s) = DI_infToolBox(YSh, X, hX, 'naive', 0);
                        end
                            
                    end
                end
            end
           % fprintf('%d\n', time);
        end
        fprintf('%d\n', n);
    end

    save FigS5_results
    
%% Plot   

avgDI = squeeze(mean(di,2));
avgDFI = squeeze(mean(dfi,2));
avgFIT = squeeze(mean(FIT,2));

avgDI_YtoX = squeeze(mean(di_YtoX,2));
avgDFI_YtoX = squeeze(mean(dfi_YtoX,2));
avgFIT_YtoX = squeeze(mean(FIT_YtoX,2));

simTimeSteps = numel(avgDI);
minTime = 1;

prc_plot = 99;
prc_plotDFI = 99.5; % for DFI we do two-tailed test since it can be either positive or negative
boxPlot.tMin = 51;
boxPlot.tMax = 100;

figure()
subplot(332)
% ACTIVITY PLOT
simTimeStepsActiv = size(final_activity1,1);
timeDiff = simTimeStepsActiv-simTimeSteps;
%plot(1:simTimeStepsActiv-timeDiff, mean(activity1(timeDiff+1:end,:),2))
plot(1:simTimeStepsActiv-timeDiff, final_activity1(timeDiff+1:end))
hold on
plot(1:simTimeStepsActiv-timeDiff, final_activity2a(timeDiff+1:end))
plot(1:simTimeStepsActiv-timeDiff, final_activity2b(timeDiff+1:end))
%legend('Node X+','Node X-','Node Y','Location','northeast')
legend('Node X','Node Y+', 'Node Y-','Location','northeast')
%xlim([minTime+timeDiff simTimeStepsActiv])
xlabel('Time (ms)')
ylabel('Mean Activity')


subplot(334)
% DI X-->Y PLOT
for i =1:size(squeeze(diSh),3)
    avgDiSh(i,:) = squeeze(mean(diSh(:,:,i),2));
%     plot(avgDiSh(i,:),'.','Color',[0.8 0.8 0.8])
    hold on
end
x = 1:simTimeSteps;
f = squeeze(avgDiSh);
minf = prctile(f,100-prc_plot);
maxf = prctile(f,prc_plot);

fill([x fliplr(x)], [minf fliplr(maxf)], [0.8, 0.8, 0.8], 'EdgeColor','None')%[117, 128, 240]/255,'EdgeColor','None') 
plot(avgDI(:),'LineWidth',1,'Color','k')

for i = 1:simTimeSteps
red(i) = avgDI(i)>maxf(i);
end
% red(simTimeSteps:end) = 0;
plot(find(red>0),avgDI(red>0),'ko','MarkerFaceColor','k','MarkerSize',4)
xlim([minTime simTimeSteps])
maxY = max([avgDI(minTime:simTimeSteps); avgDI_YtoX(minTime:simTimeSteps);0]);
minY = min([avgDI(minTime:simTimeSteps); avgDI_YtoX(minTime:simTimeSteps);0]);
ylim([1.25*minY 1.25*maxY])
% xticks([20 40 60 80 simTimeSteps simTimeSteps])
% xticklabels({'0','20','40','60','80','simTimeSteps'})
xlabel('Time (ms)')
ylabel('[bits]')
title('DI_{X \rightarrow Y}')
set(gca,'FontSize',16);

subplot(335)
% DFI X-->Y PLOT
for i =1:size(squeeze(diSh),3)
      avgDfiSh(i,:) = squeeze(mean(dfiSh(:,:,i),2));
%     plot(avgDfiSh(i,:),'.-','Color',[0.8 0.8 0.8])
    hold on
end
f = squeeze(avgDfiSh);
minf = prctile(f,(100-prc_plotDFI));
maxf = prctile(f,prc_plotDFI);

fill([x fliplr(x)], [minf fliplr(maxf)], [152, 245, 160]/255,'EdgeColor','None') 
hold on
plot(avgDFI(:),'LineWidth',1,'Color',[22, simTimeSteps, 51]/255)

for i = 1:simTimeSteps
red(i) = (avgDFI(i)>maxf(i)) || (avgDFI(i)<minf(i));
end

plot(find(red),avgDFI(red>0),'go','MarkerFaceColor',[22, simTimeSteps, 51]/255,'MarkerSize',6)
xlim([minTime simTimeSteps])
maxY = max([avgDFI(minTime:simTimeSteps); avgDFI_YtoX(minTime:simTimeSteps);0]);
minY = min([avgDFI(minTime:simTimeSteps); avgDFI_YtoX(minTime:simTimeSteps);0]);
ylim([1.25*minY 1.25*maxY])
% xticks([20 40 60 80 simTimeSteps simTimeSteps])
% xticklabels({'0','20','40','60','80','simTimeSteps'})
xlabel('Time (ms)')
title('DFI_{X \rightarrow Y}')
set(gca,'FontSize',16);

subplot(336)
% FIT X-->Y PLOT
for i =1:size(squeeze(diSh),3)
      avgFITSh(i,:) = squeeze(mean(FITSh(:,:,i),2));
%     plot(avgFITSh(i,:),'.-','Color',[0.8 0.8 0.8])
    hold on
end

f = squeeze(avgFITSh);
minf = prctile(f,100-prc_plot);
maxf = prctile(f,prc_plot);

fill([x fliplr(x)], [maxf zeros(size(x))], [245, 171, 152]/255,'EdgeColor','None') 
hold on
plot(avgFIT(:),'LineWidth',1,'Color','r')
for i = 1:simTimeSteps
red(i) = avgFIT(i)>maxf(i);
end

plot(find(red),avgFIT(red>0),'ro','MarkerFaceColor','r','MarkerSize',6)
maxY = max([avgFIT(minTime:simTimeSteps); avgFIT_YtoX(minTime:simTimeSteps);0]);
minY = min([avgFIT(minTime:simTimeSteps); avgFIT_YtoX(minTime:simTimeSteps);0]);
ylim([1.25*minY 1.25*maxY])
xlim([minTime simTimeSteps])
% xticks([20 40 60 80 simTimeSteps simTimeSteps])
% xticklabels({'0','20','40','60','80','simTimeSteps'})
% plot(maxf,'LineWidth',1,'Color','r')
xlabel('Time (ms)')
title('FIT_{X \rightarrow Y}')
set(gca,'FontSize',16);
% ylim([-2e-5 4e-4]);



subplot(337)
% DI Y-->X PLOT
for i =1:size(squeeze(diSh),3)
      avgDI_YtoX_Sh(i,:) = squeeze(mean(di_YtoX_Sh(:,:,i),2));
%     plot(avgDI_YtoX_Sh(i,:),'.-','Color',[0.8 0.8 0.8])
    hold on
end

f = squeeze(avgDI_YtoX_Sh);
minf = prctile(f,100-prc_plot);
maxf = prctile(f,prc_plot);

fill([x fliplr(x)], [minf fliplr(maxf)], [0.8, 0.8, 0.8], 'EdgeColor','None')%[117, 128, 240]/255,'EdgeColor','None') 
plot(avgDI_YtoX(:),'LineWidth',1,'Color','k')
for i = 1:simTimeSteps
red(i) = avgDI_YtoX(i)>maxf(i);
end

plot(find(red),avgDI_YtoX(red>0),'ko','MarkerFaceColor','k','MarkerSize',4)
maxY = max([avgDI(minTime:simTimeSteps); avgDI_YtoX(minTime:simTimeSteps);0]);
minY = min([avgDI(minTime:simTimeSteps); avgDI_YtoX(minTime:simTimeSteps);0]);
ylim([1.25*minY 1.25*maxY])
xlim([minTime simTimeSteps])
% xticks([20 40 60 80 simTimeSteps simTimeSteps])
% xticklabels({'0','20','40','60','80','simTimeSteps'})
% plot(maxf,'LineWidth',1,'Color','r')
xlabel('Time (ms)')
title('DI_{Y \rightarrow X}')
set(gca,'FontSize',16);
% ylim([-2e-5 4e-4]);


subplot(338)
% DFI Y-->X PLOT
for i =1:size(squeeze(diSh),3)
      avgDFI_YtoX_Sh(i,:) = squeeze(mean(dfi_YtoX_Sh(:,:,i),2));
%     plot(avgDFI_YtoX_Sh(i,:),'.-','Color',[0.8 0.8 0.8])
    hold on
end

f = squeeze(avgDFI_YtoX_Sh);
minf = prctile(f,100-prc_plotDFI);
maxf = prctile(f,prc_plotDFI);

fill([x fliplr(x)], [minf fliplr(maxf)], [152, 245, 160]/255, 'EdgeColor','None')%[117, 128, 240]/255,'EdgeColor','None') 
plot(avgDFI_YtoX(:),'LineWidth',1,'Color',[22, simTimeSteps, 51]/255)
for i = 1:simTimeSteps
red(i) = (avgDFI_YtoX(i)>maxf(i)) || (avgDFI_YtoX(i)<minf(i));
end

plot(find(red),avgDFI_YtoX(red>0),'go','MarkerFaceColor',[22, simTimeSteps, 51]/255,'MarkerSize',6)
maxY = max([avgDFI(minTime:simTimeSteps); avgDFI_YtoX(minTime:simTimeSteps);0]);
minY = min([avgDFI(minTime:simTimeSteps); avgDFI_YtoX(minTime:simTimeSteps);0]);
ylim([1.25*minY 1.25*maxY])
xlim([minTime simTimeSteps])
% xticks([20 40 60 80 simTimeSteps simTimeSteps])
% xticklabels({'0','20','40','60','80','simTimeSteps'})
% plot(maxf,'LineWidth',1,'Color','r')
xlabel('Time (ms)')
title('DFI_{Y \rightarrow X}')
set(gca,'FontSize',16);
% ylim([-2e-5 4e-4]);


subplot(339)
% FIT Y-->X PLOT
for i =1:size(squeeze(diSh),3)
    
      avgFIT_YtoX_Sh(i,:) = squeeze(mean(FIT_YtoX_Sh(:,:,i),2));
%     plot(avgFIT_YtoX_Sh(i,:),'.-','Color',[0.8 0.8 0.8])
    hold on
end

f = squeeze(avgFIT_YtoX_Sh);
minf = prctile(f,100-prc_plot);
maxf = prctile(f,prc_plot);

fill([x fliplr(x)], [maxf zeros(size(x))], [245, 171, 152]/255,'EdgeColor','None') 
hold on
plot(avgFIT_YtoX(:),'LineWidth',1,'Color','r')
for i = 1:simTimeSteps
red(i) = avgFIT_YtoX(i)>maxf(i);
end

plot(find(red),avgFIT_YtoX(red>0),'ro','MarkerFaceColor','r','MarkerSize',6)
maxY = max([avgFIT(minTime:simTimeSteps); avgFIT_YtoX(minTime:simTimeSteps);0]);
minY = min([avgFIT(minTime:simTimeSteps); avgFIT_YtoX(minTime:simTimeSteps);0]);
ylim([1.25*minY 1.25*maxY])
xlim([minTime simTimeSteps])
% xticks([20 40 60 80 simTimeSteps simTimeSteps])
% xticklabels({'0','20','40','60','80','simTimeSteps'})
% plot(maxf,'LineWidth',1,'Color','r')
xlabel('Time (ms)')
title('FIT_{Y \rightarrow X}')
set(gca,'FontSize',16);
% ylim([-2e-5 4e-4]);


% DI boxplots fwd vs bwd

DIpoints=[avgDI(boxPlot.tMin:boxPlot.tMax),avgDI_YtoX(boxPlot.tMin:boxPlot.tMax)];
orig_DI_fwd = mean(avgDI(boxPlot.tMin:boxPlot.tMax));
orig_DI_bwd = mean(avgDI_YtoX(boxPlot.tMin:boxPlot.tMax));
surrogate_values = zeros(1,1000);
for i = 1:1000
    ridx = randperm(numel(DIpoints));
    surrDI = DIpoints(ridx);
    surrogate_values(i) = mean(surrDI(1:boxPlot.tMin));
end
