% FIT multi-feature transfer
% X = S1 + D*S2 + Ex

clear all; %close all;

rng(1) % For reproducibility
save_results =  0;% set to 1 to save results file
results_path =  ['']; % path to save results
plot_path =  ['']; % path to save results

if ~exist(plot_path,'dir')
    mkdir(plot_path);
end

% Simulation parameters
nTrials_per_stim = 500; % number of trials per stimulus value
simReps = 50; % repetitions of the simulation
nShuff = 10; % number of permutations (used for both FIT permutation tests)

delta_range = [0:0.1:1]; % range of D parameter (relative S2 encoding strength in node X)
noise = 1;
Xencoding = [-1 1];

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 3; 
opts.n_binsY = 3; 
opts.n_binsS = 2; % Number of stimulus values

shuff_types = {'cond','simple'}; % 'cond' in the shuffling of X at fixed S, 'simple' is the shuffling of S across all trials
null2plot = 'max'; % null hypothesis to use (maximum between 'simple' and 'cond')
n_boot = 500; % number of sample of the null distribution
prctile_plt = 99; % percentile used to determine significance

% Initialize structures
fitS1 = nan(simReps,numel(delta_range));
fitS2 = nan(simReps,numel(delta_range));
fitSall = nan(simReps,numel(delta_range));

fitS1Sh.simple = nan(simReps,numel(delta_range),nShuff);
fitS1Sh.cond = nan(simReps,numel(delta_range),nShuff);
fitS2Sh.simple = nan(simReps,numel(delta_range),nShuff);
fitS2Sh.cond = nan(simReps,numel(delta_range),nShuff);
diS1 = fitS1; diS1Sh = fitS1Sh; diS2 = fitS1; diS2Sh = fitS1Sh;  % defining di for S1 and S2 just to check that they are equal
%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition ',num2str(repIdx)])
    
    nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials
    
    for dIdx = 1:numel(delta_range)
        % simulating two independent stimulus-features S1 and S2
        S1 = randi(opts.n_binsS,[1,nTrials]);
        S2 = randi(opts.n_binsS,[1,nTrials]);

        X = encoding_function(S1, Xencoding, 1, 0)+encoding_function(S2, Xencoding, delta_range(dIdx), 0)+noise*randn(1,nTrials);
        Yt = X+noise*randn(1,nTrials);
        hY = noise*randn(1,nTrials);

        edgs = eqpop(X, opts.n_binsX);
        [~,bX] = histc(X, edgs);

        edgs = eqpop(Yt, opts.n_binsY);
        [~,bYt] = histc(Yt, edgs);

        edgs = eqpop(hY, opts.n_binsY);
        [~,bhY] = histc(hY, edgs);

        Sall = (S1 - 1) .* opts.n_binsS + S2;
        
        % compute TE and FIT
        [diS1(repIdx,dIdx),~,fitS1(repIdx,dIdx)]=...
            compute_FIT_TE(S1, bX, bYt, bhY);
        [diS2(repIdx,dIdx),~,fitS2(repIdx,dIdx)]=...
            compute_FIT_TE(S2, bX, bYt, bhY);
        [diSall(repIdx,dIdx),~,fitSall(repIdx,dIdx)]=...
            compute_FIT_TE(Sall, bX, bYt, bhY);
        
        % loop over shufflings
        for shIdx = 1:nShuff

            % conditioned shuff (shuffle X at fixed S)
            S1val = unique(S1);
            for Ss = 1:numel(S1val)
                idx = (S1 == S1val(Ss));
                tmpX = bX(idx);
                ridx = randperm(sum(idx));
                XSh(1,idx) = tmpX(ridx);
            end

            [diS1Sh.simple(repIdx,dIdx,shIdx),~,fitS1Sh.simple(repIdx,dIdx,shIdx)]=...
                compute_FIT_TE(S1, XSh, bYt, bhY);

            % simple shuff (shuffle X across all trials)
            idx = randperm(nTrials);
            S1sh = S1(idx);
            XSh = bX(idx);

            [~,~,fitS1Sh.cond(repIdx,dIdx,shIdx)]=...
                compute_FIT_TE(S1sh, bX, bYt, bhY);
            [diS1Sh.cond(repIdx,dIdx,shIdx)]=...
                DI_infToolBox(XSh, bYt, bhY, 'naive', 0);
            
            % conditioned shuff (shuffle X at fixed S)
            S2val = unique(S2);
            for Ss = 1:numel(S2val)
                idx = (S2 == S2val(Ss));
                tmpX = bX(idx);
                ridx = randperm(sum(idx));
                XSh(1,idx) = tmpX(ridx);
            end

            [diS2Sh.simple(repIdx,dIdx,shIdx),~,fitS2Sh.simple(repIdx,dIdx,shIdx)]=...
                compute_FIT_TE(S2, XSh, bYt, bhY);

            % simple shuff (shuffle X across all trials)
            idx = randperm(nTrials);
            S2sh = S2(idx);
            XSh = bX(idx);

            [~,~,fitS2Sh.cond(repIdx,dIdx,shIdx)]=...
                compute_FIT_TE(S2sh, bX, bYt, bhY);
            [diS2Sh.cond(repIdx,dIdx,shIdx)]=...
                DI_infToolBox(XSh, bYt, bhY, 'naive', 0);
        end

  
    end
end

if  save_results
    fname = ['multiFeat_transfer_' date '.mat'];
    save([fname])
end

    
%% Plot FIT and TE trends as a function of the relative S2 encoding strength D

figure('Position',[360,198,245,420])
maxY = 0.5; 
clear h
subplot(2,1,1)
hold on
h(1)=plot(delta_range,squeeze(mean(fitS1,1)),'color','b');
shadedErrorBar(delta_range,squeeze(mean(fitS1,1)),squeeze(std(fitS1,[],1))/sqrt(size(fitS1,1)),'LineProps',{'color','b'},'patchSaturation',0.2)
h(2)=plot(delta_range,squeeze(mean(fitS2,1)),'color','r');
shadedErrorBar(delta_range,squeeze(mean(fitS2,1)),squeeze(std(fitS2,[],1))/sqrt(size(fitS2,1)),'LineProps',{'color','r'},'patchSaturation',0.2)
h(3)=plot(delta_range,squeeze(mean(fitSall,1)),'color','g');
shadedErrorBar(delta_range,squeeze(mean(fitSall,1)),squeeze(std(fitSall,[],1))/sqrt(size(fitSall,1)),'LineProps',{'color','g'},'patchSaturation',0.2)
xlabel('delta S2')
ylabel('[bits]')
ylim([0,maxY])
title('FIT')
legend([h(1) h(2) h(3)], {'FIT_{S1}','FIT_{S2}','FIT_{S}'})

subplot(2,1,2)
hold on
plot(delta_range,squeeze(mean(diS1,1)),'color','b')
shadedErrorBar(delta_range,squeeze(mean(diS1,1)),squeeze(std(diS1,[],1))/sqrt(size(diS1,1)),'LineProps',{'color','b'},'patchSaturation',0.2)
xlabel('delta S2')
ylabel('[bits]')
ylim([0,0.5])
title('TE')

ylabel('bits')
