%% FIT in presence of overlapping time courses

% In this script we test FIT ability to determine the directionality of
% information flow in presence of overlapping feature info timecourses.
% In doing so, we also show that FIT is sensitive to the time-lagged
% similarity of information encoding format

clear all; % close all;

rng(1) % For reproducibility
save_results = 0; % set to 1 to save results file
results_path =  'C:\Users\mcelotto\Desktop\Neural Computation\PhD\Lemke\Reach_to_grasp_project\Scripts\Stefan\revisions_260623\figures\new_simulations'; % path to the directory to save results

% Simulation parameters
nTrials_per_stim = 500; % number of trials per stimulus value
simReps = 1000; % repetitions of the simulation
nShuff = 10; % number of permutations (used for both FIT permutation tests)
% Set transfer_scenario = 'XY_transfer' to reproduce Fig. 2E results
transfer_scenario = 'XY_transfer'; % either 'indep_encoding', 'XY_transfer', 'bidir_noise'

% w_xy_noise = 0; % range of w_noise parameter
SNR_range = [0.05:0.05:1];
delta = 1;
noise_range = [delta./SNR_range]; % standard deviation of gaussian noise in X_noise and Y

f_encoding_hX = [1 2 3 4];
f_encoding_hY = [2 1 3 4];
f_encoding_Xt = [1 2 4 3];
f_encoding_Yt = [1 2 3 4];

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 3;
opts.n_binsY = 3; 
opts.n_binsS = 4; % Number of stimulus values
opts.bin_method_X = 'none';
opts.bin_method_Y = 'none';
opts.bin_method_S = 'none';

shuff_types = {'cond','simple'}; % 'cond' in the shuffling of X at fixed S, 'simple' is the shuffling of S across all trials
n_boot = 500; % number of sample of the null distribution
prctile_plt = 99; % percentile used to determine significance

% Initialize structures
fit_xy = nan(simReps,numel(noise_range)); TE_xy = fit_xy; 
fit_yx = nan(simReps,numel(noise_range)); TE_yx = fit_xy; 
fitSh_xy.simple = nan(simReps,numel(noise_range),nShuff); TESh_xy.simple = fitSh_xy.simple; 
fitSh_yx.simple = nan(simReps,numel(noise_range),nShuff); TESh_yx.simple = fitSh_xy.simple; 
fitSh_xy.cond = nan(simReps,numel(noise_range),nShuff); TESh_xy.simple = fitSh_xy.cond; 
fitSh_yx.cond = nan(simReps,numel(noise_range),nShuff); TESh_yx.simple = fitSh_xy.cond; 

info.hX = fit_xy; info.Xt = fit_xy; info.hY = fit_xy; info.Yt = fit_xy;

%% Plot encoding functions (Fig.2C)
nTrials = nTrials_per_stim(1)*opts.n_binsS; % Compute number of trials
S = randi(opts.n_binsS,1,nTrials);

figure('Position',[4,311,1274,279])
subplot(1,4,1)
encoding_function(S, f_encoding_hX, 1, 1);
xlim([0.5 4.5]) 
xlabel('Feature')
xticklabels({'1','2','3'})
title('X_{past}')
subplot(1,4,2)
encoding_function(S, f_encoding_hY, 1, 1);
xlim([0.5 4.5])
xlabel('Feature') 
xticklabels({'1','2','3'})
title('Y_{past}')
subplot(1,4,3)
encoding_function(S, f_encoding_Xt, 1, 1);
xlim([0.5 4.5]) 
xlabel('Feature') 
xticklabels({'1','2','3'})
title('X_{pres}')
subplot(1,4,4)
encoding_function(S, f_encoding_Yt, 1, 1);
xlim([0.5 4.5]) 
xlabel('Feature') 
xticklabels({'1','2','3'})
title('Y_{pres}')

%% Run simulation

    
for noiseIdx = 1:numel(noise_range)
    disp(['Simulation for ',num2str(SNR_range(noiseIdx)),' SNR'])
    for repIdx = 1:simReps
        %disp(['Repetition number ',num2str(repIdx)]);
        nTrials = nTrials_per_stim(1)*opts.n_binsS; % Compute number of trials

        % Draw the stimulus value for each trial
        S = randi(opts.n_binsS,1,nTrials);

        % simulate X signal
        switch transfer_scenario
            case 'indep_encoding'
                hX = encoding_function(S,f_encoding_hX,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
                Yt = encoding_function(S,f_encoding_hX,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
                hY = encoding_function(S,f_encoding_hY,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
                Xt = encoding_function(S,f_encoding_Xt,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
            case 'XY_transfer' % Scenario used for Fig.2E --> real transmission from X to Y (not from Y to X). Y_pres will carry feature-info in the same format of X_past due to info transfer
                hX = encoding_function(S,f_encoding_hX,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
                Yt = hX; 
                hY = encoding_function(S,f_encoding_hY,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
                Xt = encoding_function(S,f_encoding_Xt,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
            case 'bidir_noise'
                noiseYX = noise_range(noiseIdx)*randn(1,nTrials);
                hX = encoding_function(S,f_encoding_hX,delta,0) + noise_range(noiseIdx)*randn(1,nTrials); 
                Yt = hX; 
                hY = encoding_function(S,f_encoding_hY,delta,0) + noiseYX; 
                Xt = encoding_function(S,f_encoding_Xt,delta,0) + noiseYX; 
        end
                
        % Discretize neural activity
         edgs = eqpop(hX, opts.n_binsX);
        [~,bXpast] = histc(hX, edgs);
        edgs = eqpop(Xt, opts.n_binsX);
        [~,bXt] = histc(Xt, edgs);

        edgs = eqpop(hY, opts.n_binsY);
        [~,bYpast] = histc(hY, edgs);
        edgs = eqpop(Yt, opts.n_binsY);
        [~,bYt] = histc(Yt, edgs);
        
        [TE_xy(repIdx,noiseIdx),~,fit_xy(repIdx,noiseIdx)]=...
            compute_FIT_TE(S,bXpast,bYt,bYpast);
        [TE_yx(repIdx,noiseIdx),~,fit_yx(repIdx,noiseIdx)]=...
            compute_FIT_TE(S,bYpast,bXt,bXpast);
        
         for shIdx = 1:nShuff

            % conditioned shuff (shuffle X at fixed S)
            Sval = unique(S);
            for Ss = 1:numel(Sval)
                idx = (S == Sval(Ss));
                ridx = randperm(sum(idx));

                tmpX = bXpast(idx);
                XpastSh(1,idx) = tmpX(ridx);
                tmpY = bYpast(idx);
                YpastSh(1,idx) = tmpY(ridx);
            end

            [TESh_xy.cond(repIdx,noiseIdx,shIdx),~,fitSh_xy.cond(repIdx,noiseIdx,shIdx)]=...
            compute_FIT_TE(S,XpastSh,bYt,bYpast);
            [TESh_yx.cond(repIdx,noiseIdx,shIdx),~,fitSh_yx.cond(repIdx,noiseIdx,shIdx)]=...
            compute_FIT_TE(S,YpastSh,bXt,bXpast);
            
            % simple shuff (shuffle X across all trials)
            idx = randperm(nTrials);
            Ssh = S(idx);
            XpastSh = bXpast(idx);
            YpastSh = bYpast(idx);

            [~,~,fitSh_xy.simple(repIdx,noiseIdx,shIdx)]=...
                compute_FIT_TE(Ssh, bXpast, bYt, bYpast);
            [TESh_xy.simple(repIdx,noiseIdx,shIdx)]=...
                DI_infToolBox(XpastSh, bYt, bYpast, 'naive', 0);

            [~,~,fitSh_yx.simple(repIdx,noiseIdx,shIdx)]=...
                compute_FIT_TE(Ssh, bYpast, bXt, bXpast);
            [TESh_yx.simple(repIdx,noiseIdx,shIdx)]=...
                DI_infToolBox(YpastSh, bXt, bXpast, 'naive', 0);
         end
        
         % Compute mutual info encoded at each time point (used in previous
         % version of the code to verify that the timecourses of
         % information are overlapping)
        [M_xs, nt] = buildr(S,bXpast);
        opts.nt = nt;
        info.hX(repIdx,noiseIdx) = information(M_xs,opts,'I');
        [M_xs, nt] = buildr(S,bXt);
        opts.nt = nt;
        info.Xt(repIdx,noiseIdx) = information(M_xs,opts,'I');
        [M_ys, nt] = buildr(S,bYpast);
        opts.nt = nt;
        info.hY(repIdx,noiseIdx) = information(M_ys,opts,'I');
        [M_ys, nt] = buildr(S,bYt);
        opts.nt = nt;
        info.Yt(repIdx,noiseIdx) = information(M_ys,opts,'I');
        [M_yx, nt] = buildr(bYt,bXpast);
        opts.nt = nt;
        info.hX_Yt(repIdx,noiseIdx) = information(M_yx,opts,'I');
        [M_yx, nt] = buildr(bYpast,bXt);
        opts.nt = nt;
        info.hY_Xt(repIdx,noiseIdx) = information(M_yx,opts,'I');
        
    end
end

if save_results
    fname = ['NIPS_content_dir_3Bins_100Reps_' transfer_scenario '_10Shuff_seed1.mat'];
    save([fname])
end

%% Compute null of mean FIT
for shLab = shuff_types
    pooledFIT_XY_sh.(shLab{1}) = btstrp_shuff(fitSh_xy.(shLab{1}),n_boot);
    pooledFIT_YX_sh.(shLab{1}) = btstrp_shuff(fitSh_yx.(shLab{1}),n_boot);
    pooledTE_XY_sh.(shLab{1}) = btstrp_shuff(TESh_xy.(shLab{1}),n_boot);
    pooledTE_YX_sh.(shLab{1}) = btstrp_shuff(TESh_yx.(shLab{1}),n_boot);
end
pooledFIT_XY_sh.max = max(cat(4,pooledFIT_XY_sh.simple,pooledFIT_XY_sh.cond),[],4);
pooledFIT_YX_sh.max = max(cat(4,pooledFIT_YX_sh.simple,pooledFIT_YX_sh.cond),[],4);
pooledTE_XY_sh.max = max(cat(4,pooledTE_XY_sh.simple,pooledTE_XY_sh.cond),[],4);
pooledTE_YX_sh.max = max(cat(4,pooledTE_YX_sh.simple,pooledTE_YX_sh.cond),[],4);

hi_prc_FIT_xy_null = prctile(pooledFIT_XY_sh.max,prctile_plt,2);
hi_prc_FIT_yx_null = prctile(pooledFIT_YX_sh.max,prctile_plt,2);
hi_prc_TE_xy_null = prctile(pooledTE_XY_sh.simple,prctile_plt,2);
hi_prc_TE_yx_null = prctile(pooledTE_YX_sh.simple,prctile_plt,2);

%% Plot FIT from X to Y and from Y to X as a function of SNR
cols = distinguishable_colors(4);
xval_range = delta./noise_range;
x_vals = [1 2];

figure('Position',[282,268,409,281])
hold on
% FIT from X to Y
h(1) = plot(xval_range,(mean(fit_xy)),'color',cols(1,:));
shadedErrorBar(xval_range,(mean(fit_xy)),(std(fit_xy,[],1)/sqrt(simReps)),'LineProps',{'color',cols(1,:)},'patchSaturation',0.2)
sigmask = ((mean(fit_xy))>hi_prc_FIT_xy_null');
scatter(xval_range(sigmask),(mean(fit_xy(:,sigmask))),12,'filled')

% FIT from Y to X
h(2) = plot(xval_range,(mean(fit_yx)),'color',cols(2,:));
shadedErrorBar(xval_range,(mean(fit_yx)),(std(fit_yx,[],1)/sqrt(simReps)),'LineProps',{'color',cols(2,:)},'patchSaturation',0.2)
sigmask = ((mean(fit_yx))>hi_prc_FIT_yx_null');
scatter(xval_range(sigmask),(mean(fit_yx(:,sigmask))),12,'filled')

ylabel('info')
xlabel('SNR')
ylim([0,0.17])
legend([h(1) h(2)], 'FIT X->Y',' FIT Y->X');
title(['Tranfer direct. due to content difference'])

