%% This script creates data for multiple 10-fold cross validation.
clear;clc;

sdmt_info  = readtable('../../joint_analysis/sdmt_list_final.csv'); % load SDMT data demographics

nback_info = readtable('../../joint_analysis/nback_list_final.csv'); % load Nback data demographics

load('../../data/data_final_pgc.mat','data_gene','data_snp','class'); 
nn = 5;

load(sprintf('../../data/%d_layer_connection.mat',nn),'mask_con'); % mask_con is the connectivity adjacency matrix.

load(sprintf('../../data/gene_pathway_layer_connection_and_input_data_%d.mat',nn),'data_gene','class','mask_gene_encode','pool_n'); % load genetics data and its class labels. mask_gene_encode is the gene -> pathway information. pool_n is an array which contains number of nodes in each layer of ontology.


fold = 10; % number of folds.

load('../../joint_analysis/F_scz_nc_nback.mat','F'); % load nback imaging data.
F_nback = F;

load('../../joint_analysis/F_scz_nc_sdmt_aversive.mat','F'); % load sdmt imaging data.
F_sdmt = F;

mkdir('../../data/multiple_runs')
%  data_SNP contails all the information about the genetics data.
% nback_info, SDMT_info contains all the information about the imaging
% data.

[~, id_nback]   = ismember(round(nback_info.familyno*100),data_snp.IID);
[~, id_sdmt]    = ismember(round(sdmt_info.familyno*100),data_snp.IID);


data_gene_nback = data_gene(id_nback,:); % this is the genetics data for those subjects who have nback as well.
class_nback     = class(id_nback);

data_gene_sdmt  = data_gene(id_sdmt,:);  % this is the genetics data for those subjects who have sdmt as well.
class_sdmt      = class(id_sdmt);

for m_n=1:10
    cpart_nback_joint = cv_partition(class_nback(1:62)); % We have 62 subjects who have all 3 modalities and they are stacked at the beginning.
    cpart_sdmt_joint  = cpart_nback_joint;

    cpart_nback = cv_partition(class_nback(63:end)); % Train-Test for subjects who have only Nback + SNP 
    cpart_sdmt  = cv_partition(class_sdmt(63:end)); % Train-Test for subjects who have only  SDMT + SNP 

    for i=1:fold
    
        %% Nback
        [cpart_nback_joint{i,1}, cpart_nback_joint{i,3}] = split(cpart_nback_joint{i,1},class_nback(cpart_nback_joint{i,1})' ); % Train-Val for subjects who have all 3 data modalities 
        [cpart_nback{i,1}, cpart_nback{i,3}] = split(cpart_nback{i,1},class_nback(62+cpart_nback{i,1})' );% Train-Val for subjects who have only  Nback + SNP 
    
        %% SDMT
        cpart_sdmt_joint = cpart_nback_joint;
        [cpart_sdmt{i,1}, cpart_sdmt{i,3}] = split(cpart_sdmt{i,1},class_sdmt(62+cpart_sdmt{i,1})' );% Train-Val for subjects who have only  SDMT + SNP 
    
    
    end
    
    for i=1:fold
    
        %% Train 
    
        %% Joint Data
        train_id_nback                      = [cpart_nback_joint{i,1},62+cpart_nback{i,1}];
        
        % Generate contrast maps.
        [b1_nback, b2_nback, F_train_nback] = generate_training_nback(F_nback, train_id_nback);
    
        I_train_Nback_j                     = F_train_nback(:,1:length(cpart_nback_joint{i,1}))';
        I_train_Nback                       = F_train_nback(:,length(cpart_nback_joint{i,1})+1:end)';
    
        G_train_j                           = data_gene_nback(cpart_nback_joint{i,1},:);
        G_train_Nback                       = data_gene_nback(62+cpart_nback{i,1},:);
    
        Y_train_j                           = class_nback(cpart_nback_joint{i,1})';
        Y_train_Nback                       = class_nback(cpart_nback{i,1}+62)';
    
    
        %% SDMT
        train_id_sdmt                    = [cpart_sdmt_joint{i,1},62+cpart_sdmt{i,1}];
        [b1_sdmt, b2_sdmt, F_train_sdmt] = generate_training_sdmt(F_sdmt, train_id_sdmt);
    
        I_train_SDMT_j                   = F_train_sdmt(:,1:length(cpart_sdmt_joint{i,1}))';
        I_train_SDMT                     = F_train_sdmt(:,length(cpart_sdmt_joint{i,1})+1:end)';
    
        G_train_SDMT                     = data_gene_sdmt(62+cpart_sdmt{i,1},:);
        Y_train_SDMT                     = class_sdmt(cpart_sdmt{i,1}+62)';
    
    
        %% Testing
    
        %% Joint Data
        I_test_Nback_j  = generate_testing_nback(F_nback, cpart_nback_joint{i,2}, b1_nback, b2_nback);
        I_test_Nback_j  = I_test_Nback_j';
    
        I_test_Nback    = generate_testing_nback(F_nback, 62+cpart_nback{i,2},    b1_nback, b2_nback);
        I_test_Nback    = I_test_Nback';
    
        G_test_j       = data_gene_nback(cpart_nback_joint{i,2},:);
        G_test_Nback   = data_gene_nback(62+cpart_nback{i,2},:);
    
        Y_test_j        = class_nback(cpart_nback_joint{i,2})';
        Y_test_Nback    = class_nback(cpart_nback{i,2}+62)';
    
    
        %% SDMT
        I_test_SDMT_j  = generate_testing_sdmt(F_sdmt, cpart_sdmt_joint{i,2}, b1_sdmt, b2_sdmt);
        I_test_SDMT_j  = I_test_SDMT_j';
    
        I_test_SDMT    = generate_testing_sdmt(F_sdmt, 62+cpart_sdmt{i,2},    b1_sdmt, b2_sdmt);
        I_test_SDMT    = I_test_SDMT';
    
        G_test_SDMT     = data_gene_sdmt(62+cpart_sdmt{i,2},:);
        Y_test_SDMT     = class_sdmt(cpart_sdmt{i,2}+62)';
    
    
        %% Validation
    
        %% Joint Data
        I_val_Nback_j  = generate_testing_nback(F_nback, cpart_nback_joint{i,3}, b1_nback, b2_nback);
        I_val_Nback_j  = I_val_Nback_j';
    
        I_val_Nback    = generate_testing_nback(F_nback, 62+cpart_nback{i,3},    b1_nback, b2_nback);
        I_val_Nback    = I_val_Nback';
    
        G_val_j       = data_gene_nback(cpart_nback_joint{i,3},:);
        G_val_Nback   = data_gene_nback(62+cpart_nback{i,3},:);
    
        Y_val_j        = class_nback(cpart_nback_joint{i,3})';
        Y_val_Nback    = class_nback(cpart_nback{i,3}+62)';
    
    
        %% SDMT
        I_val_SDMT_j  = generate_testing_sdmt(F_sdmt, cpart_sdmt_joint{i,3}, b1_sdmt, b2_sdmt);
        I_val_SDMT_j  = I_val_SDMT_j';
    
        I_val_SDMT    = generate_testing_sdmt(F_sdmt, 62+cpart_sdmt{i,3},    b1_sdmt, b2_sdmt);
        I_val_SDMT    = I_val_SDMT';
    
        G_val_SDMT     = data_gene_sdmt(62+cpart_sdmt{i,3},:);
        Y_val_SDMT     = class_sdmt(cpart_sdmt{i,3}+62)';
    

    
        save(sprintf('../../data/multiple_runs/train_test_val_imaging_data_folds_with_%d_graph_layers%d_3_%d.mat',nn, ...
            i, m_n), 'I_train_Nback_j', 'I_train_Nback', 'G_train_j', 'G_train_Nback', 'Y_train_j', 'Y_train_Nback', ...
            'I_train_SDMT_j', 'I_train_SDMT', 'G_train_SDMT', 'Y_train_SDMT', 'I_test_Nback_j', 'I_test_Nback', ...
            'G_test_j', 'G_test_Nback', 'Y_test_j', 'Y_test_Nback', 'I_test_SDMT_j', 'I_test_SDMT', 'G_test_SDMT', 'Y_test_SDMT',...
            'I_val_Nback_j', 'I_val_Nback', ...
            'G_val_j', 'G_val_Nback', 'Y_val_j', 'Y_val_Nback', 'I_val_SDMT_j', 'I_val_SDMT', 'G_val_SDMT', 'Y_val_SDMT')

    end
end

function [tr_id, cv_id ]= split(id,class)

dis = find(class==1);
dis = dis(randperm(length(dis)));

con = find(class==0);
con = con(randperm(length(con)));


cv_id_dis = dis(randperm(length(dis),floor((length(dis)*10)/100))  );    
cv_id_con = con(randperm(length(con),floor((length(con)*10)/100))  ); 

cv_id = [cv_id_con',cv_id_dis'];
cv_id = id(cv_id(randperm(length(cv_id))));

tr_id = id(ismember(id,cv_id)==0);
end




