%% History-Independent Analysis of Higher-Order Coordination
clear all;
clc;

%% Hyperparameters
alpha = 1e-3; % Learning Rate
sigma = 0.001; % Size of Significance Test
Nthr = 1; % Threshold for reliable interactions

W=10;
beta = 0.975;

% This loads a saved simulated ensemble
load('SimData.mat');
T = size(n,2);
L = size(n,1);
L_star = 2^L - 1;
K=T/W;

%% Identify reliable interactions
incl_idx = [];
X_star = [];
id_list = 1:L_star;
for idx = id_list
    bi_idx = de2bi(idx,L);
    idx0 = find(bi_idx==0); idx1 = find(bi_idx==1);
    
    tmp1 = prod(n(idx1,:),1);
    if sum(tmp1) == 0
        id_gr = id_list(id_list>=idx);
        for ii=1:numel(idx1)
            id_gr_bits(ii,1:numel(id_gr)) = bitget(id_gr, idx1(ii));
        end
        elim_id = find(sum(id_gr_bits)==numel(idx1));
        id_list( ismember(id_list, elim_id + idx-1) ) = [];
    else
        tmp = max(0, prod(n(idx1,:),1)-double(sum(n(idx0,:),1)>0))';
        if sum(tmp)>Nthr
            X_star = [X_star tmp];
            incl_idx = [incl_idx idx];
        end        
    end
end

ord_idx = sum( de2bi(incl_idx, L) ,2);

n_star = X_star';
ng = sum(n_star, 1);

%% History-Independent Model
M = length(incl_idx);

% Gradient Descent
theta_star = zeros(M,K);
lambdastar_est = zeros(M,K);
lambdagstar_est = zeros(1,K);

Xb_star = zeros(K, M);
for k=1:K
	Xb_star(k,:) = sum( X_star((k-1)*W+1:k*W,:), 1)/W;
end

X = zeros(size(Xb_star(1,:)))';
for k=1:K
	X = beta*X + Xb_star(k,:)';
    for gditer=1:2000
        grad = W * ( X - ((1-beta^k)/(1-beta))*( exp(theta_star(:,k)) )/( 1+sum(exp(theta_star(:,k))) ) );
        theta_star(:,k) = theta_star(:,k) + alpha*grad;
    end
    
	lambdastar_est(:,k) = exp(theta_star(:,k))/( 1+sum( exp(theta_star(:,k)) ) );
    lambdagstar_est(k) = sum(lambdastar_est(:,k));
end

%% Statistical Inference
%%% CIF of spiking events
lambda_est = zeros(L,K);
bi_incl_idx = de2bi(incl_idx, L);
for l=1:L
    tmp = find( bi_incl_idx(:,l) );
    lambda_est(l,:) = sum( lambdastar_est(tmp, :) );
end
%%% Log-odds of simultaneous spiking events (independent neurons)
bi_incl_idx = de2bi(incl_idx, L);
odds = zeros(M, K);
for l=1:M
    k_tmp = find( bi_incl_idx(l,:) );
    odds(l,:) = prod( lambda_est(k_tmp,:), 1 ) / prod( 1 - lambda_est(k_tmp,:), 1 );
end

ord = 1:L;
Jstat_ord = cell(numel(ord),1);
h_ord = cell(numel(ord),1);
Ms = zeros(numel(ord),1);
Devs = cell(numel(ord), 1);
nus = cell(numel(ord), 1);

%%% Test for Significant rth-Order Coordinated Spiking
ordRidx = [];
for r = ord
    ordRidx = sort(find(ord_idx == r), 'ascend');    
    if isempty(ordRidx) || r==1
        Jstat_ord{r} = zeros(K,1);
        h_ord{r} = zeros(K,1);
        Ms(r) = M - length(ordRidx);
        Devs{r} = zeros(K,1);
        nus{r} = zeros(K,1);
        continue;
    end
    
    [Dev, Md, nu, gamma] = SynchTest_dynamic(theta_star, X_star, W, odds, ordRidx, alpha, beta);
    
	h = and(Md>0, (1 - sigma) < chi2cdf(Dev, Md));
    Jstat = zeros(K,1);
    for k=1:K
        Jstat(k) = ( 1 - sigma - ncx2cdf( chi2inv( 1-sigma, Md), Md, nu(k) ) );
    end
    ex_in = -sign(sum(gamma,2));
    Jstat = ex_in.*(h.*Jstat);
    
    Jstat_ord{r} = Jstat;
    h_ord{r} = h;
    Ms(r) = Md;
    Devs{r} = Dev;
    nus{r} = nu;
end

%%
Jstat_im = zeros(L, K);
for r=1:L
    Jstat_im(r,:) = Jstat_ord{r}';
end
h_im = zeros(L, K);
for r=1:L
    h_im(r,:) = h_ord{r}';
end

n_Rords = zeros(L, size(n,2));
for r = 1:L
    tmp = sort(find(ord_idx == r), 'ascend');
    n_Rords(r,:) = sum( n_star(tmp,:) , 1 );

end

figure;

subplot(6,1,5:6)
imagesc(kron(Jstat_im(2:end,:), ones(1,1)), [-1 1]); colormap redblue;
yticks('');
xticks([0:T/6:T]/W); xticklabels('');

tmp = n_Rords;
hspacing = 5; %>=1, integer
vspacing = 1.5; %>=1
subplot(6,1,3:4)
hold on;
for ii=1:size(tmp,2)
    for jj=2:size(tmp,1)
        if tmp(jj,ii)
            set(line, 'XData', [hspacing,hspacing]*ii, 'YData', (size(tmp,1) - jj)*vspacing+[-0.5, 0.5]+1, 'Color', 'k');
        end
    end
end
hold off;
ylim([vspacing/2 vspacing*(size(tmp,1))-vspacing/2]); yticklabels(''); yticks('');
xlim([0, T*hspacing]); xticks([0:T/6:T]*hspacing); xticklabels('');

tmp = n(:, :);
hspacing = 5; %>=1, integer
vspacing = 1.5; %>=1
subplot(6,1,1:2)
hold on;
for ii=1:size(tmp,2)
    for jj=1:size(tmp,1)
        if tmp(jj,ii)
            set(line, 'XData', [hspacing,hspacing]*ii, 'YData', jj*vspacing+[-0.5, 0.5]+1, 'Color', 'k');
        end
    end
end
hold off;
ylim([vspacing/2 vspacing*(size(tmp,1))+2+vspacing/2]); yticklabels(''); yticks('');
xlim([0, T*hspacing]); xticks([0:T/6:T]*hspacing); xticklabels('');
