% Script to compute \Delta TE (i.e. difference between TE computed at fixed stimulus value) 
% in Fig.S11B


% Paths and parameters settings

clear all; clc;

params.codeVersion = 'NIPS_paper';
paths.mainPath = pwd; paths.externalPath = pwd; paths.scriptsPath = pwd;
paths.dataPath = [paths.externalPath 'data']; % path where data are stored

addpath(genpath(paths.scriptsPath))

% parameters used to discretize the activity and stimulus
opts.n_bins = 2;
opts.Sbins = 2;

opts.doShuff = 1;
params.null_type = 'conditioned'; % Shufflings not used in this analysis
opts.nShuff = 1;

% General settings
params.nSubj = 15;
params.nSess = 4;
params.nWorkers = 30;
params.all_trials = 1; % set to 1 to run the information transmission analysis using all trials
params.correrr_trials = 0; % set to 1 to run information transmission analysis separately for correct and error trials
params.corrSubsamp = 'rand'; % type of correct trials subsampling, either 'mid' or 'rand'
params.randSeed = 0;
%params.freqs_lim = [1, 7;8, 10;15, 35;8, 35;40, 75];
params.freqs_lim = [40, 75]; %Hz
params.significant_comm_windows = 0; % choose to only consider time-delay pairs with significant stim/choice info
params.maxDelay = 30; % x16.7ms = 500ms (maxDelay)
params.tMin = -0.6; % seconds
params.tMax = 0.5; % seconds
params.info_type = {'DI','FIT_S','FIT_C'};
params.stim_feature = 'average'; % average, first_sample, running_avg

load([paths.dataPath, '/times'])
params.timePoints = (find(times==params.tMin):find(times==params.tMax)); % [0, 1.2]s
load([paths.dataPath, '/roinames_pair']) % ROIs in neural data are ordered as left/right/left/right/...
load([paths.dataPath, '/frequencies'])
params.totROIs = numel(rois);
params.totTimePoints = numel(times);

rng(params.randSeed)
%% Looping over frequency bands
for bandIdx = 1 % Here we only run frequency band 1 = gamma
    % Create frequency bands names
    start_freq = num2str(params.freqs_lim(bandIdx,1));
    end_freq = num2str(params.freqs_lim(bandIdx,2));
    band_name = ['band_',start_freq,'_',end_freq];
    disp(['----- Computing Information Transfer within ',band_name,' -----'])
    params.freqs_labels{bandIdx} = band_name;
    params.fixed_rec_delay.(band_name) = params.fixed_rec_delay.all(bandIdx);
    % Check existence of raw data for this frequency band 
    assert(exist([paths.dataPath,'/Wilming_bands/',band_name],'dir'),['No frequency band defined from ',num2str(params.freqs_lim(bandIdx,1)),' to ',num2str(params.freqs_lim(bandIdx,2)),'Hz']);
    
    band_label = params.freqs_labels{bandIdx};
    
    %% Selected ROI pairs (symm. fwd vs bwd) and time-delay windows

    selectedROIs.(band_label) = {'V1','V3A','LO3'}; % areas selected for the computation
    %selectedROIs.(band_label) = [selectedROIs.(band_label),'V1','V2','V3','V3A','V6A']; % adding visual ROIs based on TF masked union alpha + beta
    disp([num2str(numel(selectedROIs.(band_label))),' ROIs selected for the computation'])

    for rIdx = 1:numel(selectedROIs.(band_label))
        rois_idx(rIdx) = find(strcmp(rois,selectedROIs.(band_label){rIdx}));
        %params.timePoints.S = 
    end
    roiPairs = numel(selectedROIs.(band_label))*(numel(selectedROIs.(band_label))-1)/2; 

    % Should we allow different max delay for each pair?
    params.direction = {'X2Y','Y2X'};
    params.frequency_bands = {'h2h','l2l','h2l','l2h'}; % e.g. h2h = high to high, l2l = low to low, etc...
    % Here we select 1D to compute information transfer within each
    % hemisphere and then average
    params.neural_feature = '1D'; % 1D = average of 1D activity
    params.selected_features = {'left','right'};
    
    disp(['Info transfer analysis for the following neural features: ',params.selected_features{:}])

    paths.savePath = [paths.mainPath 'Results/' params.codeVersion '/' band_label '/' params.neural_feature];
    if ~exist(paths.savePath, 'dir')
       mkdir(paths.savePath)
    end

    % We take the full selected time window to compute TE and FIT
    for roiIdx=1:numel(selectedROIs.(band_label))
        roiLab = matlab.lang.makeValidName(selectedROIs.(band_label){roiIdx});
        sigTime.(band_label).(roiLab) = params.timePoints;
    end

%%
    if (params.nWorkers > 1)
        if (isempty(gcp('nocreate')))
            parpool(params.nWorkers);
        else
            poolTemp = gcp;
            if(poolTemp.NumWorkers < params.nWorkers)
                delete(gcp)
                parpool(params.nWorkers);
            end
        end
    end

    % Load computed pairs identities (to avoid running computation twice
    % for same pairs)
    computedPairs.(band_label) = zeros(params.totROIs,params.totROIs);
    matFiles=dir(fullfile([paths.savePath],'*.mat'));
    for k = 1:length(matFiles)
        tmpROIs = load(fullfile(paths.savePath,matFiles(k).name),'roiX','roiY');
        computedPairs.(band_label)(tmpROIs.roiX,tmpROIs.roiY)=1;
    end

    % First loop over roi Pairs --> produce mat file for each ROI pair
    parfor pairIdx = 1:roiPairs
        [tmproiX,tmproiY] = provide_pair(numel(selectedROIs.(band_label)),pairIdx);

        Xlab = selectedROIs.(band_label){tmproiX};
        Ylab = selectedROIs.(band_label){tmproiY};
        % Get indexes for the pair of selected rois
        roiX = rois_idx(tmproiX);
        roiY = rois_idx(tmproiY);

        % Compute X to Y
        if ~computedPairs.(band_label)(roiX,roiY)
            compute_deltaDI_pair(Xlab,Ylab,roiX,tmproiX,roiY,tmproiY,sigTime,opts,params,paths,band_label)
        end % end if ~computed 

        % Compute Y to X
        if ~computedPairs.(band_label)(roiY,roiX)
            compute_deltaDI_pair(Ylab,Xlab,roiY,tmproiY,roiX,tmproiX,sigTime,opts,params,paths,band_label)
        end % end if ~computed 
    end
end