clear;clc;
cd('G:\EEG_AV_speech');
load(['G:\EEG_AV_speech\data_dsp.mat'],'data_dsp','env_dsp');
load('G:\EEG_AV_speech\stimuli feature\word_surprisal_new.mat','wsurp');
load('G:\EEG_AV_speech\stimuli feature\word_onset_new.mat','wonset');
load('G:\EEG_AV_speech\stimuli feature\noise_envelop_dsp.mat','noise_env_dsp');
% data_dsp env_dsp noise_env_dsp filtered in 1-10hz, downsampled to 25 Hz

%%
clearvars -except env_dsp data_dsp wsurp wonset noise_env_dsp
addpath G:\analysis_code
D_pre = [25];
nd = length(D_pre);
fs_d = 25;
nsubs = 21;
nchan = 64;
nFeat = 4;
nCon = 3;
d = 1;

D = D_pre(d); % time lags

% pre = (D/fs_d)*0.2; % pre onset
pre = 0.2;

tikf = [1e-04 1e-03 1e-02 1e-01 1 1e+01 1e+02 1e+03 1e+04 inf];
% tikf = [        1e-03       1e-01 1 1e+01                   inf];
ntikf = length(tikf);

b_pass = fir1(fs_d*6,[1 10]/(fs_d/2),'bandpass');

wsurp_2 = wsurp;
wonset_2 = wonset;


delay_A = 0.2; % delay of audio caused by the matlab function sound()
delay_V = 0.058; % delay of video

% concatenate trials of each condition
clear A_all W_all WO_all AN_all
A_all = cell(1,3);
W_all = cell(1,3);
WO_all = cell(1,3);
AN_all = cell(1,3);

for j = 1:16
    clear A_0 W_0 WO_0 A_0 A W WO AN
    A_0 = env_dsp{j,1};
    W_0 = wsurp_2{j,1};
    WO_0 = wonset_2{j,1};
    AN_0 = noise_env_dsp{j,1};
    A  = zscore(A_0([(1+ceil((1-delay_A)*fs_d)):(end-fs_d*1)],1));
    
    W  = W_0([(1+ceil((1-delay_A)*fs_d)):(end-fs_d*1)],1);
    WO  = WO_0([(1+ceil((1-delay_A)*fs_d)):(end-fs_d*1)],1);
    
    AN  = zscore(AN_0([(1+ceil((1-delay_A)*fs_d)):(end-fs_d*1)],1));
    
    A_all{1,1} = cat(1,A_all{1,1},A);
    W_all{1,1} = cat(1,W_all{1,1},W);
    WO_all{1,1} = cat(1,WO_all{1,1},WO);
    AN_all{1,1} = cat(1,AN_all{1,1},AN);
end

A_all{1,2} = A_all{1,1};
W_all{1,2} = W_all{1,1};
WO_all{1,2} = WO_all{1,1};
AN_all{1,2} = AN_all{1,1};

A_all{1,3} = A_all{1,1};
W_all{1,3} = W_all{1,1};
WO_all{1,3} = WO_all{1,1};
AN_all{1,3} = AN_all{1,1};


% EEG
data_shift_all = cell(1,3);
condi{1,1} = [1:16];
condi{1,2} = [17:32];
condi{1,3} = [33:48];
for c = 1:3
    for j = condi{1,c}
        clear data_shift_sub
        for i = 1: nsubs
            clear data_0 data_1 data_shift
            data_0 = data_dsp{j,1}(:,:,i);
            data_1 = data_0([1+fs_d*1]:[end-ceil((1-delay_A)*fs_d)],:);
            data_shift = circshift(data_1,ceil(pre*fs_d),1);
            data_shift_sub(:,:,i) = data_shift;            
        end
        data_shift_all{1,c} = cat(1,data_shift_all{1,c},data_shift_sub);
    end
end

%% compute TRFs in three conditions: 
clc;
h_shift = zeros(D,nFeat,nchan,nsubs,nCon,ntikf);
CR_test = zeros(nchan,nsubs,nCon,ntikf);

for k = 1:ntikf 
    tic;
    for c = 1:3       
        for i = 1:nsubs
%             [h_shift(:,[1:3],:,i,c,k),CR_test(:,i,c,k)] = normRCtik_Z([A_all{1,c}';WO_all{1,c}';W_all{1,c}'],data_shift_all{1,c}(:,:,i)',D,tikf(k));
            [h_shift(:,:,:,i,c,k),CR_test(:,i,c,k)] = normRCtik_Z([A_all{1,c}';WO_all{1,c}';W_all{1,c}';AN_all{1,c}'],data_shift_all{1,c}(:,:,i)',D,tikf(k));
            X = sprintf('round %d subj %d condi %d',k,i,c);
            disp(X)
        end
    end
%     for c = 2:3        
%         for i = 1:nsubs
%             [h_shift(:,:,:,i,c,k),CR_test(:,i,c,k)] = normRCtik_Z([A_all{1,c}';WO_all{1,c}';W_all{1,c}';V_all{1,c}'],data_shift_all{1,c}(:,:,i)',D,tikf(k));
%             X = sprintf('round %d subj %d condi %d',k,i,c);
%             disp(X)
%         end
%     end
    toc;
end

save('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise.mat','h_shift','CR_test','A_all','WO_all','W_all','AN_all'); 

%%

load('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise.mat','CR_test'); 
tikf = [1e-04 1e-03 1e-02 1e-01 1 1e+01 1e+02 1e+03 1e+04 inf];
ntikf = length(tikf);
for k = 1:ntikf
    pred(:,k) = mean(CR_test(:,[1:6 8:21],:,k),[2 3]); % average across trial and subject
end
P = mean(pred,1);
id = find(P == max(P));
CR = CR_test(:,[1:6 8:21],:,id);

figure;
set(gcf,'Position',[200 100 700 500],'color','w');
plot(P,'Marker','s','LineWidth',3);
hold on
scatter(id,P(id),72,'s','r','filled');
xticks(linspace(1,10,10));
xticklabels({'10^{-4}','10^{-3}', '10^{-2}', '10^{-1}', '10^{0}','10^{1}','10^{2}','10^{3}','10^{4}', ' +inf'})
xlabel('λ');
yticks([0 0.01 0.02 0.03 0.04 0.05 0.06]);ylabel('Mean predictive accuracy');
ylim([0 0.06]);
xlim([0.5 10.5])
set(gca,'fontsize',18,'fontweight','bold');
set(gca,'linewidth',1.5);
box off

%%
figure;
set(gcf,'Position',[200 100 800 500],'color','w');
% topoplot(pred(:,id),'BioSemi64.loc','electrodes','on','maplimits',[-0.05 0.05]);
for c = 1:3
    subplot(1,3,c)
    topoplot(mean(CR_sig(:,:,c),2),'BioSemi64.loc','electrodes','on','maplimits',[-0.05 0.05]);
    colormap(jet);
    % c = colorbar;
    % c.Position = ([0.9 0.35 0.04 0.3]);
    % c.Label.String = 'predictive accuracy';
    % c.Limits = [0 0.05];
    % c.Ticks = [-0.05 0 0.05];
    % c.Box = 'off';
    % c.AxisLocation = 'in';
    title(num2str(mean(CR_sig(:,:,c),'all')));
    box off
    set(gca,'fontsize',18,'fontweight','bold');
    set(gca,'linewidth',1.5);
end


 %%
load('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise.mat','h_shift'); 
k = id;
h_shift_2 = h_shift(:,:,:,[1:6 8:21],:,k);

for c = 1:3
    clear h_A h_WO h_W h_AN
   for i = 1:size(h_shift_2,4) % sub
      for ch = 1:64
         for f = 1:nFeat
             clear temp
             temp = h_shift_2(:,f,ch,i,c);
             temp = temp-mean(temp([1:5],1));
             h_bc(:,f,ch,i,c) = temp;
         end 
      end
   end
   h_A  = squeeze(h_bc(:,1,:,:,c));
   h_WO = squeeze(h_bc(:,2,:,:,c));
   h_W  = squeeze(h_bc(:,3,:,:,c));
   h_AN  = squeeze(h_bc(:,4,:,:,c));
   
   h{1,c} = h_A;
   h{2,c} = h_WO;
   h{3,c} = h_W;
   h{4,c} = h_AN;
end
clear temp

h_all{1,1} = squeeze(mean(h_bc(:,1,:,:,:),5));
h_all{2,1} = squeeze(mean(h_bc(:,2,:,:,:),5));
h_all{3,1} = squeeze(mean(h_bc(:,3,:,:,:),5));
h_all{4,1} = squeeze(mean(h_bc(:,4,:,:,:),5));

save('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise_bc.mat','h','id','h_all','CR');

%%
clear;clc;
load('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise_bc.mat','h_all');
attr = {'Auditory','Onset','Surprisal','Noise'};
nFeat = 4;
achan = [12 32 47:49];
vchan = [27 28 29 30 64]; %28 Iz
fcchan = [11 38 46 47 48];
tchan = [23 25 60 62 24 61];%P7 PO7 P8 PO8 P9 P10

chan{1} = achan;
chan{2} = achan;
chan{3} = achan;
chan{4} = achan;

clear lim
lim{1} = [-0.15 0.15];
lim{2} = [-0.15 0.15];
lim{3} = [-0.15 0.15];
lim{4} = [-0.15 0.15];

for j = 1:nFeat
    figure;
    set(gcf,'Position',[200 100 900 500],'color','white');
    H = h_all{j,1};
    H_m = mean(H,3);  
    for ch = 1:64
        hold on
        if ismember(ch,chan{j})==1
            plot(H_m(:,ch),'LineWidth',2);            
        else
            plot(H_m(:,ch),':','LineWidth',0.5); 
        end
    end
    fsize = 18;
    ylabel('Amplitude (a.u.)');
    xticks([0 5 10 15 20 25]);xticklabels({'-200','0','200','400','600','800 (ms)'});
%     xticks([0 5 10 15 20 25 30]);xticklabels({'-200','0','200','400','600','800','1000'});
%     title(attr{j},'fontsize',fsize,'fontweight','bold');
    set(gca,'fontsize',fsize,'fontweight','bold');
    set(gca,'linewidth',1.5);
    ylim(lim{j})
    yticks([-0.15 0 0.15]);
    box off
    saveas(gcf,['.\figure_TRF_3feat_noise\plot_allchan' attr{j}  '.emf']);
    close;
end
%%
clear;clc;
load('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise_bc.mat','h_all');
load('.\surrogate_TRF_AVinput_3feat_impulse_new_onset_noise_all.mat','h_rand');
attr = {'Auditory','Onset','Surprisal','Noise'};
nFeat = 4;
achan = [12 32 47:49];
vchan = [27 28 29 30 64]; %28 Iz
fcchan = [11 38 46 47 48];
tchan = [23 25 60 62 24 61];%P7 PO7 P8 PO8 P9 P10

chan{1} = achan;
chan{2} = achan;
chan{3} = achan;
chan{4} = achan;

clear lim
lim{1} = [-0.15 0.15];
lim{2} = [-0.15 0.15];
lim{3} = [-0.15 0.15];
lim{4} = [-0.15 0.15];

for j = 1:nFeat
    figure;
    set(gcf,'Position',[200 100 900 500],'color','white');
    H = h_all{j,1};
    H_m = mean(H,3); 
    H_rand = h_rand{j,1};
    H_rand_m = mean(H_rand,[2 3]);
    for ch = 1:64
        plot(H_rand_m,'LineWidth',2,'Color','k')
        hold on
        if ismember(ch,chan{j})==1
            plot(H_m(:,ch),'LineWidth',2);            
        else
            plot(H_m(:,ch),':','LineWidth',0.5); 
        end
    end
    fsize = 24;
    ylabel('Amplitude (a.u.)');
    xticks([0 5 10 15 20 25]);xticklabels({'-200','0','200','400','600','800 (ms)'});

    set(gca,'fontsize',fsize,'fontweight','bold');
    set(gca,'linewidth',1.5);
    ylim(lim{j})
    yticks([-0.15 0 0.15]);
    box off
    saveas(gcf,['.\figure_TRF_3feat_noise\plot_allchan' attr{j}  '_with_surrogate_2.emf']);
    close;
end

%% find significant clusters
clear;clc;
cd('G:\EEG_AV_speech');
addpath G:\analysis_code
load('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise_bc.mat','h_all');
load('.\surrogate_TRF_AVinput_3feat_impulse_new_onset_noise_all.mat','h_rand');

% new channel location file was used in cluster-based permutation, where
% the adjacent electrodes are arranged together
clearvars -except h_all h_rand
load('BioSemi64_new_2.mat');
D = 25;
for j = 1:4
    clear X temp1 Y temp2
    X = zeros(D,64,20);
    X = h_all{j}(:,cell2mat(loc_new_2(:,1)),:); % project to new location
    temp1 = X;
    temp1([1:5],:,:) = 0; % exclude baseline (-200~0 lags)
    Y = zeros(D,64,20);
    Y = h_rand{j}(:,cell2mat(loc_new_2(:,1)),:); % project to new location
    temp2 = Y;
    temp2([1:5],:,:) = 0; % exclude baseline (-200~0 lags)
    [clusters{j}, p_values{j}, t_sums{j}, t_value{j}] = permutest_ywy(temp1,temp2,1,0.025,1024,1);    
end
save('.\permutest_TRF_AVinput_21subj_3feat_impulse_new_onset_noise.mat','clusters','p_values','t_sums','t_value');

%% figure: spatiotemporal clusters and their topographies 
clear;
cd('G:\EEG_AV_speech');
addpath G:\analysis_code
load('.\permutest_TRF_AVinput_21subj_3feat_impulse_new_onset_noise.mat','clusters','p_values','t_sums','t_value');
attr = {'A','WO','W','AN'};
D = 25;      
for j = 1:length(attr)
    clear nc
    nc = length(find(p_values{j}<0.05));
    if nc ~= 0
    sig_t_value{j} = zeros(D,64);
    for c = 1:nc
        clear idx cx cy cluster_sig_t
        id = clusters{j}{c};
        sz = [D 64];
        [cx cy] = ind2sub(sz,id);
        len = length(id);
        for n = 1:len
            sig_t_value{j}(cx(n),cy(n)) = t_value{j}(cx(n),cy(n)); % new location
        end
        cluster_sig_t = zeros(D,64);
        for n = 1:len
            cluster_sig_t(cx(n),cy(n)) = t_value{j}(cx(n),cy(n)); % new location
        end
    
        clear aa idx tlim
        aa = tabulate(cx);
        idx = find(aa(:,2)~=0);
        tlim = aa(idx,1);
        clear bb idb chlim
        bb = tabulate(cy);
        idy = find(bb(:,2)~=0);
        chlim = bb(idy,1);
        % project to original chanloc
        clear X cc chlim
        X = zeros(D,64);
        load('BioSemi64.mat');
        load('BioSemi64_new_2.mat');
        X(:,cell2mat(loc_new_2(:,1))) = cluster_sig_t(:,cell2mat(loc(:,1))); 
        cc = [1:64];chlim(cell2mat(loc_new_2(:,1)),1) = cc(1,cell2mat(loc(:,1))); %find original chanloc
        
        clear XX
        XX = X; chan = setdiff(cc,chlim);XX(tlim,chan)=0;
        figure;
        set(gcf,'Position',[200 100 500 500],'color','white');
        lim = [-10 10];
        aaa = [ceil(roundn(linspace(-200,800,26),-02))];        
        topoplot(mean(XX(tlim,:),1),'BioSemi64.loc','electrodes','on','maplimits',lim);
        titles = [num2str(ceil(aaa(tlim(1)))) '-' num2str(ceil(aaa(tlim(end))))];
        colormap(redblue);
%         saveas(gcf,['.\figure_TRF_3feat_noise\TRF_cluster_' attr{j} '_topo_' titles '.emf']);
%         close;
        cluster_tlim{j,c} = tlim;
    end
    
    figure;
    set(gcf,'Position',[200 100 900 600],'color','white');
    imagesc(sig_t_value{j}',[-10 10]);colormap(redblue)
    ylabel('Channel');yticks([2 6 13 22 31 40 50 58 62]);yticklabels(loc_new_2([2 6 13 22 31 40 50 58 62],4));
    xticks([1 5 10 15 20 25]);xticklabels({'-200','0','200','400','600','800 (ms)'});
    xlim([5 25]);ylim([0 65]);

    set(gca,'fontsize',24,'fontweight','bold');
    set(gca,'linewidth',1.5);
    box off
    saveas(gcf,['.\figure_TRF_3feat_noise\TRF_cluster_' attr{j} '_2.emf']);
    close
    end
end
% save('.\cluster_threeinput_tlim_withsurrogat_concate.mat','cluster_tlim');
% save('.\sig_cluster_AVinput_21subj_3feat_impulse_new_onset_noise.mat','sig_t_value');

%%
clear;
load('.\TRF_AVinput_21subj_3feat_impulse_new_onset_noise_bc.mat','h'); % h: row:feat; col: condi
attr = {'A','WO','W','AN'};
achan = [12 32 47:49];
vchan = [27 28 29 30 64]; %28 Iz
fcchan = [11 38 46 47 48];
pcchan = [48 19 32 56 31]; % CZ CP1 CPZ CP2 PZ 
%%
chan{1} = achan;
chan{2} = achan;
chan{3} = achan;
chan{4} = achan;

clear lim
lim{1} = [-0.15 0.15];
lim{2} = [-0.15 0.15];
lim{3} = [-0.15 0.15];
lim{4} = [-0.15 0.15];

color{1} = 'k';
color{2} = 'b';
color{3} = 'r';

for f = 1:4
   figure;
   set(gcf,'Position',[200 100 500 200],'color','white');
   for c = 1:3
        plot(mean(h{f,c}(:,chan{f},:),[2 3]),'LineWidth',1.5,'Color',color{c});
        hold on
   end
    fsize = 16;
    ylabel('Amplitude (a.u.)');
    yticks([-0.15 0 0.15]);%xticklabels({'-200','0','200','400','600','(ms)'});
    xticks([0 5 10 15 20 25 30]);xticklabels({'-200','0','200','400','600','800','1000'});
%     title(attr{j},'fontsize',fsize,'fontweight','bold');
    set(gca,'fontsize',fsize,'fontweight','bold');
    set(gca,'linewidth',1.5);
    ylim(lim{f})
    box off
%     saveas(gcf,['.\figure_TRF\TRF_waveform_' attr{f} '.emf']);
%     close
end
