% Compute directed graphs from information transmission measures
% Fixed bug on emitter frequency

function [connStrenghts,sessConnStrenghts,subjConnStrengths,staticNetworks,sessStaticNetworks,staticNetworks_vis,sessConnStrenghtsSh, pooledTimeDelayMaps, pooledTimeDelayMapsSh, sessTimeDelayMaps, sessTimeDelayMapsSh] = ...
    compute_connectivity_visNet_v2(infoTransf_maps,infoTransfSh_maps,sigTimes,computedROIs,sel_time_window,times,computedPairs,computedROIsGroups,rois,params,doShuff)
% v1 + left/right hemisphere

totROIs = size(computedPairs,1);
ROIpairs = length(fieldnames(infoTransf_maps));
computedIdxs = find(computedPairs==1);

time_idxs(1) = find(times == sel_time_window(1)); % wrt times array
time_idxs(2) = find(times == sel_time_window(2));

staticNetworks = []; connStrenghts = [];
sessConnStrenghts = []; sessStaticNetworks = []; staticNetworks_vis = [];
subjConnStrengths = [];

freqLab = params.band_label;
for infoIdx = 1:numel(params.info_type)
    infoLab = params.info_type{infoIdx};
        
    recFreq = params.band_label;
    emitFreq = params.band_label;

    connStrenghts.(infoLab).(freqLab) = zeros(numel(computedROIs),numel(computedROIs));
    sessConnStrenghts.(infoLab).(freqLab) = zeros(params.nSubj*params.nSess,numel(computedROIs),numel(computedROIs));
    sessConnStrenghtsSh.(infoLab).(freqLab) = zeros(params.nSubj*params.nSess,numel(computedROIs),numel(computedROIs),params.nShuff);
    subjConnStrengths.(infoLab).(freqLab) = zeros(params.nSubj,numel(computedROIs),numel(computedROIs));

    pooledTimeDelayMaps.(infoLab).(freqLab) = zeros(numel(computedROIs),numel(computedROIs),time_idxs(2)-time_idxs(1)+1,numel(params.minSelDelay:params.maxSelDelay));
    pooledTimeDelayMapsSh.(infoLab).(freqLab) = zeros(numel(computedROIs),numel(computedROIs),time_idxs(2)-time_idxs(1)+1,numel(params.minSelDelay:params.maxSelDelay),params.nShuff);
    sessTimeDelayMapsSh.(infoLab).(freqLab) = zeros(params.nSubj*params.nSess,numel(computedROIs),numel(computedROIs),time_idxs(2)-time_idxs(1)+1,numel(params.minSelDelay:params.maxSelDelay),params.nShuff);

    for pairIdx = 1:ROIpairs
%         pairIdx
%         if pairIdx == 79
%            1+1; 
%         end
        
        tmproiY = ceil(computedIdxs(pairIdx)/totROIs); % overall idx
        tmproiX = computedIdxs(pairIdx)-(tmproiY-1)*totROIs;

        %[tmproiX,tmproiY] = provide_pair(numel(computedROIs),pairIdx);

        Xlab = rois{tmproiX};
        Ylab = rois{tmproiY};

        % Find idx relative to the computedROIs
        relatXIdx = find(strcmp(computedROIs,Xlab));
        relatYIdx = find(strcmp(computedROIs,Ylab));

        pairLab = ['pair_',Xlab,'_',Ylab];
        pairLab = matlab.lang.makeValidName(pairLab);

%         neural_feat = fieldnames(infoTransf_maps.(pairLab).(freqLab));
%         neural_feat = neural_feat{1};

        % Find significant receiver time point
        validYroi=matlab.lang.makeValidName(Ylab);
        receiverTpoints = sigTimes.(validYroi);
        validXroi=matlab.lang.makeValidName(Xlab);
        emitterTpoints = sigTimes.(validXroi);
        
        % Define time window to average result for the specific pair so
        % that it only include points within the selected time window
        % 'sel_time_window'

        % check that the receiver is reciving in this frequency band at
        % the selected time points & emitter is emitting in these time
        % points & there is at least an emitting time point preceeding the
        % last receiving time point
        if ~isempty(intersect(receiverTpoints,time_idxs(1):time_idxs(2))) & ~isempty(intersect(emitterTpoints,time_idxs(1):time_idxs(2))) & (emitterTpoints(1)<receiverTpoints(end))
            % Remember that time is always w.r.t. receiver!
            % Check whether the first selected time point belongs to the receiving window
            if (min(receiverTpoints) < time_idxs(1)) & (max(receiverTpoints) > time_idxs(1)) % initial time point is in between the receiver window
                time_window(1) = find(receiverTpoints == time_idxs(1));
            else 
                time_window(1) = 1; % initial time point is before the receiver window
            end
            
            % Check whether the first selected time point belongs to the emitter window
            if (min(emitterTpoints) < time_idxs(1)) & (max(emitterTpoints) > time_idxs(1)) % initial time point is in between the receiver window
                startEmitter = find(emitterTpoints == time_idxs(1));
            else 
                startEmitter = 1; % initial time point is before the receiver window
            end

            % Check whether the second selected time point belongs to the receiving window
            if (max(receiverTpoints) > time_idxs(2)) & (min(receiverTpoints) < time_idxs(2)) % end time point is in between the receiver window
                time_window(2) = find(receiverTpoints == time_idxs(2));
            else 
                time_window(2) = numel(receiverTpoints); % initial time point is after the receiver window
            end
            %time_window(2) = time_window(2) - receiverTpoints(2) + 1; % Find relative time index wrt first significant time point of the receiver
            tmpMinDelay = receiverTpoints(time_window(1))-emitterTpoints(startEmitter); % used for delay-averaging --> in this way we only consider information emitted after time_idxs(1)
            
            for featLab = fieldnames(infoTransf_maps.(pairLab).(freqLab))'
                if isfield(infoTransf_maps.(pairLab).(freqLab).(featLab{1}),infoLab) % corr and err miss FIT_C
                    
                    dim_inconsistency = [];
%                     % check that infomaps has correct number of dimnsions
%                     if numel(size(infoTransf_maps.(pairLab).(freqLab).(featLab{1}).(infoLab)))<4 % This control is not needed anymore in current version of the analysis
%                         disp(['Found a dimension inconsistency in pair ',Xlab,' to ',Ylab,' --> automatic fixing'])
%                         if numel(receiverTpoints) == 1 % time dimension (3) is 1 --> missing
%                             disp('Time dimension = 1')
%                             dim_inconsistency = 'time';
%                             tmp = zeros(params.nSubj,params.nSess,1,size(infoTransf_maps.(pairLab).(freqLab).(featLab{1}).(infoLab),3));
%                             tmp(:,:,1,:) = infoTransf_maps.(pairLab).(freqLab).optim.(infoLab);
%                             infoTransf_maps.(pairLab).(freqLab).(featLab{1}).(infoLab) = tmp;
%                         else % delay dimension (4) is 1 --> missing
%                             disp('Delay dimension = 1')
%                             dim_inconsistency = 'delay';
%                             tmp = zeros(params.nSubj,params.nSess,size(infoTransf_maps.(pairLab).(freqLab).(featLab{1}).(infoLab),3),1);
%                             tmp(:,:,:,1) = infoTransf_maps.(pairLab).(freqLab).optim.(infoLab);
%                             infoTransf_maps.(pairLab).(freqLab).(featLab{1}).(infoLab) = tmp;
%                         end
%                     end

                    assert(numel(size(infoTransf_maps.(pairLab).(freqLab).(featLab{1}).(infoLab)))==4, 'InfoTransfMaps have wrong number of dimansions')
                    
                    if strcmp(params.neural_feature,'1D')
                        hemAvgSignal = 0.5*(infoTransf_maps.(pairLab).(freqLab).left.(infoLab) + infoTransf_maps.(pairLab).(freqLab).right.(infoLab));
                    elseif strcmp(params.neural_feature,'2D')
                        hemAvgSignal = infoTransf_maps.(pairLab).(freqLab).joint.(infoLab);
                    elseif strcmp(params.neural_feature,'optim')
                        hemAvgSignal = infoTransf_maps.(pairLab).(freqLab).optim.(infoLab);
                    end
                    timeDelayMap = squeeze(hemAvgSignal(:,:,time_window(1):time_window(2),params.minSelDelay:params.maxSelDelay));
                    
                    if doShuff
                        if strcmp(params.neural_feature,'1D')
                            hemAvgSignalSh = 0.5*(infoTransfSh_maps.(pairLab).(freqLab).left.(infoLab)(:,:,:,:,1:params.nShuff) + infoTransfSh_maps.(pairLab).(freqLab).right.(infoLab)(:,:,:,:,1:params.nShuff));
                        elseif strcmp(params.neural_feature,'2D')
                            hemAvgSignalSh = infoTransfSh_maps.(pairLab).(freqLab).joint.(infoLab)(:,:,:,:,1:params.nShuff);
                        elseif strcmp(params.neural_feature,'optim')
                            hemAvgSignalSh = infoTransfSh_maps.(pairLab).(freqLab).optim.(infoLab)(:,:,:,:,1:params.nShuff);
                        end
                        timeDelayMapSh = squeeze(hemAvgSignalSh(:,:,time_window(1):time_window(2),params.minSelDelay:params.maxSelDelay,:)); % Average over sessions
                    end
                    
                    % Pool time-delay maps
                    pooledTimeDelayMaps.(infoLab).(freqLab)(relatXIdx,relatYIdx,:,:) = mean(mean(timeDelayMap,1),2); % mean across sessions and subjects
                    if doShuff
                        pooledTimeDelayMapsSh.(infoLab).(freqLab)(relatXIdx,relatYIdx,:,:,:) = mean(mean(timeDelayMapSh,1),2); % mean across sessions and subjects
                    end
                    
                    % Compute single-subject time-delay maps
                    for subjIdx = 1:params.nSubj
                        for sessIdx = 1:params.nSess
                            subSessIdx = (subjIdx-1)*params.nSess+sessIdx;
                            %subSessIdx
                            if params.nSubj == 1
                                sessTimeDelayMaps = squeeze(timeDelayMap(sessIdx,:,:));
                            else
                                sessTimeDelayMaps = squeeze(timeDelayMap(subjIdx,sessIdx,:,:));
                            end
                            
                            if doShuff
                                if params.nSubj == 1
                                    tmpSessTimeDelayMapSh = squeeze(timeDelayMapSh(sessIdx,:,:,:));
                                else
                                    tmpSessTimeDelayMapSh = squeeze(timeDelayMapSh(subjIdx,sessIdx,:,:,:));
                                end
                            end
                            % Restrict time-delay points to the ones where both
                            % the emitted and the received information belong
                            % to the selected time window
                            for tmpT = 1:numel(time_window(1):time_window(2))
                                if tmpMinDelay+tmpT >= 1 % the condition can be false if the receiver carried information before the emitter
                                    sessTimeDelayMaps(tmpT,tmpMinDelay+tmpT:end) = 0;
                                    if doShuff
                                        tmpSessTimeDelayMapSh(tmpT,tmpMinDelay+tmpT:end,:) = 0;
                                    end
                                end
                            end
                            
                            if strcmp(dim_inconsistency,'time')
                                tmp = sessTimeDelayMaps;
                                sessTimeDelayMaps(:) = zeros(1,numel(tmp));
                                sessTimeDelayMaps = tmp;
                                if doShuff
                                    tmp = tmpSessTimeDelayMapSh;
                                    tmpSessTimeDelayMapSh = zeros(1,size(tmp,1),size(tmp,2));
                                    tmpSessTimeDelayMapSh(1,:,:) = tmp;
                                end
                            elseif strcmp(dim_inconsistency,'delay')
                                tmp = sessTimeDelayMaps;
                                sessTimeDelayMaps = zeros(numel(tmp),1);
                                sessTimeDelayMaps(:) = tmp;
                                if doShuff
                                    tmp = tmpSessTimeDelayMapSh;
                                    tmpSessTimeDelayMapSh = zeros(size(tmp,1),1,size(tmp,2));
                                    tmpSessTimeDelayMapSh(:,1,:) = tmp;
                                end
                            end

                            if doShuff
                                sessTimeDelayMapsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,:,:,:) = tmpSessTimeDelayMapSh;
                            end
                            % Compute the information value from the time-delay map
                            if strcmp(params.delay_type,'mean')
                                if strcmp(params.info_time_pooling,'mean')
                                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = nanmean(avgAcrossDelays_v2(sessTimeDelayMaps,tmpMinDelay)); % Average over time, delays
                                    if doShuff
                                        for shIdx = 1:params.nShuff
                                            sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,shIdx) = nanmean(avgAcrossDelays_v2(squeeze(tmpSessTimeDelayMapSh(:,:,shIdx)),tmpMinDelay)); % Average over time, delays
                                        end
                                    end
                                elseif strcmp(params.info_time_pooling,'sum')
                                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = sum(avgAcrossDelays_v2(sessTimeDelayMaps,tmpMinDelay),'omitnan'); % Average delays, sum over time
                                    if doShuff
                                        for shIdx = 1:params.nShuff
                                            sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,shIdx) = sum(avgAcrossDelays_v2(squeeze(tmpSessTimeDelayMapSh(:,:,shIdx)),tmpMinDelay)); % Average over time, delays
                                        end
                                    end
                                elseif strcmp(params.info_time_pooling,'maximum')
                                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = max(avgAcrossDelays_v2(sessTimeDelayMaps,tmpMinDelay),[],'omitnan');    
                                    if doShuff
                                        for shIdx = 1:params.nShuff
                                            sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,shIdx) = max(avgAcrossDelays_v2(squeeze(tmpSessTimeDelayMapSh(:,:,shIdx)),tmpMinDelay)); % Average over time, delays
                                        end
                                    end
                                end
                            elseif strcmp(params.delay_type,'maximum')
                                % We apply the maximum delay idxs obtained for pooled subjects so that the average information transmission values across subjects matches the pooledSubjects value 
                                % (maximum operation is not linear, does not commute with averaging)
                                [vals,delIdxs]=max(sessTimeDelayMaps,[],2); 
                                if doShuff
                                    [valsSh,delIdxs]=max(tmpSessTimeDelayMapSh,[],2); 
                                    valsSh = squeeze(valsSh);
                                end
                                if strcmp(params.info_time_pooling,'mean')
                                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = nanmean(vals);
                                    if doShuff
                                        sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,:) = nanmean(valsSh,1);
                                    end
                                elseif strcmp(params.info_time_pooling,'sum')
                                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = sum(vals,'omitnan'); 
                                    if doShuff
                                        sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,:) = sum(valsSh,1,'omitnan');
                                    end
                                elseif strcmp(params.info_time_pooling,'maximum')
                                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = max(vals,[],'omitnan');  
                                    if doShuff
                                        sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,:) = max(valsSh,[],1,'omitnan');
                                    end
                                end
                            end
                            if params.zeros_with_nan
                                if (sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx)==0)
                                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = nan;
                                end
                            end
                        end
                        subjConnStrengths.(infoLab).(freqLab)(subjIdx,relatXIdx,relatYIdx) = nanmean(sessConnStrenghts.(infoLab).(freqLab)(subSessIdx-(params.nSess-1):subSessIdx,relatXIdx,relatYIdx),1); % Average over time, delays
                    end
                    % Compute pooled subjects connectivity strengths
                    connStrenghts.(infoLab).(freqLab)(relatXIdx,relatYIdx) = nanmean(sessConnStrenghts.(infoLab).(freqLab)(:,relatXIdx,relatYIdx),1); % Average over time, delays
                end
            end
        else % If no data are available for the pair in the selected window, set NaN
            for subjIdx = 1:params.nSubj
                for sessIdx = 1:params.nSess
                    subSessIdx = (subjIdx-1)*params.nSess+sessIdx;
                    sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx) = nan;
                    if doShuff
                        sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,relatXIdx,relatYIdx,:) = nan;
                    end
                end
                subjConnStrengths.(infoLab).(freqLab)(subjIdx,relatXIdx,relatYIdx) = nan; % Average over time, delays
            end
            connStrenghts.(infoLab).(freqLab)(relatXIdx,relatYIdx) = nan;
        end
    end

    staticNetworks.(infoLab).(freqLab) = digraph(connStrenghts.(infoLab).(freqLab),computedROIs);

    % Remove NaN links (for which no pair of time-delay was suitable for
    % information transfer) --> Should we set it to zero? We have to deeply
    % think about how we want to deal with this point, otherwise we might
    % introduce biases
    tmpNet = staticNetworks.(infoLab).(freqLab);
    staticNetworks.(infoLab).(freqLab) = rmedge(tmpNet,find(isnan(tmpNet.Edges.Weight)));

    for subjIdx = 1:params.nSubj
        for sessIdx = 1:params.nSess
            subSessIdx = (subjIdx-1)*params.nSess+sessIdx;
            tmpLinks = squeeze(sessConnStrenghts.(infoLab).(freqLab)(subSessIdx,:,:));
            sessStaticNetworks{subjIdx}.(infoLab).(freqLab) = digraph(tmpLinks,computedROIs);
            tmpNet = sessStaticNetworks{subjIdx}.(infoLab).(freqLab);
            sessStaticNetworks{subjIdx}.(infoLab).(freqLab) = rmedge(tmpNet,find(isnan(tmpNet.Edges.Weight)));
%             if doShuff
%                 for shIdx = 1:params.nShuff
%                     tmpLinksSh = squeeze(sessConnStrenghtsSh.(infoLab).(freqLab)(subSessIdx,:,:,shIdx));
%                     sessStaticNetworksSh{subjIdx}.(infoLab).(freqLab) = digraph(tmpLinksSh,computedROIs);
%                     tmpNet = sessStaticNetworks{subjIdx}.(infoLab).(freqLab);
%                     sessStaticNetworks{subjIdx}.(infoLab).(freqLab) = rmedge(tmpNet,find(isnan(tmpNet.Edges.Weight)));
%                 end
%             end
        end
    end

    % Define visual network
    vis_idxs = (computedROIsGroups==1); % idx of ROIs belonging to visual group
    staticNetworks_vis.(infoLab).(freqLab)=subgraph(staticNetworks.(infoLab).(freqLab),vis_idxs);
end

