% Main script to compute Transfer entropy (TE) and Feature-specific Information
% transfer (FIT) about S and C in Wilming MEG dataset
% (i.e. Fig.3 and Fig.S6)
%
% Note: this script has been tested on a machine running Ubunto18.04, 
% using Matlab2019a

% Throught the script we might refer to TE as 'DI' (Directed Information)
clear all; clc;

% Paths and parameters settings
params.codeVersion = 'NIPS_paper'; % label that will be used to creathe the directory where results will be saved
paths.mainPath = pwd; paths.externalPath = pwd; paths.scriptsPath = pwd; 
paths.dataPath = [paths.mainPath 'data']; % path where data are stored

addpath(genpath(paths.scriptsPath))

% info params (paraeters used for the information analysis)
% The script should be ran for both params.null_type = 'conditioned' and 'simple S' --> in the plot script we take the element-wise maximum
params.null_type = 'conditioned'; % Type of null hypothesis: 'conditioned' (shuffling of X at fixed S) and 'simple S' (shuffling of S across all trials)
% parameters used to discretize activity and stimulus
opts.n_bins = 2;
opts.Sbins = 2;
opts.doShuff = 1; % set 1 to compute permutations
opts.nShuff = 10; % number of permutations (we then combine them across subjects to get 500 samples of the average information null in the plot script)

% General settings
params.nSubj = 15;
params.nSess = 4;
params.nWorkers = 30;
params.all_trials = 1; % set to 1 to run the information transmission analysis using all trials
params.correrr_trials = 1; % set to 1 to run information transmission analysis separately for correct and error trials
params.corrSubsamp = 'rand'; % type of correct trials subsampling
params.randSeed = 0; % for reproducibility
params.freqs_lim = [40, 75]; % Hz
params.maxDelay = 30; % x16.7ms = 500ms of maxDelay
params.tMin = -0.6; % seconds
params.tMax = 0.5; % seconds
params.info_type = {'DI','FIT_S','FIT_C','DFI_S','DFI_C'};
params.stim_feature = 'average'; % average, first_sample, running_avg

load([paths.dataPath, '/times'])
params.timePoints = (find(times==params.tMin):find(times==params.tMax)); % [0, 1.2]s
load([paths.dataPath, '/roinames_pair']) % ROIs in neural data are ordered as left/right/left/right/...
load([paths.dataPath, '/frequencies'])
params.totROIs = numel(rois);
params.totTimePoints = numel(times);

rng(params.randSeed)
%% Loop over frequency bands

for bandIdx = 1:size(params.freqs_lim,1) % In this paper we only used the [40,75]Hz band (gamma)
    % Create frequency bands names
    start_freq = num2str(params.freqs_lim(bandIdx,1));
    end_freq = num2str(params.freqs_lim(bandIdx,2));
    band_name = ['band_',start_freq,'_',end_freq];
    disp(['----- Computing Information Transfer within ',band_name,' -----'])
    params.freqs_labels{bandIdx} = band_name;
    params.fixed_rec_delay.(band_name) = params.fixed_rec_delay.all(bandIdx);
    % Check existence of raw data for this frequency band 
    assert(exist([paths.dataPath,'/Wilming_bands/',band_name],'dir'),['No frequency band defined from ',num2str(params.freqs_lim(bandIdx,1)),' to ',num2str(params.freqs_lim(bandIdx,2)),'Hz']);
    
    band_label = params.freqs_labels{bandIdx};
    
    %% Selected ROI pairs (symm. fwd vs bwd) and time-delay windows

    selectedROIs.(band_label) = {'V1','V3A','LO3'}; % V1: main broadcaster of visual information; V3A: region carrying top stim info in DLS visual; V3CD and LO3: top 2 regions carrying stim info in MT+ complex

    disp([num2str(numel(selectedROIs.(band_label))),' ROIs selected for the computation'])

    for rIdx = 1:numel(selectedROIs.(band_label))
        rois_idx(rIdx) = find(strcmp(rois,selectedROIs.(band_label){rIdx}));
        %params.timePoints.S = 
    end
    roiPairs = numel(selectedROIs.(band_label))*(numel(selectedROIs.(band_label))-1)/2; 

    % Should we allow different max delay for each pair?
    params.direction = {'X2Y','Y2X'};
    params.frequency_bands = {'h2h','l2l','h2l','l2h'}; % e.g. h2h = high to high, l2l = low to low, etc...
    % Here we select 1D to compute information transfer within each
    % hemisphere and then average
    params.neural_feature = '1D'; % 1D = average of 1D activity
    params.selected_features = {'left','right'};

    disp(['Info transfer analysis for the following neural features: ',params.selected_features{:}])

    % If save path does not exist create it
    paths.savePath = [paths.mainPath 'Results/' params.codeVersion '/' band_label '/' params.neural_feature];
    if ~exist(paths.savePath, 'dir')
       mkdir(paths.savePath)
    end

    % We take the full selected time window to compute TE and FIT
    for roiIdx=1:numel(selectedROIs.(band_label))
        roiLab = matlab.lang.makeValidName(selectedROIs.(band_label){roiIdx});
        sigTime.(band_label).(roiLab) = params.timePoints;
    end

%%
    if (params.nWorkers > 1)
        if (isempty(gcp('nocreate')))
            parpool(params.nWorkers);
        else
            poolTemp = gcp;
            if(poolTemp.NumWorkers < params.nWorkers)
                delete(gcp)
                parpool(params.nWorkers);
            end
        end
    end

    % Load computed pairs identities (to avoid recomputing pairs twice)
    computedPairs.(band_label) = zeros(params.totROIs,params.totROIs);
    matFiles=dir(fullfile([paths.savePath],'*.mat'));
    for k = 1:length(matFiles)
        tmpROIs = load(fullfile(paths.savePath,matFiles(k).name),'roiX','roiY');
        computedPairs.(band_label)(tmpROIs.roiX,tmpROIs.roiY)=1;
    end

    % First loop over roi Pairs --> produce mat file for each ROI pair
    parfor pairIdx = 1:roiPairs
        [tmproiX,tmproiY] = provide_pair(numel(selectedROIs.(band_label)),pairIdx);

        Xlab = selectedROIs.(band_label){tmproiX};
        Ylab = selectedROIs.(band_label){tmproiY};
        % Get indexes for the pair of selected rois
        roiX = rois_idx(tmproiX);
        roiY = rois_idx(tmproiY);

        % Compute X to Y
        if ~computedPairs.(band_label)(roiX,roiY)
            compute_infoTransf_pair(Xlab,Ylab,roiX,tmproiX,roiY,tmproiY,sigTime,opts,params,paths,band_label)
        end % end if ~computed 

        % Compute Y to X
        if ~computedPairs.(band_label)(roiY,roiX)
            compute_infoTransf_pair(Ylab,Xlab,roiY,tmproiY,roiX,tmproiX,sigTime,opts,params,paths,band_label)
        end % end if ~computed 
    end
end